"""
test_MPI_Spawn.py

Test the Interface between FMS and Aeolus2 to invoked via
MPI_Comm_spawn.  This file contains a main program for coupled runs,
and functions to handle the communication with FMS. data from incoming
MPI messages is unpacked and passed as numpy arrays to functions which
are responsible to do the real work. Their return values are packed
into MPI messages and sent back to FMS.

"""

import sys
import os
import time
start_time = time.time()
#import argparse
#import pathlib
#import math                     as ma
#import mpmath                   as mp

import numpy                    as np
#import scipy.integrate          as integrate #Gives access to the ODE integration package
#import scipy.special            as ss
#import scipy.io                 as sio
#from scipy.sparse import linalg as spla

from mpi4py import MPI
#from multiprocessing import Pool

#import dedalus.public           as de
## Load config options
#from dedalus.tools.config import config

#import sphere_wrapper           as sph
#from NetCDFOutput import NetCDFOutput


# aeolus2.py implements the real work.
#from aeolus2 import aeolus2_init_grid, aeolus2_init, \
#    aeolus2_restart, aeolus2_finish, \
#    aeolus2_get_stock_pe, \
#    aeolus2_update_down, aeolus2_update_up


epsilon = 1e-10
def double_eq(v1,v2,e): return np.fabs((v1-v2)/((v1+v2)/2)) < e

# Find MPI rank
comm      = MPI.COMM_WORLD
rank      = comm.rank
size      = comm.size

parentcomm = MPI.Comm.Get_parent()

## Define some variables on global scope.
# TODO: check which of these variables are really needed globally
L_dealias     = None
L_max         = None # spherical harmonic order
S_max         = None # spin order (leave fixed)
domain        = None
S             = None # Sphere object

# indices of the borders of the compute domain in the global grid.
# Note: following python conventions, the end indices are one
# _beyond_ the actual end.
ics  = 0 # sigh. "is" is a reserved word in python
ice  = None
jcs  = 0
jce  = None

ncout = None     # NetCDFOutput handle
axes_t = None    # time series of scalars
axes_ty = None   # time series of zonal-mean values
axes_yx = None   # static maps
axes_tyx = None  # time series of maps

# start time of experiment
Time_init_year = None
Time_init_month = None
Time_init_day = None
Time_init_seconds = None
# start time of this run, within the experiment
Time_days = None
Time_seconds = None

Time_step_days = None
Time_step_seconds = None

without_topo = False
update_land_frac_always = False

nr_tracers = None # number of tracers to exchange with coupler
q_ind = None      # index of humidity tracer
co2_ind = None    # index of CO2 tracer


def interface_check():
    """
    as the function name says
    """
    sys.stdout = open("fms.py.out-"
                      +os.getenv("SLURM_JOBID", default="")+"-"
                      +os.getenv("SLURM_NNODES", default="")+"-"
                      +os.getenv("SLURM_NTASKS", default="")+"-"
                      +"%(rank)02d"%{"rank":rank}, "w")
    sys.stderr = sys.stdout

    print("pid",os.getpid(),"commandline",sys.executable,sys.argv)
    mypid = np.array(os.getpid(), dtype='i')
    print("test_MPI_Spawn.py host/pid ", os.uname()[1], "/", mypid, " MPI rank ", rank, " of ", size, flush=True)

    if parentcomm is None:
        print("no parent communicator")
        print("test_MPI_Spawn.py: but dont know how to handle standalone runs yet", flush=True)
        raise Exception
    parentsize = parentcomm.size
    parentrank = parentcomm.rank
    remotesize = parentcomm.Get_remote_size()

    print("test_MPI_Spawn.py: task ", rank, " is task ", parentrank,
          " in parent of size ", parentsize, " remote size ", remotesize)
    if (parentsize != remotesize):
        print("Error: local size of parent communicator must be equal to remote size of parent communicator", parentsize, remotesize, flush=True)
        raise Exception
    if (size != remotesize):
        print("Error: size of local communicator must be equal to remote size of parent communicator", size, remotesize, flush=True)
        raise Exception

    for i in range(0, remotesize):
        parentcomm.Send([mypid, 1, MPI.INT], dest=i, tag=42)

    remotepid=np.array(42, dtype='i')
    for i in range(0, remotesize):
        parentcomm.Recv([remotepid, 1, MPI.INT], source=MPI.ANY_SOURCE, tag=43)
        print("aeolus2fms.py task ", rank, " received pid ", remotepid, flush=True)

    sys.stdout.flush()
    sys.stderr.flush()
    print('interface_check entering first barrier', flush=True)
    parentcomm.Barrier()

    deadbeef=np.empty(1,dtype='i')
    print('interface_check receiving deadbeef', flush=True)
    parentcomm.Recv([deadbeef, 1, MPI.INT], source=rank, tag=44)
    if deadbeef != 0xeadbeef: # 0xdeadbeef yields overflow error on 32bit machines
        print("Error: test1 failed: 0xeadbeef != ", deadbeef, flush=True)
    pitest=np.empty(1,dtype=np.float64)
    print('interface_check receiving pitest', flush=True)
    parentcomm.Recv([pitest, 1, MPI.DOUBLE], source=rank, tag=45)
    if pitest != np.pi:
        print("Error: test2 failed: %(pitest)30.28f != %(M_PI)30.28f"
              %{"pitest":pitest[0],"M_PI":np.pi}, flush=True)

    # Fortran arrays are index 1:n , numpy arrays are denoted [0:n]
    #    but last element is [n-1]
    # If we use python indices for calculations, we must add 1 to
    # yield the same results as in Fortran!
    testarray1d = np.empty((2*3), dtype=np.float64)
    print('interface_check receiving testarray2d', flush=True)
    parentcomm.Recv([testarray1d, testarray1d.size, MPI.DOUBLE], source=rank, tag=46)
    #print("flat", testarray1d)
    #print("reshape F", np.reshape(testarray1d,(2,3),order='F'))
    #print("... i=0", np.reshape(testarray1d,(2,3),order='F')[0,:])
    #print("reshape C", np.reshape(testarray1d,(2,3),order='C'))
    #print("transpose(reshape F)",np.transpose(np.reshape(testarray1d,(2,3),order='F')))
    #print("... i=0", np.transpose(np.reshape(testarray1d,(2,3),order='F'))[0,:])
    #sys.stdout.flush()
    testarray2d = (np.reshape(testarray1d,(2,3),order='F'))
    #for k in range(0, 4):
    for j in range(0, 3):
        for i in range(0, 2):
            #idx = (k-1)*3*2 + (j-1)*2 + (i-1)
            testval = testarray2d[i,j]
            myval = (i+1)+(j+1)/10.0
            #testval = testarray3d[i,j,k]
            #myval = i*100.0+j+k/10.0
            #printf("testarray3d(%2d,%2d,%2d) = testarray3d[%2d] = %10.9g vs. %10.9g\n", i, j, k, idx, testval, myval)
            if not double_eq(testval, myval, epsilon):
                print("Error: test5 failed: testarray2d(%(i)d,%(j)d) is %(testval)10.9g != %(myval)10.9g"
                      %{"i":i, "j":j, "testval":testval, "myval":myval}, flush=True)
    print('interface_check sending back one double', flush=True)
    pitest[0] = 8.15
    parentcomm.Send([pitest, 1, MPI.DOUBLE], dest=rank, tag=47)
    parentcomm.Send([testarray2d, testarray2d.size, MPI.DOUBLE], dest=rank, tag=48)
    print('interface_check entering final barrier', flush=True)
    parentcomm.Barrier()
    print("interface_check finished", flush=True)
# end of interface_check




def finish():
    """
    """
    aeolus2params=np.empty(2,dtype=np.int32)
    parentcomm.Recv([aeolus2params, aeolus2params.size, MPI.INT], source=MPI.ANY_SOURCE, tag=72)
    print("finish received aeolus2params", flush=True)
    Time_seconds          = aeolus2params[1-1]
    Time_days             = aeolus2params[2-1]
    #aeolus2_finish(Time_seconds, Time_days)
# end of finish


##
## main()
##
print("pid",os.getpid(),"commandline",sys.executable,sys.argv)
sys.stdout.flush()
Status = MPI.Status()

interface_check()
# yes, end endless loop. broken out when a "finish" message is received
while True:
    # blocking wait for next incoming message,
    # which might be of several different types,
    # distinguished by their tags.
    parentcomm.Probe(source=rank, tag=MPI.ANY_TAG, status=Status)
    tag = Status.Get_tag()
    if (tag == 50):
        aeolus2fms_init_grid()
    elif (tag == 60):
        aeolus2fms_init()
    elif (tag == 70):
        aeolus2fms_get_stock_pe()
    elif (tag == 72):
        finish()
        break
    elif (tag == 73):
        aeolus2fms_restart()

    elif (tag == 80):
        aeolus2fms_get_bottom_mass()
    elif (tag == 90):
        aeolus2fms_get_bottom_wind()

    elif (tag == 100):
        aeolus2fms_update_up()
    elif (tag == 200):
        aeolus2fms_update_down()
    else:
        print('Error: aeolus2fms.py does not know how to handle message tag', tag)
        raise Exception

print("test_MPI_Spawn.py finished", flush=True)
MPI.Finalize()
