"""
aeolus2fms.py

Interface between FMS and Aeolus2 to invoked via MPI_Comm_spawn.
This file contains a main program for coupled runs, and functions to
handle the communication with FMS. data from incoming MPI messages
is unpacked and passed as numpy arrays to functions which are
responsible to do the real work. Their return values are
packed into MPI messages and sent back to FMS.
"""

import sys
import os
import time
start_time = time.time()
#import argparse
#import pathlib
#import math                     as ma
#import mpmath                   as mp

import numpy                    as np
#import scipy.integrate          as integrate #Gives access to the ODE integration package
#import scipy.special            as ss
#import scipy.io                 as sio
#from scipy.sparse import linalg as spla

from mpi4py import MPI
#from multiprocessing import Pool

#import dedalus.public           as de
## Load config options
#from dedalus.tools.config import config

#import sphere_wrapper           as sph
#from NetCDFOutput import NetCDFOutput


# aeolus2.py implements the real work.
from aeolus2 import \
    aeolus2_init_grid, aeolus2_init, \
    aeolus2_restart, aeolus2_finish, \
    aeolus2_get_stock_pe, \
    aeolus2_get_bottom_mass, aeolus2_get_bottom_wind, \
    aeolus2_update_down, aeolus2_update_up


epsilon = 1e-10
def double_eq(v1,v2,e): return np.fabs((v1-v2)/((v1+v2)/2)) < e

# Communicator for Dedalus/python tasks.
# Spawn case: MPI.COMM_WORLD; MPMD case: created via Group range
comm_dedalus = None
rank         = None
size         = None

# Communicator for all FMS tasks.
# Spawn case: parent comm created by Spawn; MPMD case: MPI.COMM_WORLD
comm_fms = None
# Spawn case: same as rank; MPMD case: same as rank in comm_dedalus, but in MPI.COMM_WORLD
peerrank = None


## Define some variables on global scope.
# TODO: check which of these variables are really needed globally

# indices of the borders of the compute domain in the global grid.
# Note: following python conventions, the end indices are one
# _beyond_ the actual end.
ics  = 0 # sigh. "is" is a reserved word in python
ice  = None
jcs  = 0
jce  = None
ni = None
nj = None


# start time of experiment
Time_init_year = None
Time_init_month = None
Time_init_day = None
Time_init_seconds = None
# start time of this run, within the experiment
Time_days = np.zeros(1,dtype=np.int32)
Time_seconds = np.zeros(1,dtype=np.int32)

Time_step_days = None
Time_step_seconds = None

without_topo = False
update_land_frac_always = False

nr_tracers = None # number of tracers to exchange with coupler
q_ind = None      # index of humidity tracer
co2_ind = None    # index of CO2 tracer


def aeolus2fms_interface_check():
    """
    as the function name says
    """
    global comm_fms, peerrank, comm_dedalus, rank, size
    #sys.stdout = open("fms.py.out-"
    #                  +os.getenv("SLURM_JOBID", default="")+"-"
    #                  +os.getenv("SLURM_NNODES", default="")+"-"
    #                  +os.getenv("SLURM_NTASKS", default="")+"-"
    #                  +"%(rank)02d"%{"rank":MPI.COMM_WORLD.rank}, "w")
    #sys.stderr = sys.stdout

    print("pid",os.getpid(),"commandline",sys.executable,sys.argv)
    mypid = np.array(os.getpid(), dtype='i')
    print("aeolus2fms.py pid ", mypid, " MPI rank ", rank, " of ", size, flush=True)

    #print("calling Get_parent", flush=True)
    #comm_fms = MPI.Comm.Get_parent()
    print("calling getenv(MPMD_FMS)", flush=True)
    mpmd_fms_npes = int(os.getenv("MPMD_FMS", default="0"))
    print("getenv(MPMD_FMS) yields", mpmd_fms_npes, flush=True)
    if mpmd_fms_npes > 0:
        # started via MPMD configuration
        print("aeolus2fms.py started in MPMD mode", flush=True)
        comm_fms = MPI.COMM_WORLD
        mpmd_fms_npes = int(os.getenv("MPMD_FMS", default="0"))
        mpmd_dedalus_npes = int(os.getenv("MPMD_OTHER", default="0"))
        print("found MPMD_FMS=", mpmd_fms_npes, " MPMD_OTHER=", mpmd_dedalus_npes, flush=True)
        group_world = comm_fms.Get_group()
        group_dedalus = group_world.Range_incl([(mpmd_fms_npes,group_world.size-1,1),])
        print("created group_dedalus", flush=True)
        comm_dedalus = comm_fms.Create(group_dedalus)
        print("created comm_dedalus", flush=True)
        rank = comm_dedalus.rank
        peerrank = rank # in comm_fms!
        remotesize = mpmd_dedalus_npes
        print("created dedalus-communicator, size ",comm_dedalus.size," my rank ",rank, flush=True)
        print("comm_fms: size ",comm_fms.size," my rank ",comm_fms.rank," peer rank ",peerrank, flush=True)
    elif not comm_fms is None and not comm_fms is MPI.COMM_NULL:
        # started via MPI_Comm_spawn(), we have a valid parent communicator
        print("aeolus2fms.py started via MPI_Comm_spawn()")
        fmssize = comm_fms.size
        peerrank = comm_fms.rank
        remotesize = comm_fms.Get_remote_size()
        comm_dedalus = MPI.COMM_WORLD
        rank = comm_dedalus.rank
        size = comm_dedalus.size
        print("aeolus2fms.py: task ", rank, " is task ", comm_fms.rank,
              " in parent of size ", fmssize, " remote size ", remotesize, flush=True)
        if (fmssize != remotesize):
            print("Error: local size of parent communicator must be equal to remote size of parent communicator",
                  fmssize, remotesize, flush=True)
            raise Exception
        if (size != remotesize):
            print("Error: size of local communicator must be equal to remote size of parent communicator",
                  size, remotesize, flush=True)
            raise Exception
    else:
        print("aelus2fms.py: Error: no parent communicator, and no for MPMD environment settings")
        print("aeolus2fms.py: but dont know how to handle standalone runs yet", flush=True)
        raise Exception


    for i in range(0, remotesize):
        print("sending mypid to task ", i, flush=True)
        comm_fms.Send([mypid, 1, MPI.INT], dest=i, tag=42)

    remotepid=np.array(42, dtype='i')
    for i in range(0, remotesize):
        print("receiving remotepid nr ", i, flush=True)
        comm_fms.Recv([remotepid, 1, MPI.INT], source=MPI.ANY_SOURCE, tag=43)
        print("aeolus2fms.py task ", rank, " received pid ", remotepid, flush=True)

    sys.stdout.flush()
    sys.stderr.flush()
    #print('aeolus2fms_interface_check entering first barrier', flush=True)
    #comm_fms.Barrier()

    deadbeef=np.empty(1,dtype='i')
    print('aeolus2fms_interface_check receiving deadbeef', flush=True)
    comm_fms.Recv([deadbeef, 1, MPI.INT], source=peerrank, tag=44)
    if deadbeef != 0xeadbeef: # 0xdeadbeef yields overflow error on 32bit machines
        print("Error: test1 failed: 0xeadbeef != ", deadbeef, flush=True)
    pitest=np.empty(1,dtype=np.float64)
    print('aeolus2fms_interface_check receiving pitest', flush=True)
    comm_fms.Recv([pitest, 1, MPI.DOUBLE], source=peerrank, tag=45)
    if pitest != np.pi:
        print("Error: test2 failed: %(pitest)30.28f != %(M_PI)30.28f"
              %{"pitest":pitest[0],"M_PI":np.pi}, flush=True)

    # Fortran arrays are index 1:n , numpy arrays are denoted [0:n]
    #    but last element is [n-1]
    # If we use python indices for calculations, we must add 1 to
    # yield the same results as in Fortran!
    testarray1d = np.empty((2*3), dtype=np.float64)
    print('aeolus2fms_interface_check receiving testarray2d', flush=True)
    comm_fms.Recv([testarray1d, testarray1d.size, MPI.DOUBLE], source=peerrank, tag=46)
    #print("flat", testarray1d)
    #print("reshape F", np.reshape(testarray1d,(2,3),order='F'))
    #print("... i=0", np.reshape(testarray1d,(2,3),order='F')[0,:])
    #print("reshape C", np.reshape(testarray1d,(2,3),order='C'))
    #print("transpose(reshape F)",np.transpose(np.reshape(testarray1d,(2,3),order='F')))
    #print("... i=0", np.transpose(np.reshape(testarray1d,(2,3),order='F'))[0,:])
    #sys.stdout.flush()
    testarray2d = (np.reshape(testarray1d,(2,3),order='F'))
    #for k in range(0, 4):
    for j in range(0, 3):
        for i in range(0, 2):
            #idx = (k-1)*3*2 + (j-1)*2 + (i-1)
            testval = testarray2d[i,j]
            myval = (i+1)+(j+1)/10.0
            #testval = testarray3d[i,j,k]
            #myval = i*100.0+j+k/10.0
            #printf("testarray3d(%2d,%2d,%2d) = testarray3d[%2d] = %10.9g vs. %10.9g\n", i, j, k, idx, testval, myval)
            if not double_eq(testval, myval, epsilon):
                print("Error: test5 failed: testarray2d(%(i)d,%(j)d) is %(testval)10.9g != %(myval)10.9g"
                      %{"i":i, "j":j, "testval":testval, "myval":myval}, flush=True)
    print('aeolus2fms_interface_check sending back one double', flush=True)
    pitest[0] = 8.15
    comm_fms.Send([pitest, 1, MPI.DOUBLE], dest=peerrank, tag=47)
    print('aeolus2fms_interface_check sending back testarray2d', flush=True)
    comm_fms.Send([testarray2d, testarray2d.size, MPI.DOUBLE], dest=peerrank, tag=48)
    #print('aeolus2fms_interface_check entering final barrier', flush=True)
    #comm_fms.Barrier()
    print("aeolus2fms_interface_check finished", flush=True)
# end of aeolus2fms_interface_check()


def aeolus2fms_init_grid():
    """
    as the function name says
    """
    global ics, ice, jcs, jce, ni, nj
    gridparams = np.empty(7, dtype=np.int32)
    # gridparams(1-1) = nlons_global
    # gridparams(2-1) = nlats_global
    # gridparams(3-1) = lon_0_360
    # gridparams(4-1) = ics
    # gridparams(5-1) = ice
    # gridparams(6-1) = jcs
    # gridparams(7-1) = jce

    comm_fms.Recv([gridparams, gridparams.size, MPI.INT], source=MPI.ANY_SOURCE, tag=50)

    ics, ice, jcs, jce, lons_global, lonb_global, lats_global, latb_global = \
        aeolus2_init_grid(gridparams[0], gridparams[1], gridparams[2], comm_dedalus)

    print("received FMS domain boundary indices",
          gridparams[3:7],
          " my own are ",
          ics, ice, jcs, jce, flush=True)
    # Caution: again, F90/FMS counts indices from 1, while python/Sphere counts from 0
    # The end indices, however, are the same, because F90/FMS counts to last element,
    # while python/Sphere counts 1 beyond.
    # Attempt to ensure that we get this message from all python processes.
    comm_dedalus.Barrier()
    if (ics != gridparams[3]-1 or
        ice != gridparams[4]   or
        jcs != gridparams[5]-1 or
        jce != gridparams[6]):
        print("Error: FMS domain boundary indices dont match Aeolus2 boundaries",
              gridparams[3:7], ics, ice, jcs, jce, flush=True)
        raise Exception
    ni = ice-ics
    nj = jce-jcs

    print("aeolus2fms_init_grid checked grid parameters", flush=True)

    print("aeolus2fms_init_grid sending back _global_ axis info", flush=True)
    # send back the lons_global, lonb_global, lats_global, latb_global arrays
    print("aeolus2fms_init_grid sending lons_global", flush=True)
    comm_fms.Send([lons_global, lons_global.size, MPI.DOUBLE], dest=peerrank, tag=51)
    print("aeolus2fms_init_grid sending lonb_global", flush=True)
    comm_fms.Send([lonb_global, lonb_global.size, MPI.DOUBLE], dest=peerrank, tag=52)
    print("aeolus2fms_init_grid sending lats_global", flush=True)
    comm_fms.Send([lats_global, lats_global.size, MPI.DOUBLE], dest=peerrank, tag=53)
    print("aeolus2fms_init_grid sending latb_global", flush=True)
    comm_fms.Send([latb_global, latb_global.size, MPI.DOUBLE], dest=peerrank, tag=54)
    print("aeolus2fms_init_grid sent back axis info", flush=True)

    print("aeolus2fms_init_grid finished")
# end of aeolus2fms_init_grid()


def aeolus2fms_init():
    """
    final steps of initialisation
    """
    global Time_init_year, Time_init_month, Time_init_day, Time_init_seconds
    global Time_days, Time_seconds, Time_step_days, Time_step_seconds
    global without_topo, update_land_frac_always
    global nr_tracers, q_ind, co2_ind
    #global HORO, SIGORO, area
    aeolus2params=np.empty(14,dtype=np.int32)
    comm_fms.Recv([aeolus2params, aeolus2params.size, MPI.INT], source=MPI.ANY_SOURCE, tag=60)
    print("aeolus2fms_init received aeolus2params", flush=True)
    Time_init_year          = aeolus2params[1-1]
    Time_init_month         = aeolus2params[2-1]
    Time_init_day           = aeolus2params[3-1]
    Time_init_seconds       = aeolus2params[4-1]
    Time_days[0]            = aeolus2params[5-1]
    Time_seconds[0]         = aeolus2params[6-1]
    Time_step_days          = aeolus2params[7-1]
    Time_step_seconds       = aeolus2params[8-1]
    without_topo            = aeolus2params[9-1]
    update_land_frac_always = aeolus2params[10-1]
    nr_tracers              = aeolus2params[11-1]
    q_ind                   = aeolus2params[12-1]
    co2_ind                 = aeolus2params[13-1]
    if (aeolus2params[14-1] != 42):
        print("aeolus2fms_init received parameter list of wrong length")
        raise Exception

    #print('locals', locals())
    #print('globals', globals())

    HORO = np.zeros(ni*nj, dtype=np.float64)
    SIGORO = np.zeros(ni*nj, dtype=np.float64)
    if not without_topo:
        comm_fms.Recv([HORO, HORO.size, MPI.DOUBLE], source=peerrank, tag=61)
        comm_fms.Recv([SIGORO, HORO.size, MPI.DOUBLE], source=peerrank, tag=62)
        print("aeolus2fms_init received HORO and SIGORO", flush=True)
    HORO = np.reshape(HORO,(ni,nj),order='F')
    SIGORO = np.reshape(SIGORO,(ni,nj),order='F')
    # TODO: check if HORO needs to be smoothed before passed on
    land_frac = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([land_frac, land_frac.size, MPI.DOUBLE], source=peerrank, tag=63)
    land_frac = np.reshape(land_frac,(ni,nj),order='F')
    area = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([area, area.size, MPI.DOUBLE], source=peerrank, tag=64)
    area = np.reshape(area,(ni,nj),order='F')
    print("aeolus2fms_init received land_frac and area", flush=True)

    aeolus2_init((Time_init_year, Time_init_month, Time_init_day, Time_init_seconds,
                  Time_days, Time_seconds,
                  Time_step_days, Time_step_seconds,
                  without_topo, update_land_frac_always,
                  nr_tracers, q_ind, co2_ind,
                  HORO, SIGORO, land_frac, area))

    print("aeolus2fms_init finished", flush=True)
    # TODO check if we need to send a dummy response for synchronisation.
    # tag=65
    # Without that the F90 part will continue before we are finished here.
    # Does that harm, or could it speedup things?
# end of aeolus2fms_init()


def aeolus2fms_update_up():
    """
    Receive input values, with tags 100ff.
    Invoke the real implementation in aeolus2.py.
    Send return values.
    """
    global Time_seconds, Time_days

    print("aeolus2fms_update_up", flush=True)

    # Sigh. MPI Recv() needs to have receive buffers preallocated.
    # Doing that for every array in every invocation of these functions
    # might be really slow.
    # On the other hand, allocating these variables globally and re-using
    # them might consume non-trivial amounts of memory. Sigh.
    Time_year = np.zeros(1,dtype=np.int32)
    Time_month = np.zeros(1,dtype=np.int32)
    # Receive input values
    comm_fms.Recv([Time_year,    1, MPI.INT], source=peerrank, tag=100)
    comm_fms.Recv([Time_month,   1, MPI.INT], source=peerrank, tag=101)
    comm_fms.Recv([Time_seconds, 1, MPI.INT], source=peerrank, tag=102)
    comm_fms.Recv([Time_days,    1, MPI.INT], source=peerrank, tag=103)
    print("  Time_year, Time_month, Time_seconds, Time_days",
          Time_year,Time_month,Time_seconds,Time_days,flush=True)

    land_frac = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([land_frac, land_frac.size, MPI.DOUBLE], source=peerrank, tag=104)
    land_frac = np.reshape(land_frac,(ni,nj),order='F')

    surf_diff_delta_t = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([surf_diff_delta_t, surf_diff_delta_t.size, MPI.DOUBLE], source=peerrank, tag=105)
    surf_diff_delta_t = np.reshape(surf_diff_delta_t,(ni,nj),order='F')

    surf_diff_delta_q = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([surf_diff_delta_q, surf_diff_delta_q.size, MPI.DOUBLE], source=peerrank, tag=106)
    surf_diff_delta_q = np.reshape(surf_diff_delta_q,(ni,nj),order='F')

    u_star = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([u_star,    u_star.size,    MPI.DOUBLE], source=peerrank, tag=107)
    u_star = np.reshape(u_star,(ni,nj),order='F')

    b_star = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([b_star,    b_star.size,    MPI.DOUBLE], source=peerrank, tag=108)
    b_star = np.reshape(b_star,(ni,nj),order='F')

    q_star = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([q_star,    q_star.size,    MPI.DOUBLE], source=peerrank, tag=109)
    q_star = np.reshape(q_star,(ni,nj),order='F')

    # invoke the real implementation in aeolus2.py
    # return_values = aeolus2_update_up((input_values))
    lprec, fprec, gust \
    = aeolus2_update_up((Time_year[0], Time_month[0], \
        Time_seconds[0], Time_days[0], \
        land_frac, \
        surf_diff_delta_t, surf_diff_delta_q, \
        u_star, b_star, q_star))

    print('aeolus2fms_update_up() Send return values')
    comm_fms.Send([lprec, lprec.size, MPI.DOUBLE], dest=peerrank, tag=150)
    comm_fms.Send([fprec, fprec.size, MPI.DOUBLE], dest=peerrank, tag=151)
    comm_fms.Send([gust,  gust.size,  MPI.DOUBLE], dest=peerrank, tag=152)

    print("aeolus2fms_update_up finished", flush=True)
# end of aeolus2fms_update_up()


def aeolus2fms_update_down():
    """
    Receive input values, with tags 200ff.
    Invoke the real implementation in aeolus2.py.
    Send return values.
    """

    print("aeolus2fms_update_down", flush=True)
    # Receive input values
    Time_seconds = np.zeros(1,dtype=np.int32)
    Time_days = np.zeros(1,dtype=np.int32)
    dayoftheyear = np.zeros(1,dtype=np.int32)
    comm_fms.Recv([Time_seconds, 1, MPI.INT], source=peerrank, tag=200)
    comm_fms.Recv([Time_days,    1, MPI.INT], source=peerrank, tag=201)
    comm_fms.Recv([dayoftheyear, 1, MPI.INT], source=peerrank, tag=202)
    print("  Time_seconds, Time_days, dayoftheyear",
          Time_seconds,Time_days,dayoftheyear,flush=True)

    frac_land = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([frac_land,      frac_land.size,      MPI.DOUBLE], source=peerrank, tag=203)
    frac_land = np.reshape(frac_land,(ni,nj),order='F')
    t_surf = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([t_surf,         t_surf.size,         MPI.DOUBLE], source=peerrank, tag=204)
    t_surf = np.reshape(t_surf,(ni,nj),order='F')
    albedo = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([albedo,         albedo.size,         MPI.DOUBLE], source=peerrank, tag=205)
    albedo = np.reshape(albedo,(ni,nj),order='F')
    albedo_vis_dir = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([albedo_vis_dir, albedo_vis_dir.size, MPI.DOUBLE], source=peerrank, tag=206)
    albedo_vis_dir = np.reshape(albedo_vis_dir,(ni,nj),order='F')
    albedo_nir_dir = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([albedo_nir_dir, albedo_nir_dir.size, MPI.DOUBLE], source=peerrank, tag=207)
    albedo_nir_dir = np.reshape(albedo_nir_dir,(ni,nj),order='F')
    albedo_vis_dif = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([albedo_vis_dif, albedo_vis_dif.size, MPI.DOUBLE], source=peerrank, tag=208)
    albedo_vis_dif = np.reshape(albedo_vis_dif,(ni,nj),order='F')
    albedo_nir_dif = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([albedo_nir_dif, albedo_nir_dif.size, MPI.DOUBLE], source=peerrank, tag=209)
    albedo_nir_dif = np.reshape(albedo_nir_dif,(ni,nj),order='F')
    rough_mom = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([rough_mom,      rough_mom.size,      MPI.DOUBLE], source=peerrank, tag=210)
    rough_mom = np.reshape(rough_mom,(ni,nj),order='F')
    u_star = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([u_star,         u_star.size,         MPI.DOUBLE], source=peerrank, tag=211)
    u_star = np.reshape(u_star,(ni,nj),order='F')
    b_star = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([b_star,         b_star.size,         MPI.DOUBLE], source=peerrank, tag=212)
    b_star = np.reshape(b_star,(ni,nj),order='F')
    q_star = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([q_star,         q_star.size,         MPI.DOUBLE], source=peerrank, tag=213)
    q_star = np.reshape(q_star,(ni,nj),order='F')
    dtau_du = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([dtau_du,        dtau_du.size,        MPI.DOUBLE], source=peerrank, tag=214)
    dtau_du = np.reshape(dtau_du,(ni,nj),order='F')
    dtau_dv = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([dtau_dv,        dtau_dv.size,        MPI.DOUBLE], source=peerrank, tag=215)
    dtau_dv = np.reshape(dtau_dv,(ni,nj),order='F')
    u_flux = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([u_flux,         u_flux.size,         MPI.DOUBLE], source=peerrank, tag=216)
    u_flux = np.reshape(u_flux,(ni,nj),order='F')
    v_flux = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([v_flux,         v_flux.size,         MPI.DOUBLE], source=peerrank, tag=217)
    v_flux = np.reshape(v_flux,(ni,nj),order='F')
    coszen = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([coszen,         coszen.size,         MPI.DOUBLE], source=peerrank, tag=218)
    coszen = np.reshape(coszen,(ni,nj),order='F')
    # solar is 1-D !
    solar = np.zeros(nj, dtype=np.float64)
    comm_fms.Recv([solar,          solar.size,          MPI.DOUBLE], source=peerrank, tag=219)
    flux_lw_sf_up = np.zeros(ni*nj, dtype=np.float64)
    comm_fms.Recv([flux_lw_sf_up,  flux_lw_sf_up.size,  MPI.DOUBLE], source=peerrank, tag=220)
    flux_lw_sf_up = np.reshape(flux_lw_sf_up,(ni,nj),order='F')


    # invoke the real implementation in aeolus2.py
    # return_values = aeolus2_update_down((input_values))
    u_flux, v_flux, \
    gust, \
    flux_sw, \
    flux_sw_dir, flux_sw_dif, \
    flux_sw_down_vis_dir, flux_sw_down_vis_dif, flux_sw_down_total_dir, flux_sw_down_total_dif, \
    flux_sw_vis, flux_sw_vis_dir, flux_sw_vis_dif, \
    flux_lw, \
    surf_diff_delta_t, surf_diff_dflux_t, \
    surf_diff_dtmass, \
    surf_diff_delta_u, surf_diff_delta_v, \
    surf_diff_delta_co2, surf_diff_dflux_co2 \
    = aeolus2_update_down((Time_seconds[0], Time_days[0], dayoftheyear[0], \
        frac_land, t_surf, albedo, \
        albedo_vis_dir, albedo_nir_dir, albedo_vis_dif, albedo_nir_dif, \
        rough_mom, \
        u_star, b_star, q_star, dtau_du, dtau_dv, \
        u_flux, v_flux, \
        coszen, solar, \
        flux_lw_sf_up))

    print('aeolus2fms_update_down() Send return values')
    comm_fms.Send([u_flux,                 u_flux.size,                 MPI.DOUBLE], dest=peerrank, tag=250)
    comm_fms.Send([v_flux,                 v_flux.size,                 MPI.DOUBLE], dest=peerrank, tag=251)
    comm_fms.Send([gust,                   gust.size,                   MPI.DOUBLE], dest=peerrank, tag=252)
    comm_fms.Send([flux_sw,                flux_sw.size,                MPI.DOUBLE], dest=peerrank, tag=253)
    comm_fms.Send([flux_sw_dir,            flux_sw_dir.size,            MPI.DOUBLE], dest=peerrank, tag=254)
    comm_fms.Send([flux_sw_dif,            flux_sw_dif.size,            MPI.DOUBLE], dest=peerrank, tag=255)
    comm_fms.Send([flux_sw_down_vis_dir,   flux_sw_down_vis_dir.size,   MPI.DOUBLE], dest=peerrank, tag=256)
    comm_fms.Send([flux_sw_down_vis_dif,   flux_sw_down_vis_dif.size,   MPI.DOUBLE], dest=peerrank, tag=257)
    comm_fms.Send([flux_sw_down_total_dir, flux_sw_down_total_dir.size, MPI.DOUBLE], dest=peerrank, tag=258)
    comm_fms.Send([flux_sw_down_total_dif, flux_sw_down_total_dif.size, MPI.DOUBLE], dest=peerrank, tag=259)
    comm_fms.Send([flux_sw_vis,            flux_sw_vis.size,            MPI.DOUBLE], dest=peerrank, tag=260)
    comm_fms.Send([flux_sw_vis_dir,        flux_sw_vis_dir.size,        MPI.DOUBLE], dest=peerrank, tag=261)
    comm_fms.Send([flux_sw_vis_dif,        flux_sw_vis_dif.size,        MPI.DOUBLE], dest=peerrank, tag=262)
    comm_fms.Send([flux_lw,                flux_lw.size,                MPI.DOUBLE], dest=peerrank, tag=263)
    comm_fms.Send([surf_diff_delta_t,      surf_diff_delta_t.size,      MPI.DOUBLE], dest=peerrank, tag=264)
    comm_fms.Send([surf_diff_dflux_t,      surf_diff_dflux_t.size,      MPI.DOUBLE], dest=peerrank, tag=265)
    comm_fms.Send([surf_diff_dtmass,       surf_diff_dtmass.size,       MPI.DOUBLE], dest=peerrank, tag=266)
    comm_fms.Send([surf_diff_delta_u,      surf_diff_delta_u.size,      MPI.DOUBLE], dest=peerrank, tag=267)
    comm_fms.Send([surf_diff_delta_v,      surf_diff_delta_v.size,      MPI.DOUBLE], dest=peerrank, tag=268)
    comm_fms.Send([surf_diff_delta_co2,    surf_diff_delta_co2.size,    MPI.DOUBLE], dest=peerrank, tag=269)
    comm_fms.Send([surf_diff_dflux_co2,    surf_diff_dflux_co2.size,    MPI.DOUBLE], dest=peerrank, tag=270)

    print("aeolus2fms_update_down finished", flush=True)
# end of aeolus2fms_update_down()


def aeolus2fms_get_stock_pe():
    """
    retrieve local amount of water/heat/salt
    """
    idx = np.empty(1, dtype=np.int32)
    comm_fms.Recv([idx, 1, MPI.INT], source=peerrank, tag=70)

    # invoke the real implementation in aeolus2.py
    val = aeolus2_get_stock_pe(idx)

    comm_fms.Send([np.asarray(val), 1, MPI.DOUBLE], dest=peerrank, tag=70)
# end of aeolus2fms_get_stock_pe()


def aeolus2fms_finish():
    """
    clean up
    """
    global Time_seconds, Time_days

    aeolus2params = np.empty(2, dtype=np.int32)
    comm_fms.Recv([aeolus2params, aeolus2params.size, MPI.INT], source=MPI.ANY_SOURCE, tag=72)
    print("aeolus2fms_finish received aeolus2params", flush=True)
    Time_seconds[0]          = aeolus2params[1-1]
    Time_days[0]             = aeolus2params[2-1]
    aeolus2_finish(Time_seconds, Time_days)
# end of aeolus2fms_finish()


def aeolus2fms_restart():
    """
    write restart file
    """
    global Time_seconds, Time_days

    aeolus2params=np.empty(2,dtype=np.int32)
    comm_fms.Recv([aeolus2params, aeolus2params.size, MPI.INT], source=MPI.ANY_SOURCE, tag=73)
    print("aeolus2fms_restart received aeolus2params", flush=True)
    Time_seconds[0]          = aeolus2params[1-1]
    Time_days[0]             = aeolus2params[2-1]
    aeolus2_restart(Time_seconds, Time_days)
# end of aeolus2fms_restart()


def aeolus2fms_get_bottom_mass():
    """
    Return quantities at bottom of atmospere, as 2-D arrays.
    If CO2 is configured as tracer at the coupler interface,
    that value must also be returned.
    """
    dummy = np.empty(1,dtype=np.int32)
    comm_fms.Recv([dummy, 1, MPI.INT], source=peerrank, tag=80)

    # are these allocations necessary?
    #t_bot   = np.empty(ni*nj,dtype=np.float64)
    #q_bot   = np.empty(ni*nj,dtype=np.float64)
    #p_bot   = np.empty(ni*nj,dtype=np.float64)
    #z_bot   = np.empty(ni*nj,dtype=np.float64)
    #p_surf  = np.empty(ni*nj,dtype=np.float64)
    #slp     = np.empty(ni*nj,dtype=np.float64)
    #co2_bot = np.empty(ni*nj,dtype=np.float64)

    # invoke the real implementation in aeolus2.py
    t_bot, q_bot, p_bot, z_bot, p_surf, slp, co2_bot = aeolus2_get_bottom_mass()

    comm_fms.Send([t_bot,  t_bot.size,  MPI.DOUBLE], dest=peerrank, tag=81)
    comm_fms.Send([q_bot,  q_bot.size,  MPI.DOUBLE], dest=peerrank, tag=82)
    comm_fms.Send([p_bot,  p_bot.size,  MPI.DOUBLE], dest=peerrank, tag=83)
    comm_fms.Send([z_bot,  z_bot.size,  MPI.DOUBLE], dest=peerrank, tag=84)
    comm_fms.Send([p_surf, p_surf.size, MPI.DOUBLE], dest=peerrank, tag=85)
    comm_fms.Send([slp,    slp.size,    MPI.DOUBLE], dest=peerrank, tag=86)
    if (co2_ind >= 0):
        comm_fms.Send([co2_bot, co2_bot.size, MPI.DOUBLE], dest=peerrank, tag=87)
# end of aeolus2fms_get_bottom_mass


def aeolus2fms_get_bottom_wind():
    """
    get_bottom_wind
    """
    dummy = np.empty(1,dtype=np.int32)
    comm_fms.Recv([dummy, 1, MPI.INT], source=peerrank, tag=90)

    # invoke the real implementation in aeolus2.py
    u_bot, v_bot = aeolus2_get_bottom_wind()

    comm_fms.Send([u_bot, u_bot.size, MPI.DOUBLE], dest=peerrank, tag=91)
    comm_fms.Send([v_bot, v_bot.size, MPI.DOUBLE], dest=peerrank, tag=92)
# end of aeolus2fms_get_bottom_wind


##
## main()
##
print("pid",os.getpid(),"commandline",sys.executable,sys.argv)
sys.stdout.flush()

aeolus2fms_interface_check()
Status = MPI.Status()
# yes, an endless loop. broken out when a "finish" message is received
while True:
    # blocking wait for next incoming message,
    # which might be of several different types,
    # distinguished by their tags.
    comm_fms.Probe(source=peerrank, tag=MPI.ANY_TAG, status=Status)
    tag = Status.Get_tag()
    #count = Status.Get_count(MPI.DOUBLE)
    print('aeolus2fms received tag ',tag)
    if (tag == 50):
        aeolus2fms_init_grid()
    elif (tag == 60):
        aeolus2fms_init()
    elif (tag == 70):
        aeolus2fms_get_stock_pe()
    elif (tag == 72):
        aeolus2fms_finish()
        break
    elif (tag == 73):
        aeolus2fms_restart()

    elif (tag == 80):
        aeolus2fms_get_bottom_mass()
    elif (tag == 90):
        aeolus2fms_get_bottom_wind()

    elif (tag == 100):
        aeolus2fms_update_up()
    elif (tag == 200):
        aeolus2fms_update_down()
    else:
        print('Error: aeolus2fms.py does not know how to handle message tag', tag)
        raise Exception

print("aeolus2fms.py finished", flush=True)
MPI.Finalize()
