import importlib
import os
os.environ['PROJ_LIB'] = '/home/jan/anaconda3/share/proj' #anaconda3 miniconda3
import sys
sys.path.append('../misc')
import numpy as np
from netCDF4 import Dataset  # http://code.google.com/p/netcdf4-python/
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap, addcyclic, shiftgrid
from nc_dump import ncdump
import math
import scipy.stats
import pickle
#import xarray as xr
import pandas as pd
import matplotlib.colors as mcol
from matplotlib.colors import LinearSegmentedColormap
#my scripts:

import tools_plot as jplt
importlib.reload(jplt) 
import tools_analysis as jan
importlib.reload(jan)  

def movingaverage (values, window):
    weights = np.repeat(1.0, window)/window
    sma = np.convolve(values, weights, 'valid')
    return sma

##############################################################################
### Set the Emission Scenario Parameters and generate the forcing files
##############################################################################

# emission scenario parameters:
duration=1200. # time interval in years in which 95% of the emission fall (+-2sigma); short: 1200., long: 6000., single: 100.    
    
c_total=5300. #in GtC low: 2500., medium: 5300., high, 7500, or 0.

sulf_inj_tot=125.         #total GtS injected to stratosphere [0,50,125,250,500]
sulf_inj_single=250.     #TgS or MtS stratospheric injection by one injection [9,50,100,200,250,500,1000]
distribution='normal'     #propability distribution of the eruptions: 'normal' or 'uniform'

scen_schmidt=False 
sparse_eruptions=False 

if duration==1200. or duration==6000.:
    if duration==1200.:
        scenario_length=5000. #years
    elif duration==6000.:
        scenario_length=12000.        
    c_scenario_string='_'+str(int(duration))+'y_'+str(int(c_total))+'GtC'
    n_string='multi'
    sulf_scenario_string='_'+n_string+'_'+str(int(duration))+'y_'+str(int(sulf_inj_tot))+'GtS_'+str(int(sulf_inj_single))+'MtS' #_rev
    if sparse_eruptions==True:
        sulf_scenario_string='_'+n_string+'_'+str(int(duration))+'y_'+str(int(sulf_inj_tot))+'GtS_'+str(int(sulf_inj_single))+'MtS_sparse' 
elif duration==100.:
    scenario_length=100.
    n_string='single'
    sulf_scenario_string= '_'+n_string+'_'+str(int(sulf_inj_single))+'MtS'
    c_scenario_string='none'
if scen_schmidt!=False:
    if scen_schmidt==10:
            scenario_length=100.
            n_string='schmidt'
            sulf_inj_single=(1200.*0.44)/365 #12000 MtS distributed over 10 years -> 3650 days -> 3.29 MtS by daily eruptions; equivalent to amount of S in Schmidt that effectively forms aerosols
            sulf_scenario_string= '_'+n_string+'_'+'12000MtS_0p44_10yrs'
            c_scenario_string='none'
    if scen_schmidt==50:
            scenario_length=100.
            n_string='schmidt'
            sulf_inj_single=(1200.*0.44)/365
            sulf_scenario_string= '_'+n_string+'_'+'60000MtS_0p44_50yrs'
            c_scenario_string='none'

        
years=np.arange(1,scenario_length+1,1)    
sigma=duration/4 #300 #years
mu=3*sigma+400 #age of the peak in years
solar_constant=1339. #W/m^2, see SBC_UPDATES/POTSDAM2/ini.F -> S0

eruption_dates_filename='./examples/eruption_dates_'+str(int(duration))+'y.pickle'
if scen_schmidt!=False:
    eruption_dates_filename='./examples/eruption_dates_'+str(scen_schmidt)+'y.pickle'

c_scenario_filename='./examples/carbon_emission'+c_scenario_string+'.dat'
eruption_history_filename='./examples/eruptions'+sulf_scenario_string+'_rev_0p1.nc'
if scen_schmidt!=False:
    eruption_history_filename='./examples/eruptions'+sulf_scenario_string+'_rev_0p1.nc'
aod_filename='./examples/aod'+sulf_scenario_string+'_rev_0p1_ave.nc'
solarv_filename='./examples/solarv'+sulf_scenario_string+'_rev_0p1.dat'


### Carbon Emission Scenario:
c_rate_max=c_total/np.sqrt(2*sigma**2*np.pi) #max 9.97, min 0.66 GtC/yr
#c_total=c_rate_max*np.sqrt(2*sigma**2*np.pi)
c_rate=c_rate_max*np.exp(-1./2*((years-mu)/sigma)**2)
c_emission_scenario=np.zeros((len(years),2))
c_emission_scenario[:,0]=years
c_emission_scenario[:,1]=c_rate
DataOut = np.column_stack((years,c_rate))
np.savetxt(c_scenario_filename, DataOut, fmt=('%i', '%.6f'))

### Sulfur Aerosol Forcing Scenario:
if duration==100.:
    n_erupt=1;
    #year_erupt=3; month_erupt=3; day_erupt=21; 
    year_erupt=3; month_erupt=7; day_erupt=1;
    ssi_erupt=sulf_inj_single; hemi_erupt=1; lat_erupt=0;
    if sulf_inj_single==9.: #Pinatubo (Toohey 2016)
        year_erupt=3; month_erupt=6; day_erupt=15; 
        ssi_erupt=9; hemi_erupt=1; lat_erupt=15.1;

if scen_schmidt!=False:
    n_erupt=365*scen_schmidt
    year_erupt=np.zeros(n_erupt)
    month_erupt=np.zeros(n_erupt)
    day_erupt=np.zeros(n_erupt)
    lat_erupt=np.zeros(n_erupt); lat_erupt[:]=-21  
    hemi_erupt=np.empty(n_erupt); hemi_erupt[:]=1
    ssi_erupt=np.empty(n_erupt); ssi_erupt[:]=sulf_inj_single
    for nn in np.arange(n_erupt):
            year_erupt[nn]=np.floor(nn/365)
            month_erupt[nn]=np.floor((nn-year_erupt[nn]*365)/30)
            day_erupt[nn]=nn-year_erupt[nn]*365-month_erupt[nn]*30
    year_erupt+=3        
    month_erupt+=1
    day_erupt+=1   
        
        
        
if duration==1200 or duration==6000:
    if duration==1200:           
        nc_f = './examples/eruptions_multi_1200y_100GtS_200MtS.nc' 
    elif duration==6000:
        nc_f = './examples/eruptions_multi_6000y_500GtS_200MtS.nc'
    nc_fid = Dataset(nc_f, 'r')
    #nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)
    injections=nc_fid.variables['ssi'][:] 
    injections_relative=injections/200      
        
if duration==1200. or duration==6000. and sparse_eruptions!=True:
    n_erupt=int(np.round(sulf_inj_tot/(0.001*sulf_inj_single))) #500
    #sulf_inj_singles=np.rint(np.random.normal(loc=sulf_inj_single,scale=1./5*sulf_inj_single,size=n_erupt))
    sulf_inj_singles=sulf_inj_single*injections_relative
    ssi_erupt=np.empty(n_erupt); ssi_erupt[:]=sulf_inj_singles[:]
    hemi_erupt=np.empty(n_erupt); hemi_erupt[:]=1
    #Create Eruption Dates and store them, so that they can be used for all runs    
    lower = 1+400; upper = 3./2*duration+400;
    if distribution=='normal':
        sigma=duration/4; mu=3*sigma+400; 
        year_erupt = np.sort(np.rint(scipy.stats.truncnorm.rvs((lower-mu)/sigma,(upper-mu)/sigma,loc=mu,scale=sigma,size=n_erupt)))
        #year_erupt=np.sort(np.rint(np.random.normal(loc=duration+200,scale=duration/4,size=n_erupt))) 
    elif distribution=='uniform':
        year_erupt=np.sort(np.rint(np.random.uniform(low=lower,high=upper,size=n_erupt)))  
             
    month_erupt=np.random.randint(low=1,high=12,size=n_erupt)
    day_erupt=np.random.randint(low=1,high=30,size=n_erupt)
    lat_erupt=np.random.randint(low=-20,high=20,size=n_erupt)    
    #sorting eruption dates:
    for i in range(n_erupt-1):
        if year_erupt[i]==year_erupt[i+1]:
            if month_erupt[i]>month_erupt[i+1]:
                a=month_erupt[i]; b=month_erupt[i+1]
                month_erupt[i]=b; month_erupt[i+1]=a
            elif month_erupt[i]==month_erupt[i+1]: 
                if day_erupt[i]>day_erupt[i+1]:
                    a=day_erupt[i]; b=day_erupt[i+1]
                    day_erupt[i]=b; day_erupt[i+1]=a             

#    with open(eruption_dates_filename, 'wb') as f:  
#            pickle.dump([day_erupt,month_erupt,year_erupt,lat_erupt], f)
    with open(eruption_dates_filename, 'rb') as f:  
        day_erupt,month_erupt,year_erupt,lat_erupt = pickle.load(f, encoding='latin1')
       
#if duration==1200. and 
if sparse_eruptions==True:
    #n_erupt=int(sulf_inj_tot/(0.001*sulf_inj_single))
    step_e=25
    nerup=500./step_e
    erup_hist_filename='./examples/eruptions_multi_1200y_100GtS_200MtS.nc' 
    erup_hist_filename_sparse='./examples/eruptions_multi_1200y_10GtS_500MtS_sparse.nc'
    
    nc_f =erup_hist_filename
    nc_fid = Dataset(nc_f, 'r')
    #nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)
    nerup_e=nc_fid.dimensions['nerup'] #[::10]
    year_e=nc_fid.variables['year'][1::step_e]
    month_e=nc_fid.variables['month'][1::step_e]
    day_e=nc_fid.variables['day'][1::step_e]
    lat_e=nc_fid.variables['lat'][1::step_e]
    ssi_e=2.5*nc_fid.variables['ssi'][1::step_e]
    hemi_e=nc_fid.variables['hemi'][1::step_e]
    
    eruption_history = Dataset(erup_hist_filename_sparse, 'w', format='NETCDF4_CLASSIC') 
    level = eruption_history.createDimension('nerup', nerup) 
    year = eruption_history.createVariable('year', np.float64, ('nerup',)) 
    month = eruption_history.createVariable('month', np.float64, ('nerup',))
    day = eruption_history.createVariable('day', np.float64, ('nerup',))
    lat = eruption_history.createVariable('lat', np.float64, ('nerup',))
    ssi = eruption_history.createVariable('ssi', np.float64, ('nerup',))
    hemi = eruption_history.createVariable('hemi', np.float64, ('nerup',))
    year.units = 'year'; month.units = 'month'; day.units = 'day'; 
    lat.units = 'degrees N'; ssi.units = 'Tg[S]'; hemi.units = ''; 
    year[:]=year_e; month[:]=month_e; day[:]=day_e; 
    ssi[:]=ssi_e; hemi[:]=hemi_e; lat[:]=lat_e;
    eruption_history.close()
    
     
#generate eruption history for EVA-Input:    
#eruption_history = Dataset('./Aerosol/eruption_history_'+erupt_scen+'.nc', 'w', format='NETCDF4_CLASSIC') 
eruption_history = Dataset(eruption_history_filename, 'w', format='NETCDF4_CLASSIC') 
level = eruption_history.createDimension('nerup', n_erupt) 
year = eruption_history.createVariable('year', np.float64, ('nerup',)) 
month = eruption_history.createVariable('month', np.float64, ('nerup',))
day = eruption_history.createVariable('day', np.float64, ('nerup',))
lat = eruption_history.createVariable('lat', np.float64, ('nerup',))
ssi = eruption_history.createVariable('ssi', np.float64, ('nerup',))
hemi = eruption_history.createVariable('hemi', np.float64, ('nerup',))
year.units = 'year'; month.units = 'month'; day.units = 'day'; 
lat.units = 'degrees N'; ssi.units = 'Tg[S]'; hemi.units = ''; 
year[:]=year_erupt; month[:]=month_erupt; day[:]=day_erupt; 
ssi[:]=ssi_erupt; hemi[:]=hemi_erupt; lat[:]=lat_erupt;
eruption_history.close()


# -> run EVA
# calculate global annual mean in Ferret: S
# SAVE/NCFORMAT=4/FILE=aod_single_100MtS_ave.nc  aod550[j=@ave,l=1:1200:12@ave]
  

#calculate SolarConstant-Forcing from EVA-Output:
nc_f =aod_filename
nc_fid = Dataset(nc_f, 'r')
#nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)
aod_yearly=nc_fid.variables['AOD550'][:]
years_a=np.arange(0,len(aod_yearly),1)
solar_constant_forced=solar_constant-4*21*1./(1-0.3)*aod_yearly+0.58 #+0.58 because of Background;
DataOut = np.column_stack((years_a,solar_constant_forced))
np.savetxt(solarv_filename, DataOut, fmt=('%6.2f', '%6.2f'))


##############################################################################
### Compare the different Emission Scenarios
##############################################################################

### Find out maximum number of eruptions per decade and the sulfur injection in this decade
erupt_files=['eruptions_multi_1200y_50GtS_100MtS.nc',
             'eruptions_multi_1200y_125GtS_250MtS.nc',
             'eruptions_multi_1200y_250GtS_500MtS.nc',
             'eruptions_multi_1200y_500GtS_1000MtS.nc',
             'eruptions_multi_1200y_600GtS_1200MtS_rev_0p1.nc',
             'eruptions_multi_6000y_125GtS_50MtS_rev_0p1.nc',
             'eruptions_multi_6000y_250GtS_100MtS_rev.nc',
             'eruptions_multi_6000y_500GtS_200MtS_rev.nc']
maxInjMtS_df=pd.DataFrame([], columns=['maxInjNr_10yr','maxInjNr_100yr','MaxInjMtS_10yr','MaxInjMtS_100yr'], index=erupt_files)
for erupt_file in erupt_files:
    nc_f = './examples/'+erupt_file 
    nc_fid = Dataset(nc_f, 'r')
    injs=nc_fid.variables['ssi'][:] 
    injs_yrs=nc_fid.variables['year'][:] 
    years=np.arange(1,5000+1,1)
    
    window=10 #years
    n_10yr=np.empty(len(years)-window)
    ssi_10yr=np.empty(len(years)-window)
    for i in range(len(n_10yr)):
        n_10yr[i]=len(injs_yrs[abs(injs_yrs-years[i+int(1./2*window-1)]) <= 1./2*window-1]) 
        ssi_10yr[i]=np.sum(injs[abs(injs_yrs-years[i+int(1./2*window-1)]) <= 1./2*window-1])
    maxInjMtS_df.loc[erupt_file, ['maxInjNr_10yr','MaxInjMtS_10yr']]=[np.max(n_10yr),np.max(ssi_10yr)]     

    window=100 #years
    n_100yr=np.empty(len(years)-window)
    ssi_100yr=np.empty(len(years)-window)
    for i in range(len(n_100yr)):
        n_100yr[i]=len(injs_yrs[abs(injs_yrs-years[i+int(1./2*window-1)]) <= 1./2*window-1]) 
        ssi_100yr[i]=np.sum(injs[abs(injs_yrs-years[i+int(1./2*window-1)]) <= 1./2*window-1])
    maxInjMtS_df.loc[erupt_file, ['maxInjNr_100yr','MaxInjMtS_100yr']]=[np.max(n_100yr),np.max(ssi_100yr)] 
    

### Calculate Peak Aerosol Forcing
solarv_files=['solarv_multi_1200y_50GtS_100MtS_rev_0p1.dat',
                  'solarv_multi_1200y_125GtS_250MtS_rev_0p1.dat',
                  'solarv_multi_1200y_250GtS_500MtS_rev_0p1.dat',
                  'solarv_multi_1200y_500GtS_1000MtS_rev_0p1.dat',
                  'solarv_multi_1200y_600GtS_1200MtS_rev_0p1.dat',
                  'solarv_multi_6000y_125GtS_50MtS_rev_0p1.dat',
                  'solarv_multi_6000y_250GtS_100MtS_rev_0p1.dat',
                  'solarv_multi_6000y_500GtS_200MtS_rev_0p1.dat']
maxRadForcing_df=pd.DataFrame([], columns=['maxS0red_1yr','maxF_1yr','maxS0red_10yr','maxF_10yr','maxS0red_100yr','maxF_100yr'], index=solarv_files)
for solarv_file in solarv_files:
    ff='./examples/'+solarv_file
    years_a, solar_constant_forced=np.loadtxt(ff, unpack=True)
    
    maxS0red_1yr=np.max(solar_constant_forced[0]-solar_constant_forced)
    maxS0red_10yr=np.max(solar_constant_forced[0]-movingaverage(solar_constant_forced,10))
    maxS0red_100yr=np.max(solar_constant_forced[0]-movingaverage(solar_constant_forced,100))
    
    maxRadForcing_df.loc[solarv_file, ['maxS0red_1yr','maxF_1yr']]=[maxS0red_1yr, 0.7/4*maxS0red_1yr]
    maxRadForcing_df.loc[solarv_file, ['maxS0red_10yr','maxF_10yr']]=[maxS0red_10yr, 0.7/4*maxS0red_10yr]
    maxRadForcing_df.loc[solarv_file, ['maxS0red_100yr','maxF_100yr']]=[maxS0red_100yr, 0.7/4*maxS0red_100yr]

solarv_single_files=['solarv_single_9MtS_rev_0p1.dat',
                  'solarv_single_50MtS_rev_0p1.dat',
                  'solarv_single_100MtS_rev_0p1.dat',
                  'solarv_single_200MtS_rev_0p1.dat',
                  'solarv_single_250MtS_rev_0p1.dat',
                  'solarv_single_500MtS_rev_0p1.dat',
                  'solarv_single_1000MtS_rev_0p1.dat']
maxRadForcing_single_df=pd.DataFrame([], columns=['maxS0red','maxF','maxF_monthly'], index=solarv_single_files)
for solarv_single_file in solarv_single_files:
    ff='./examples/'+solarv_single_file
    years_a, solar_constant_forced=np.loadtxt(ff, unpack=True)
    maxS0red=np.max(solar_constant_forced[0]-solar_constant_forced)
    maxRadForcing_single_df.loc[solarv_single_file, ['maxS0red','maxF']]=[maxS0red, 0.7/4*maxS0red]
    
    if solarv_single_file in ['solarv_single_9MtS_rev_0p1.dat', 'solarv_single_250MtS_rev_0p1.dat', 'solarv_single_1000MtS_rev_0p1.dat']:
        nc_f ='./examples/aod'+solarv_single_file[6:-4]+'_zonalmean.nc'
        nc_fid = Dataset(nc_f, 'r')
        aod_monthly=nc_fid.variables['AOD550'][:]
        maxRadForcing_single_df.loc[solarv_single_file, ['maxF_monthly']]=[np.max((aod_monthly-aod_monthly[0]))*21]

    
    

    
##############################################################################
### Multiple Pulse Scenarios for LOSCAR simulations
##############################################################################

# create Multi-Pulse Scenario
duration_c_multi=int(1e6)
years_c_multi=np.arange(1,duration_c_multi+1,1)
c_rate_multi=np.zeros(duration_c_multi)
#output_vars=['year', 'tco2', 'alk', 'd13Cocn', 'pCO2', 'd13Catm', 'CO3', 'WeatC', 'WeatS', 'TotW', 'BurC', 'Degas', 'Tatm', 'Tocn']
#for i in range(14):

#pulse 1
filename='./examples/carbon_emission_1200y_5300GtC.dat'
years, c_rate=np.loadtxt(filename, unpack=True)
c_rate_multi[0:len(years)]=c_rate

#pulse 2
filename='./examples/carbon_emission_1200y_5300GtC.dat'
years, c_rate=np.loadtxt(filename, unpack=True)
start_2=int(np.round((201.565-201.5035)*1e6)) #following Paris et al. 2016
c_rate_multi[start_2:start_2+len(years)]=c_rate
#pulse 3
#filename='./EmissionScenarios/carbon/carbon_emission_1200y_5300GtC.dat'
filename='./examples/carbon_emission_6000y_5300GtC.dat'
years, c_rate=np.loadtxt(filename, unpack=True)
start_3=int(np.round((201.565-201.2895)*1e6)) #following Paris et al. 2016
c_rate_multi[start_3:start_3+len(years)]=c_rate
#pulse 4
#filename='./EmissionScenarios/carbon/carbon_emission_1200y_5300GtC.dat'
filename='./examples/carbon_emission_6000y_5300GtC.dat'
years, c_rate=np.loadtxt(filename, unpack=True)
start_4=int(np.round((201.565-200.916)*1e6)) #following Paris et al. 2016
c_rate_multi[start_4:start_4+len(years)]=c_rate
# save c_rate_multi
DataOut = np.column_stack((years_c_multi,c_rate_multi))
#np.savetxt('./EmissionScenarios/carbon/carbon_emission_1Myr_22100GtC.dat', DataOut, fmt=('%i', '%.6f'))
np.savetxt('./examples/carbon_emission_1Myr_22100GtC_3rd4thPulseLong.dat', DataOut, fmt=('%i', '%.6f'))












