#ifdef USE_PIK_ML
module pik_ml_fortran_mod
!
! Fortran side of the PIK_ML interface
!
! The machine learning code is invoked only on the root task.
! Thus we must collect the distributed data from all parallel
! domains onto the root task for sending it to the ML code,
! and on the way back distribute the ML-improved data
! from the root task to all others.

  use  time_manager_mod, only: time_type

  use fms_mod,         only: open_namelist_file, check_nml_error
  use mpp_mod,         only: mpp_npes, mpp_pe, mpp_root_pe, &
       mpp_error, stderr, stdout, stdlog, FATAL, NOTE, mpp_set_current_pelist, &
       mpp_clock_id, mpp_clock_begin, mpp_clock_end, mpp_sum, &
       CLOCK_COMPONENT, CLOCK_SUBCOMPONENT, CLOCK_ROUTINE, CLOCK_LOOP, lowercase, &
       input_nml_file
  use mpp_mod,         only: mpp_get_current_pelist_name
  use mpp_mod,         only: mpp_sync
  use mpp_mod,         only: mpp_broadcast
  use mpp_domains_mod, only: mpp_get_compute_domain, mpp_get_global_domain
  use mpp_domains_mod, only: mpp_global_field
  use mpp_io_mod,      only: mpp_close

  use diag_manager_mod, only: register_diag_field, send_data

  use atmos_model_mod, only: atmos_data_type

  implicit none
  private

  public :: pik_ml_fortran_init, pik_ml_fortran_precip

  character(len=48), parameter :: module_name = 'pik_ml_fortran_mod'

  integer :: pik_ml_clock
  integer :: id_lprec_in = -1, id_fprec_in = -1, id_lprec_out = -1, id_fprec_out = -1
  integer :: atm_isg, atm_ieg, atm_jsg, atm_jeg ! Atm global domain boundary indice
  integer :: atm_nxg, atm_nyg ! Atm global domain sizes
  integer :: is, ie, js, je

  ! namelist interface.
  ! A run time parameter to turn on or off the actual invocation of
  ! the neural network model. default: .true.
  logical :: invoke_neural_model = .true.
  namelist /pik_ml_nml/ invoke_neural_model

contains

  subroutine pik_ml_fortran_init(Time, Atm, dt_atmos)
    type(time_type),       intent(in)  :: Time
    type(atmos_data_type), intent(in)  :: Atm
    integer, optional,     intent(in)  :: dt_atmos

    ! some magical values for testing the interface between Fortran90 and C++ code.
    ! we pass them from Fortran90 to C++ functions and test in there if they were
    ! received correctly.
    integer, parameter :: deadbeef = X'eadbeef' ! X'deadbeef' yields overflow error on 32bit machines
    integer, parameter :: fse = 4711
    real,    parameter :: onetwothree = 1234567809.0987654321D0, &
         pitest = 3.1415926535897932384626433832795029D0 ! from linux <math.h>
    logical :: true = .true., false = .false.
    integer :: kind_i = KIND(deadbeef), kind_r = KIND(pitest), kind_l = KIND(true)
    real, dimension(3,4) :: testarray2d
    integer :: i, j
    integer :: ierr, io, unit

    ! call write_version_number(version, tagname)
    !----- read namelist -------

#ifdef INTERNAL_FILE_NML
    read (input_nml_file, pik_ml_nml, iostat=io)
    ierr = check_nml_error (io, 'pik_ml_nml')
#else
    unit = open_namelist_file()
    ierr=1;
    do while (ierr /= 0)
       read  (unit, nml=pik_ml_nml, iostat=io, end=10)
       ierr = check_nml_error (io, 'pik_ml_nml')
    enddo
10  call mpp_close(unit)
#endif
    write (stdout(),'(/)')
    write (stdout(), pik_ml_nml)

    pik_ml_clock = mpp_clock_id('pik_ml_precip')
    
    call mpp_get_global_domain(Atm%domain, atm_isg, atm_ieg, atm_jsg, atm_jeg)
    call mpp_get_compute_domain( Atm%domain, is, ie, js, je )

    ! first assemble some parameters with well-known values to verify
    ! that the F90<->C++ interface is OK.
    do i=1,3
       do j=1,4
          testarray2d(i,j) = i*100.0+j
       end do
    end do
    write(*,*) 'MAXEXPONENT ', MAXEXPONENT(pitest)

    if (.not. invoke_neural_model) return

    if( mpp_pe() == mpp_root_pe() ) then
       write(*,*) 'calling pik_ml_init on PE ', mpp_pe()
       call pik_ml_interfacecheck( &
            deadbeef, pitest,       &
            true, false,            &
            testarray2d,            &
            kind_i, kind_r, kind_l, &
            ! input variables for interface checking
            onetwothree, fse,       &
            )
       call pik_ml_init( &
            atm_isg, atm_ieg, atm_jsg, atm_jeg, dt_atmos, &
            ! atmos_comm, &
            mpp_pe(), mpp_root_pe(), mpp_npes(), &
            fse & ! minimal interface check for correct number of args
            )
    end if
    write(*,*) 'mpp_sync() after pik_ml_init() on PE ', mpp_pe()
    call mpp_sync()

    id_lprec_in = register_diag_field( 'pik_ml', 'lprec_in',       Atm%axes(1:2), Time, &
         'liquid precipitation rate as input to machine learning',        'kg/m2/s'  )
    id_fprec_in = register_diag_field( 'pik_ml', 'fprec_in',       Atm%axes(1:2), Time, &
         'frozen precipitation rate as input to machine learning',        'kg/m2/s'  )
    id_lprec_out = register_diag_field( 'pik_ml', 'lprec_out',       Atm%axes(1:2), Time, &
         'liquid precipitation rate as output from machine learning',        'kg/m2/s'  )
    id_fprec_out = register_diag_field( 'pik_ml', 'fprec_out',       Atm%axes(1:2), Time, &
         'frozen precipitation rate as output from machine learning',        'kg/m2/s'  )

  end subroutine pik_ml_fortran_init

  subroutine pik_ml_fortran_precip(Time, Atm)
    type(time_type),       intent(in) :: Time
    type(atmos_data_type), intent(inout) :: Atm

    integer :: len, used
    real, dimension(atm_isg:atm_ieg,atm_jsg:atm_jeg) :: global_lprec, global_fprec

    if (.not. invoke_neural_model) return

    !! To enable strict floating point exceptions inside this module.
    !! Taken from the example in the Intel Fortran Compiler Reference Manual.
    !! Most probably will not work with any other compiler
    !call clearstatusfpqq() ! do not inherit signalled FP exceptions from FMS
    !original_fpe_flags = FOR_SET_FPE(strict_fpe_flags)

    ! here we need only year and month
    !call get_date(Time, Time_year, Time_month, used, used, used, used)
    !call get_time(Time, Time_seconds, Time_days)

    if (id_lprec_in > 0) &
         used = send_data(id_lprec_in, Atm%lprec, Time)
    if (id_fprec_in > 0) &
         used = send_data(id_fprec_in, Atm%fprec, Time)

    if (Atm%pe) then
       call mpp_clock_begin(pik_ml_clock)
       call mpp_global_field(Atm%domain, Atm%lprec, global_lprec)
       call mpp_global_field(Atm%domain, Atm%fprec, global_fprec)

       if( mpp_pe() == mpp_root_pe() ) then
          write(*,*) 'calling pik_ml_precipitation on PE ', mpp_pe()
          call pik_ml_precipitation( &
               size(global_lprec,1),  & ! interface check
               size(global_lprec,2),  & ! interface check
               global_lprec,          & ! liquid precipitation
               global_fprec,          & ! frozen precipitation
               42                     & ! interface check
               )
       end if ! mpp_root_pe()
       call mpp_sync()
       ! Now distribute the ML-improved precip data back to
       ! the parallel atmos PEs.
       ! For now, simply broadcast the global domain,
       ! and then each task picks its local share.
       !call mpp_broadcast( data, length, from_pe, pelist )
       len = size(global_lprec,1) * size(global_lprec,2)
       call mpp_broadcast(global_lprec, len, mpp_root_pe(), Atm%pelist)
       Atm%lprec(:,:) = global_lprec(is:ie,js:je)
       call mpp_broadcast(global_fprec, len, mpp_root_pe(), Atm%pelist)
       Atm%fprec(:,:) = global_fprec(is:ie,js:je)
       call mpp_clock_end(pik_ml_clock)
    end if ! Atm%pe

    if (id_lprec_out > 0) &
         used = send_data(id_lprec_out, Atm%lprec, Time)
    if (id_fprec_out > 0) &
         used = send_data(id_fprec_out, Atm%fprec, Time)

    !! To enable strict floating point exceptions inside this module.
    !! Taken from the example in the Intel Fortran Compiler Reference Manual.
    !! Most probably will not work with any other compiler
    !original_fpe_flags = FOR_SET_FPE(original_fpe_flags)
    
  end subroutine pik_ml_fortran_precip

end module pik_ml_fortran_mod
#endif
