"""
aeolus2printgrid.py

Print out information about the spherical grid that can be used
by the FMS grid creation tools to construct axchange grids etc.

make_hgrid --grid_type frm_file --my_grid_file aeolus2xxx.txt \
--nlon X \
--nlat Y \
--grid_name aeolus2xxx

where xxx is one of [smooth, fast, super_fast]
resp. [768x384, 384x192, 192x96]

contents of sample my_grid_file
 The first line of my_grid_file will be text ( will be ignored)
 followed by nlon+1 lines of real value of x-direction supergrid bound
 location. Then another line of text ( will be ignored), followed by
 nlat+1 lines of real value of y-direction supergrid bound location.
"""

import sys
import os
import argparse
import math                     as ma

import numpy                    as np

#from mpi4py import MPI

import dedalus.public           as de
# Load config options
from dedalus.tools.config import config

import sphere_wrapper           as sph

# SP: setting non_equidis_grid is irrelevant for the model results.
#     The sphere object uses its own internal grid,
#     regardless of non_equidis_grid
non_equidis_grid   = 1                     # 0: equidistance with respect to lon lat

smooth_run         = 1                     # 1: latitudinal non-equidistance fine resolution (768*384) with 1 min numerical time step;
fast_run           = 0                     # 1: time step=step min, some second oreder nonliear terms have been modified, resolution: (384*192)
super_fast         = 0                     # 1: time step=step min, some second oreder nonliear terms have been modified, resolution: (192*96)

epsilon = 1e-10
def double_eq(v1,v2,e): return np.fabs((v1-v2)/((v1+v2)/2)) < e

# some of the user settings can be overridden via command line
parser = argparse.ArgumentParser()
parser.add_argument("-g", "--grid", help="grid resolution [smooth, fast, super_fast]")

args = parser.parse_args()
print(args)
if args.grid:
    if args.grid   == 'smooth':
        smooth_run =  1; fast_run = 0; super_fast = 0
    elif args.grid == 'fast':
        smooth_run =  0; fast_run = 1; super_fast = 0
    elif args.grid == 'super_fast':
        smooth_run =  0; fast_run = 0; super_fast = 1
    else:
        print("Error: grid must be one of [smooth, fast, super_fast] but is ",args.grid)
        quit(42)
else:
    args.grid = 'smooth'

# Discretization parameters
L_dealias     = 3/2
if smooth_run   ==1:
    L_max = 255  # spherical harmonic order
    S_max = 3    # spin order (leave fixed)
if fast_run    ==1:
    L_max = 127  # spherical harmonic order
    S_max = 3    # spin order (leave fixed)
if super_fast  == 1:
    L_max = 63  # spherical harmonic order
    S_max = 3   # spin order (leave fixed)

# Sigh. we need supergrid, i.e. cell edges and as well as cell centers.
# Duplication of resolution with supergridfactor=2 does not work well.
# -> Longitudes look reasonable.
#    But not the latitudes. The sphere object always produces a grid with cell
#    centers north and south of the equator. But the supergrid should have a
#    cell edge just at latitude = 0.
# Thus we compute cell bound coordinates and use those for the supergrid.
# For longitudes that is straight forward, just compute midpoints.
# For the non-uniform latitudinal grid distances, we also use this simple
# approach, although it is only a rough aproximation.
supergridfactor = 1 # 2
theta_basis = de.Fourier('theta', supergridfactor*(L_max+1), interval=(0,np.pi)  , dealias=L_dealias)
#print('theta_basis', theta_basis)
# Make domain
lamda_basis      = de.Fourier('lamda', 2*(supergridfactor*(L_max+1)), interval=(0,2*np.pi), dealias=L_dealias)
domain           = de.Domain([lamda_basis,theta_basis], grid_dtype=np.float64, comm=None)


# set up sphere
m_start = domain.distributor.coeff_layout.start(1)[0]
m_len   = domain.distributor.coeff_layout.local_shape(1)[0]
m_end   = m_start + m_len - 1
N_theta = int(supergridfactor*(L_max+1)*L_dealias)
print("L_max, S_max, L_dealias, m_start, m_len, m_end, N_theta",
      L_max, S_max, L_dealias, m_start, m_len, m_end, N_theta)

S            = sph.Sphere(supergridfactor*(L_max+1)-1,S_max,N_theta=N_theta,m_min=m_start,m_max=m_end)
#print("S grid, weights, sin_grid, cos_grid", S, S.grid, S.weights, S.sin_grid, S.cos_grid)

lamda        = domain.grids(L_dealias)[0]
theta_slice  = domain.distributor.grid_layout.slices(domain.dealias)[1]
theta_len    = domain.local_grid_shape(domain.dealias)[1]
theta_global = S.grid
theta        = S.grid[theta_slice].reshape((1,theta_len))
phi          = np.pi/2.0-theta

#print("theta_len ",        theta_len)
#print("theta_global.size", theta_global.size)
#print("theta_slice.start/stop/step", theta_slice.start, theta_slice.stop, theta_slice.step)
#print("theta_slice",       theta_slice)
#print("theta_global", theta_global)
#print("theta", theta)
#print("theta in deg", theta/(2*np.pi)*360.0)

# Sigh. The current input topographies are organized as -180 to 180 deg longitude,
# but lamda runs 0 to 2pi .
# Thus we need to rotate by -180deg to obtain a correct axis for NetCDF.
lons = lamda[:,0]/(2.*np.pi)*360.0 - 180.0
# Sigh. This results in Y running N to S, while FMS wants coordinates running S to N.
# At least this is consistent with the current input topographies.
# ncview and ferret both revert flip the latitudes automagically,
# with more or less warnings printed out.
lats = theta[0,:]/(2*np.pi)*360.0 - 90.0
#print('lats', lats)

# Indices of the local domain boundaries in the global domain.
# Naming convention taken from GFDLs FMS coupler.
# i/j/k denotes lon/lat/vertical direction
# c/d for compute resp. data domain (i.e. without/with halos)
# s/e for start/end
ics  = 0 # sigh. "is" is a reserved word in python
ice  = lons.size
jcs  = theta_slice.start
jce  = theta_slice.stop

#print("Lon size, Lat size: ", lons.size, lats.size)
print("Lons shape, Lats shape: ", lons.shape, lats.shape)
#print("lons",    lons)
#print("lats",    lats)

#dlons=lons[0:lons.size-2] - lons[1:lons.size-1]
dlon=360.0/lons.size
print(dlon)
lonb = np.empty(lons.size+1)
#print("lonb.shape: ", lonb.shape)
lonb[0] = lons[0] - dlon/2
#print("lons slices shapes ", lons[0:lons.size-1].shape, " ", lons[1:lons.size].shape)
lonb[1:lons.size] = (lons[0:lons.size-1]+lons[1:lons.size])/2
lonb[lonb.size-1] = lons[lons.size-1] + dlon/2
latb = np.empty(lats.size+1)
latb[0] = 90.0
latb[1:lats.size] = (lats[0:lats.size-1]+lats[1:lats.size])/2
latb[latb.size-1] = -90.0

gridinfo = {
   "nlons_global" : lons.size,
   "nlats_global" : theta_global.size,
   "lons"         : lons[:],
   "lonb"         : lonb[:],
   "lats"         : lats[:],
   "latb"         : latb[:],
   "ics"          : 0,
   "ice"          : lons.size,         # according to python conventions, one after end
   "jcs"          : theta_slice.start,
   "jce"          : theta_slice.stop,  # according to python conventions, one after end
}

# now print the lons and lats
fn="aeolus2grid-"+args.grid+".txt"
with open(fn, 'w') as f:
    print("# aeolus2 "+args.grid+" supergrid longitudes", file=f)
    for i in range(0, lons.size):
        print("%(lonb)15.11f"%{"lonb":lonb[i]}, file=f)
        print("%(lons)15.11f"%{"lons":lons[i]}, file=f)
        #if i > 0 : print(lons[i]-lons[i-1])
        #print(180.0, file=f)
    print("%(lonb)15.11f"%{"lonb":lonb[lonb.size-1]}, file=f)

    print("# aeolus2 "+args.grid+" supergrid latitudes", file=f)
    for i in range(0, lats.size):
        print("%(latb)15.11f"%{"latb":latb[i]}, file=f)
        print("%(lats)15.11f"%{"lats":lats[i]}, file=f)
    print("%(latb)15.11f"%{"latb":latb[latb.size-1]}, file=f)
