import numpy as np
import math
from netCDF4 import Dataset 
import matplotlib
import matplotlib.pyplot as plt
import mpl_toolkits.basemap
from mpl_toolkits.basemap import Basemap, addcyclic, shiftgrid
import pickle
import pandas as pd
from tools_analysis import grid_gen
import importlib
import shapefile
import tools_analysis as jan
importlib.reload(jan)
#import tools_analysis as jan
#reload(jan)



##Example:
#nc_f = './history_p2_nc_files/history_p2_O24p5_E0p000_P000.nc'  
#nc_fid = Dataset(nc_f, 'r')  
##nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)
#ts_ann = nc_fid.variables['ts_ann'][:]  # shape is time, lat, lon as shown above
#var=ts_ann[5700, :, :]; projection='rec';varname='var'; units='degC'; time='year'
#
#nc_f = './snapshots.004103.01.01.dta.nc'  
#nc_fid = Dataset(nc_f, 'r')  
#nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)
#temp = np.flipud(nc_fid.variables['temp'][0,0,:,:])
#var=temp; projection='rec';varname='var'; units='degC'; time='year'       

def plot_world(var, varname='var', units='degC', lim_l='None',lim_u='None', cbar_delta='None', cbar_tick_freq=2, \
               projection='robin', continents='off', continents_timeslice=201, continents_array=None,simrun='None', rotate='no',time='year', title='default', colourmap='RdBu_r',colourmap_shifted=False,colourmap_midvalue=0,colourmap_normalized=False,colourmap_normalized_levels=[0],axes='on', title_opt='on', fig_handle='None', var_digits=2, extend='neither', cbar_orient='hor', anno='on',text_corner=None,text_corner_posi=[0.95,0.925],latlabels=[False,False,False,False], plates='off',shallowmarine_contourf=False, interpol_var=False, smooth_var=False, interpol_var_cont=False,smooth_var_cont=False,anno_str='None',\
               vec='off',vectorx='None',vectory='None',vec_scale=None, vec_freq=1, vec_anno='on',vec_col='k', vec2='off', vectorx2='None',vectory2='None', vec3='off', vectorx3='None',vectory3='None',\
               cont=False, var_cont='None', var_cont_name='None', cont_anno='on', var_cont_unit='None', cont_label_fmt='%.2f',cont_label_units='False', cont_levs=None, cont_labels='on', linestyle=None, var_cont_digits=2, cont_col='k',cont_linewidth=1.1, \
               conts_misc=False,conts_misc_vars=[],conts_misc_levs=[],conts_misc_cols=[],conts_misc_lws=[],conts_misc_linestyles=[], conts_misc_cbarmark=False,\
              scat='off', var_scat='None', scat2='off', var_scat2='None', scat_colored='off', scat_colors='None', scat_size=50, var_scat_c='xkcd:bright blue' ,var_scat2_c='xkcd:green',
              hatches='off', var_hatch='None', hatch_levels='None',hatch_color='whitesmoke',hatch_style='...', hatch_linewidth=1.0,hatch_levels2='None', var_hatch2='None', hatch2_levels='None',hatch2_color='whitesmoke',hatch2_style='...',proxy_locs=[],
              dots='off',dots_colorarr=[],dots_sizearr=[], cell_numbering='off'):
    
    data_aggregated_geography_df=pd.read_pickle('./data_aggregated_geography_df.pkl')
    
    varmean=jan.global_mean(var)
    varmin=np.min(var)    
    varmax=np.max(var)
    if cont=='on':     
        var_cont_mean=jan.global_mean(var_cont)
        var_cont_min=np.min(var_cont)
        var_cont_max=np.max(var_cont)
    
    ### Interpolation from atmospheric to Ocean Grid       
    if interpol_var==True:
        var=jan.interpol_atm_to_ocn_grid(var)
    if interpol_var_cont==True:  
        var_cont=jan.interpol_atm_to_ocn_grid(var_cont)
        



    ### Defining Grids
    #,time='year' ,axorplt='def'
    latso, lonso, xo, yo=grid_gen(region='ocean', cell='T')             
    if len(var)==24:
        x, y=grid_gen(region='atmo', cell='p')[2:4]
        xu,yu=grid_gen(region='atmo', cell='u')[2:4]
        xt,yt=grid_gen(region='atmo', cell='T')[2:4]
        #varmean=jan.global_mean(var, grid='atmo')
        #if cont=='on':
        #    var_cont_mean=jan.global_mean(var_cont, grid='atmo')
        vel_unit='m/s'
    if len(var)==48:
        x, y=grid_gen(region='ocean', cell='p')[2:4] 
        xu,yu=grid_gen(region='ocean', cell='u')[2:4]
        xt,yt=grid_gen(region='ocean', cell='T')[2:4]
        #varmean=jan.global_mean(var, grid='ocean')
        #if cont=='on':
        #    var_cont_mean=jan.global_mean(var_cont, grid='ocean')            
        vel_unit='cm/s'
        

    xtc,ytc=grid_gen(region='ocean', cell='T')[2:4]
    if cont==True or hatches=='on':     
        if len(var_cont)==24 or len(var_hatch)==24:
            xtc,ytc=grid_gen(region='atmo', cell='T')[2:4]
        if len(var_cont)==48 or len(var_hatch)==48:
            xtc,ytc=grid_gen(region='ocean', cell='T')[2:4]
        if len(var_cont)==180 or len(var_hatch)==180:
            xtc,ytc=grid_gen(region='atmo_1deg', cell='T')[2:4]
        if len(var)==36  or len(var_hatch)==36:
            xtc,ytc=grid_gen(region='5deg', cell='T')[2:4]
    if vec=='on':     
        if len(vectorx)==24:
            xtv,ytv=grid_gen(region='atmo', cell='T')[2:4]
        if len(vectorx)==48:
            xtv,ytv=grid_gen(region='ocean', cell='T')[2:4]
    #if cont=='on':     
        #if len(var_cont)==24:
            #var_cont_mean=jan.global_mean(var_cont, grid='atmo')
        #if len(var_cont)==48:
            #var_cont_mean=jan.global_mean(var_cont, grid='ocean')
    if len(var)==180:
        x, y=grid_gen(region='atmo_1deg', cell='p')[2:4] 
        xu,yu=grid_gen(region='atmo_1deg', cell='u')[2:4]
        xt,yt=grid_gen(region='atmo_1deg', cell='T')[2:4]
        xtc,ytc=grid_gen(region='atmo_1deg', cell='T')[2:4]
        #varmean=jan.global_mean(var, grid='atmo_1deg')
        #if cont=='on':
            #var_cont_mean=jan.global_mean(var_cont, grid='atmo_1deg')            
        vel_unit='cm/s'
    if len(var)==36:
        x, y=grid_gen(region='5deg', cell='p')[2:4] 
        xu,yu=grid_gen(region='5deg', cell='u')[2:4]
        xt,yt=grid_gen(region='5deg', cell='T')[2:4]
        xtc,ytc=grid_gen(region='5deg', cell='T')[2:4]
    if len(var)==1080:
        x, y=grid_gen(region='1/6deg', cell='p')[2:4] 
        xu,yu=grid_gen(region='1/6deg', cell='u')[2:4]
        xt,yt=grid_gen(region='1/6deg', cell='T')[2:4]
        xtc,ytc=grid_gen(region='1/6deg', cell='T')[2:4]
    if len(var)==73:
        x, y=grid_gen(region='2.5x3.75deg', cell='p')[2:4] 
        xu,yu=grid_gen(region='2.5x3.75deg', cell='u')[2:4]
        xt,yt=grid_gen(region='2.5x3.75deg', cell='T_2')[2:4]
        xtc,ytc=grid_gen(region='2.5x3.75deg', cell='T_2')[2:4] 
#        xt,yt=grid_gen(region='grid_valdes', cell='T')[2:4]
#        xtc,ytc=grid_gen(region='grid_valdes', cell='T')[2:4]        


    if smooth_var==True:
        var=jan.interpol_other_to_3p75_grid(var,xt,yt,3.75)
    if smooth_var_cont==True:  
        var_cont=jan.interpol_other_to_3p75_grid(var_cont,xt,yt,3.75)    
    
    ### Rotate by 90 degree
    if rotate=='yes':
        var=np.ma.concatenate((var[:,len(var):],var[:,0:len(var)]),axis=1)
        var_cont=np.ma.concatenate((var_cont[:,len(var_cont):],var_cont[:,0:len(var_cont)]),axis=1)
        vectorx=np.ma.concatenate((vectorx[:,len(var):],vectorx[:,0:len(var)]),axis=1)
        vectory=np.ma.concatenate((vectory[:,len(var):],vectory[:,0:len(var)]),axis=1)
    


    if var_digits==0:
        varmin_disp=str(int(round(varmin)))
        varmean_disp=str(int(round(varmean)))
        varmax_disp=str(int(round(varmax)))
    else:
        varmin_disp=str(np.round(varmin, var_digits))
        varmean_disp=str(np.round(varmean, var_digits))
        varmax_disp=str(np.round(varmax, var_digits))
    var_disp_string=varname+' ['+varmin_disp+', '+varmean_disp+', '+varmax_disp+'] '+units
    if cont==True:
        if var_cont_digits==0:
            var_cont_min_disp=str(int(round(var_cont_min)))
            var_cont_mean_disp=str(int(round(var_cont_mean)))
            var_cont_max_disp=str(int(round(var_cont_max))) 
        else:
            var_cont_min_disp=str(np.round(var_cont_min, var_cont_digits))
            var_cont_mean_disp=str(np.round(var_cont_mean, var_cont_digits))
            var_cont_max_disp=str(np.round(var_cont_max, var_cont_digits))
        var_cont_disp_string=var_cont_name+' ['+var_cont_min_disp+', '+var_cont_mean_disp+', '+var_cont_max_disp+'] '+var_cont_unit
 
    vmin=varmin
    vmax=varmax    
    if lim_l!='None':
        vmin=lim_l
    if lim_u!='None':
        vmax=lim_u
    if extend!='neither':
        if extend=='both':
            vmin=lim_l-cbar_delta
            vmax=lim_u+cbar_delta
        elif extend=='min':
            vmin=lim_l-cbar_delta
            vmax=lim_u
        elif extend=='max':
            vmin=lim_l
            vmax=lim_u+cbar_delta            
    if cbar_delta!='None':                  
        if extend!='neither':
            if extend=='both':
                boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta, \
                                     math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta, cbar_delta)
                ticks=boundaries[1:-1:cbar_tick_freq]
            elif extend=='min':
                boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta, \
                                     math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta, cbar_delta)
                ticks=boundaries[1::cbar_tick_freq]   
            elif extend=='max':
                boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta, \
                                     math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta, cbar_delta)
                ticks=boundaries[::cbar_tick_freq]                 
        else:    
            #boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta, math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta, cbar_delta) #
            boundaries=np.arange(vmin, vmax+cbar_delta, cbar_delta)
            ticks=boundaries[::cbar_tick_freq]
    else:
        boundaries=np.arange(vmin,vmax,(vmax-vmin)/50)
        ticks=boundaries[::cbar_tick_freq]

    #if cbar_tick_freq=='None'
        #cbar_tick_freq=2

    if continents_timeslice==201:
        with open('continental_mask.pickle', 'rb') as f:  
            continental_mask = pickle.load(f, encoding='latin1') #Python3: , encoding='latin1'
    elif continents_timeslice!=201:
        if type(continents_timeslice)==int:
            timeslice='{:03d}'.format(continents_timeslice)
            #continental_mask=np.loadtxt('../MesoClim/c3a_meso_InputFiles/'+timeslice+'Ma/depth_adj.dat', dtype=int)[:,:]
            continental_mask=data_aggregated_geography_df.loc['../c3a_model_input_output/'+timeslice+'Ma_1000ppm',['continental_mask']][0]
            #continental_mask=np.loadtxt('../c3a_meso_input/'+timeslice+'Ma_1000ppm/depth_adj.dat', dtype=int)[:,:]
            #continental_mask[continental_mask>0]=1
            
            xt_lomask,yt_lomask=grid_gen(region='ocean', cell='T')[2:4]
    if isinstance(continents_array, np.ndarray):
        continental_mask=np.copy(continents_array)
        if len(continents_array)==36:
            xt_lomask,yt_lomask=grid_gen(region='5deg', cell='T')[2:4]
        if len(continents_array)==48:
            xt_lomask,yt_lomask=grid_gen(region='ocean', cell='T')[2:4]
        if len(continents_array)==180:
            xt_lomask,yt_lomask=grid_gen(region='atmo_1deg', cell='T')[2:4]
        if len(continents_array)==1080:
            xt_lomask,yt_lomask=grid_gen(region='1/6deg', cell='T')[2:4]
        if len(continents_array)==72:
            xt_lomask,yt_lomask=grid_gen(region='2.5x3.75deg', cell='T')[2:4]
        if len(continents_array)==73:
            xt_lomask,yt_lomask=grid_gen(region='2.5x3.75deg', cell='T_2')[2:4]
        if len(continents_array)==94:
            xt_lomask,yt_lomask=grid_gen(region='NCAR-NCEP_grid')[2:4]
        
    #if type(continents_timeslice)==str:
    if simrun!='None':
            #timeslice='{:03d}'.format(continents_timeslice)
            continental_mask=np.loadtxt('../c3a_meso_input/'+simrun+'/depth_adj.dat', dtype=int)[:,:]
            continental_mask[continental_mask>0]=1   
            
        
        
    if fig_handle=='None':
        fig = plt.figure(figsize=((8,5))) #,1
        ax= fig.add_subplot(111)
        #ax_cbar=ax
    else:
        fig=fig_handle[0]
        ax=fig_handle[1]
        #ax_cbar=ax #fig_handle[2]
        
#     #plt.clf()
#     #fig.subplots_adjust(left=0., right=1., bottom=0., top=0.9)
#     if projection=='moll':     # Setup the map. See http://matplotlib.org/basemap/users/mapsetup.html for other projections.
#         m = Basemap(projection='moll', llcrnrlat=-90, urcrnrlat=90,llcrnrlon=-180, urcrnrlon=180, resolution='c', lon_0=0)
#         #m = Basemap(projection='robin', llcrnrlat=-90, urcrnrlat=90,llcrnrlon=-180, urcrnrlon=180, resolution='c', lon_0=0)
#         m.drawmapboundary()        
#         x, y = m(x, y) # Transforms lat/lon into plotting coordinates for projection
#         xo, yo = m(xo, yo)
#         xtm,ytm=m(xt,yt)   
#         if cont=='on' or hatches=='on':
#             xtcm,ytcm=m(xtc,ytc)
#         if scat=='on':
#             x_scat, y_scat=m(var_scat[:,0],var_scat[:,1])
#             if scat2=='on':
#                 x_scat2, y_scat2=m(var_scat2[:,0],var_scat2[:,1])
#     elif projection=='rec':
#         xtm,ytm=xt,yt
#         if cont=='on':
#             xtcm,ytcm=xtc,ytc
            
    #plt.clf()
    #fig.subplots_adjust(left=0., right=1., bottom=0., top=0.9)
    #if projection=='moll':     # Setup the map. See http://matplotlib.org/basemap/users/mapsetup.html for other projections.
    m = Basemap(projection=projection, llcrnrlat=-90, urcrnrlat=90,llcrnrlon=-180, urcrnrlon=180, resolution='c', lon_0=0,lat_0=0,ax=ax)
    #m = Basemap(projection='robin', llcrnrlat=-90, urcrnrlat=90,llcrnrlon=-180, urcrnrlon=180, resolution='c', lon_0=0)
    m.drawmapboundary()        
    x, y = m(x, y) # Transforms lat/lon into plotting coordinates for projection
    xo, yo = m(xo, yo)
    xtm,ytm=m(xt,yt)  
    xtcm,ytcm=m(xtc,ytc)
    if continents=='on':
        xt_lomask_m,yt_lomask_m=m(xt_lomask,yt_lomask)
    
    if cont==True or hatches==True or conts_misc==True:
        xtcm,ytcm=m(xtc,ytc)
    if scat=='on':
        x_scat, y_scat=m(var_scat[:,0],var_scat[:,1])
    if scat2=='on':
        x_scat2, y_scat2=m(var_scat2[:,0],var_scat2[:,1])
#     elif projection=='rec':
#         xtm,ytm=xt,yt
#         if cont=='on':
#             xtcm,ytcm=xtc,ytc
    if vec=='on':
        xtvm,ytvm=m(xtv,ytv) 
    if plates=='on':
        m.readshapefile('../c3a_meso_reconstructions/plate_polygons/'+'{:03d}'.format(continents_timeslice)+'Ma_reconstruction_polygons', \
                        'polygons', color='tab:grey', linewidth=1)
        
    if colourmap_shifted==True:
        mid=colourmap_midvalue
        if abs(lim_u-mid)>=abs(lim_l-mid):
            mycmap = shiftedColorMap(plt.cm.get_cmap(colourmap), start=((lim_l-mid)+(lim_u-mid))/(2*(lim_u-mid)))            
            #midpoint=abs(lim_l)/(lim_u+lim_l)
        elif abs(lim_u-mid)<abs(lim_l-mid):
            mycmap = shiftedColorMap(plt.cm.get_cmap(colourmap), stop=(abs(lim_l-mid)+lim_u-mid)/(2*abs(lim_l-mid)))    
    #if colourmap_shifted==True:
        #if abs(lim_u)>abs(lim_l):
            #mycmap = shiftedColorMap(matplotlib.cm.Spectral_r, start=(lim_l+lim_u)/(2*lim_u)) #midpoint=abs(lim_l)/(lim_u+lim_l)
        #elif abs(lim_u)<abs(lim_l):
            #mycmap = shiftedColorMap(matplotlib.cm.Spectral_r, stop=(abs(lim_l)+lim_u)/(2*abs(lim_l)))
    #elif colourmap=='Spectral_shifted_r':
        #if abs(lim_u)>abs(lim_l):
            #mycmap = shiftedColorMap(matplotlib.cm.Spectral, start=(lim_l+lim_u)/(2*lim_u)) #midpoint=abs(lim_l)/(lim_u+lim_l)
        #elif abs(lim_u)<abs(lim_l):
            #mycmap = shiftedColorMap(matplotlib.cm.Spectral, stop=(abs(lim_l)+lim_u)/(2*abs(lim_l)))  
    else:
        mycmap= plt.cm.get_cmap(colourmap,len(boundaries)) #50
        
    if colourmap_normalized==True:
        norm=PiecewiseNorm(colourmap_normalized_levels)
    else:
        norm=None
        
    a=ax.pcolor(x, y, var,cmap=mycmap,vmin=vmin,vmax=vmax,norm=norm) #m.contourf(x, y, air_cyclic, 11, cmap=plt.cm.Spectral_r)
    
    if continents=='on':
        if len(var)==24:
            ax.contour(xo,yo,continental_mask,levels=[0.5], linewidths=1.75)
        else: # if len(var)==36 or len(var)==180 or len(var)==1080:
            ax.contour(xt_lomask_m,yt_lomask_m,continental_mask,levels=[0.5], linewidths=1.75)
        
    if cont==True:
        #if cont_levs=='None':
            #cont=ax.contour(xtm, ytm, var_cont, colors='k', linewidths=0.9, linestyles=linestyle) #6 , levels=[-10, -5, 0, 5, 15, 20, 25, 30]
        #else:
        if len(var_cont)==24:           
            xtc_plot_pad=np.pad(np.pad(xtc, ((1,1),(0,0)), mode='edge'), ((0,0),(1,1)), mode='wrap')
            xtc_plot_pad[:,0]=-180.; xtc_plot_pad[:,-1]=180.
            ytc_plot_pad=np.pad(np.pad(ytc, ((1,1),(0,0)), mode='edge'), ((0,0),(1,1)), mode='wrap')
            ytc_plot_pad[0,:]=90.; ytc_plot_pad[-1,:]=-90. 
            
            var_cont_plot_pad=np.pad(np.pad(var_cont, ((1,1),(0,0)), mode='edge'), ((0,0),(1,1)), mode='wrap')
            var_cont_plot_pad[:,0]=(var_cont_plot_pad[:,1]+var_cont_plot_pad[:,-2])/2
            var_cont_plot_pad[:,-1]=var_cont_plot_pad[:,0]
            
            xtcm_plot_pad,ytcm_plot_pad=m(xtc_plot_pad,ytc_plot_pad)
            cont=ax.contour(xtcm_plot_pad,ytcm_plot_pad, var_cont_plot_pad, levels=cont_levs, colors=cont_col, linewidths=cont_linewidth, linestyles=linestyle, antialiased=True,zorder=6)            
            
            #xta,yta=grid_gen(region='atmo', cell='T')[2:4]
            
            #xta_plot_pad=np.pad(np.pad(xta, ((1,1),(0,0)), mode='edge'), ((0,0),(1,1)), mode='wrap')
            #xta_plot_pad[:,0]=-180.; xta_plot_pad[:,-1]=180.
            #yta_plot_pad=np.pad(np.pad(yta, ((1,1),(0,0)), mode='edge'), ((0,0),(1,1)), mode='wrap')
            #yta_plot_pad[0,:]=90.; yta_plot_pad[-1,:]=-90. 
            
            #var_cont_plot_pad=np.pad(np.pad(var_cont, ((1,1),(0,0)), mode='edge'), ((0,0),(1,1)), mode='wrap')
            #var_cont_plot_pad[:,0]=(var_cont_plot_pad[:,1]+var_cont_plot_pad[:,-2])/2
            #var_cont_plot_pad[:,-1]=var_cont_plot_pad[:,0]
            
            #xtam_plot_pad,ytam_plot_pad=m(xta_plot_pad,yta_plot_pad)
            #cont=ax.contour(xtam_plot_pad,ytam_plot_pad, var_cont_plot_pad, levels=cont_levs, colors=cont_col, linewidths=1.1, linestyles=linestyle, antialiased=True,zorder=6)
        else:    
            cont=ax.contour(xtcm, ytcm, var_cont, levels=cont_levs, colors=cont_col, linewidths=cont_linewidth, linestyles=linestyle, antialiased=True,zorder=6) #0.9
        if cont_labels!='off':
            if cont_label_units!='False':                    
                # Recast levels to new class
                #cont.levels = [nf(val) for val in cont.levels]
                # Label levels with specially formatted floats
                #if plt.rcParams["text.usetex"]:
                #    fmt = r'%r \%%'
                #else:
                #    fmt = '%r %%'
                #fmt='%d $\mathsf{^{\circ}C}$'
                fmt='{}{}'.format(cont_label_fmt, cont_label_units)
                ax.clabel(cont, cont.levels, fontsize=12, inline=True, fmt=fmt) ##fmt='%d'
            else:
                ax.clabel(cont, fontsize=12, inline=True, fmt=cont_label_fmt)
    #if cont2=='on':
        #cont2=ax.contour(xtcm, ytcm, var_cont2, levels=cont_levs2, colors=cont2_col, linewidths=cont_linewidth, linestyles=linestyle, antialiased=True,zorder=6) #0.9
        #if cont_labels!='off':
            #ax.clabel(cont2, fontsize=11, inline=1, fmt=cont_label_fmt) ##fmt='%d'   
    #if cont3=='on':
        #cont3=ax.contour(xtcm, ytcm, var_cont3, levels=cont_levs3, colors=cont3_col, linewidths=cont_linewidth, linestyles=linestyle, antialiased=True,zorder=6) #0.9
        #if cont_labels!='off':
            #ax.clabel(cont3, fontsize=11, inline=1, fmt=cont_label_fmt)
    if conts_misc==True:
        for cc in range(len(conts_misc_vars)):
            ax.contour(xtcm, ytcm, conts_misc_vars[cc], levels=conts_misc_levs[cc], colors=conts_misc_cols[cc], linewidths=conts_misc_lws[cc], linestyles=conts_misc_linestyles[cc], antialiased=True,zorder=5)
            #if conts_misc_cbarmark==True:
                #for ccll in range(len(conts_misc_levs[cc])):
                    #cbar.ax.plot([(0-lim_l)/(lim_u-lim_l),(0-lim_l)/(lim_u-lim_l)],[0,1], 'blue',lw=3)

        

            
    if vec=='on': #vectorx.any()!='None':
        vec_mean=np.ma.mean(np.sqrt(vectorx**2+vectory**2))
        if projection=='cyl':          
            q=ax.quiver(xtv[::vec_freq, ::vec_freq],ytv[::vec_freq, ::vec_freq],vectorx[::vec_freq, ::vec_freq],vectory[::vec_freq, ::vec_freq],scale=vec_scale,width=0.002,color=vec_col)
            if vec_anno=='on':
                ax.quiverkey(q, 0.80, 0.96, vec_mean, r'mean: %.2f '%(vec_mean)+vel_unit, labelpos='E', coordinates='figure', fontproperties={'size': 12})            
        if projection=='robin' or projection=='moll':
            #q=m.quiver(xtm[0::2,0::2],ytm[0::2,0::2],vectorx[0::2,0::2],vectory[0::2,0::2],width=0.002)
            q=m.quiver(xtvm[::vec_freq, ::vec_freq],ytvm[::vec_freq, ::vec_freq],vectorx[::vec_freq, ::vec_freq],vectory[::vec_freq, ::vec_freq],scale=vec_scale,width=0.0015,color=vec_col)            
            #q=m.quiver(xtm,ytm,vectorx,vectory, width=0.0015, scale=vec_mean*8, scale_units='inches')
            if vec_anno=='on':
                ax.quiverkey(q, 0.12, 0.205, vec_mean, r'mean: %.2f '%(vec_mean)+vel_unit, labelpos='E',coordinates='figure', fontproperties={'size': 12})
        if vec2=='on': #       
            ax.quiver(xt[::vec_freq, ::vec_freq],yt[::vec_freq, ::vec_freq],vectorx2[::vec_freq, ::vec_freq],vectory2[::vec_freq, ::vec_freq], scale=vec_scale,color='tab:red',width=0.002)
        if vec3=='on': #         
            ax.quiver(xt[::vec_freq, ::vec_freq],yt[::vec_freq, ::vec_freq],vectorx3[::vec_freq, ::vec_freq],vectory3[::vec_freq, ::vec_freq] ,scale=vec_scale,color='tab:green',width=0.002)

    if scat=='on':
        if scat_colored=='on':
            scat=ax.scatter(x_scat, y_scat,edgecolors='k', zorder=5, s=scat_size, c=scat_colors,cmap=mycmap,vmin=vmin,vmax=vmax)     
        else:
            scat=ax.scatter(x_scat, y_scat,edgecolors='k', c=var_scat_c, zorder=5, s=scat_size) #, marker='*', s=125
        if scat2=='on':
            scat2=ax.scatter(x_scat2, y_scat2,edgecolors='k', c=var_scat2_c, zorder=5, s=scat_size) 
    
    if dots=='on':
        ax.scatter(xtm, ytm, s=dots_sizearr, zorder=10,marker='o',c=dots_colorarr,edgecolors='k') #,edgecolors='k'
            
    if hatches=='on':
        matplotlib.rcParams['hatch.color']=hatch_color #'whitesmoke' #'navy' #
        matplotlib.rcParams['hatch.linewidth'] = hatch_linewidth
        ax.contourf(xtcm, ytcm, var_hatch, hatches=[hatch_style,None], levels=hatch_levels, colors='none',zorder=5)
        if hatch_levels2!='None':
            #matplotlib.rcParams['hatch.color']=hatch_color[1] #'whitesmoke' #'navy' #
            #ax.contourf(xtcm, ytcm, var_hatch, hatches=['...',None], levels=hatch_levels2, colors='none')
            ax.contourf(xtcm, ytcm, var_hatch, hatches=['xxx',None], levels=hatch_levels2, colors='none')
        if var_hatch2!='None':
            matplotlib.rcParams['hatch.color']=hatch2_color #'whitesmoke' #'navy' #
            ax.contourf(xtcm, ytcm, var_hatch2, hatches=[hatch2_style,None], levels=hatch2_levels, colors='none',zorder=5)
    
    if proxy_locs!=[]:
        proxy_distr_df = pd.read_pickle('../c3a_meso_reconstructions/proxy_distributions.pkl') 
        proxy_keywords=['coal','evap','glacial','coral']
        proxy_columns=['lonlat_Cao2018_coal','lonlat_Cao2018_evap','lonlat_Cao2018_glacial','lonlat_coral']
        proxy_colors=['tab:brown','gold','cyan','tab:blue']
        for pp in proxy_locs:
            proxy_idx=proxy_keywords.index(pp)
            lonlats_recon=proxy_distr_df.loc[continents_timeslice,[proxy_columns[proxy_idx]]].values[0]
            if np.all(lonlats_recon==0)==False:
                lonlats_reconm=np.zeros(lonlats_recon.shape)
                lonlats_reconm[:,0],lonlats_reconm[:,1] = m(lonlats_recon[:,0],lonlats_recon[:,1])
                #ax2.plot(lonlats_reconm[:,0], lonlats_reconm[:,1],'.', color=proxy_colors[proxy_idx])
                ax.scatter(lonlats_reconm[:,0], lonlats_reconm[:,1],edgecolors='k', linewidths=0.5, zorder=7, c=proxy_colors[proxy_idx], alpha=1, s=40) #'gold'
      
    if shallowmarine_contourf==True:
        data_aggregated_geography_df=pd.read_pickle('./data_aggregated_geography_df.pkl')
        shallowcoastal_mask=data_aggregated_geography_df.loc['../c3a_meso_timeslice/'+simrun,['shallowcoastal_mask']][0]
        ax.contourf(xtcm, ytcm, 1-shallowcoastal_mask,levels=[0.5,1.0],colors=['gold'],alpha=0.25,zorder=4)
        #matplotlib.rcParams['hatch.color']='gold'
        #ax.contourf(xtcm, ytcm, 1-shallowcoastal_mask, hatches=['...',None], colors=None,levels=[0.5,1.0],alpha=0.33,zorder=4)
        #,colors=['gold']
        
    if title_opt!='off':
        if title_opt=='bottom':
            ax.set_title(title, fontsize=14, pad=-250)
        else:
            if title=='default':
                ax.set_title("%s in %s" % (varname, time), fontsize=14)
            else:
                ax.set_title(title, fontsize=14)
        #ax.set_title("%s in %s" % (varname, time))
    
    if projection=='cyl':
        ax.set_xticks(np.arange(-180,240,60)); ax.set_yticks(np.arange(-90,120,30));
        if axes=='on':
            cbar=fig.colorbar(a, ax=ax, boundaries=boundaries, ticks=ticks, orientation='vertical', pad=0.02, extend=extend, shrink=0.621)
            cbar.set_label("%s (%s)" % (varname, units), fontsize=12)
            cbar.ax.set_yticklabels(cbar.ax.get_yticklabels(), fontsize=14);
            ax.set_xlabel('Longitude ($^{\circ}$)', fontsize=12); ax.set_ylabel('Latitude ($^{\circ}$)', fontsize=12); 
            ax.tick_params(labelsize=12)
            ax.annotate(var_disp_string, xy=(0.7,-0.18), xycoords='axes fraction', fontsize=12)
            if cont!=False and cont_anno!='off':
                ax.annotate(var_cont_disp_string, xy=(-0.1,-0.12), xycoords='axes fraction', fontsize=12)  
    
    if projection=='moll' or projection=='robin': 
        m.drawparallels([-60,-30,0,30,60], linewidth=0.5,labels=latlabels) #[True,True,True,True,True,True,True]np.arange(-89.,120.,30.)
        m.drawmeridians(np.arange(0.,420.,60.), linewidth=0.5)
        if axes=='on':
            if cbar_opt==True:
                if cbar_orient=='hor':
                    cbar=fig.colorbar(a, ax=ax, orientation='horizontal',boundaries=boundaries, ticks=ticks, \
                                shrink=0.7 , pad=0.06, extend=extend) #drawedges=True, #cbar.dividers.set_color('white') 
                    #cbar=fig.colorbar(a, ax=ax, orientation='horizontal',boundaries=boundaries, ticks=ticks, \
                    #               shrink=0.46 , pad=0.06, extend=extend) #drawedges=True, #cbar.dividers.set_color('white') 
                if cbar_orient=='vert':
                    cbar=fig.colorbar(a, ax=ax, orientation='vertical',boundaries=boundaries, ticks=ticks, \
                                pad=0.01,shrink=0.65, extend=extend) #drawedges=True, #cbar.dividers.set_color('white')                
                #a.set_clim([boundaries[0], boundaries[-1]])
                cbar.set_label("%s (%s)" % (varname, units), fontsize=12)
                cbar.ax.tick_params(labelsize=12);
            if anno=='on':
                ax.annotate(var_disp_string, xy=(0.6,-0.058), xycoords='axes fraction', fontsize=12) #xy=(0.53,-0.058)
                if cont!=False and cont_anno!='off':
                    ax.annotate(var_cont_disp_string, xy=(0.02,-0.058), xycoords='axes fraction', fontsize=12)
            if anno_str!='None':
                ax.text(0.5,-0.055,anno_str,transform=ax.transAxes,ha='center', fontsize=12)
                
    if conts_misc==True and axes=='on':
        for cc in range(len(conts_misc_vars)):
            if conts_misc_cbarmark==True:
                for ccll in range(len(conts_misc_levs[cc])):
                    cbar.ax.plot([(conts_misc_levs[cc][ccll]-lim_l)/(lim_u-lim_l),(conts_misc_levs[cc][ccll]-lim_l)/(lim_u-lim_l)],[0,1], conts_misc_cols[cc][ccll],lw=3)    
                    
    if text_corner!=None:
        ax.text(text_corner_posi[0],text_corner_posi[1],text_corner, horizontalalignment='center',verticalalignment='center',fontweight='bold',transform=ax.transAxes,fontsize=14)
        
    if cell_numbering=='on':         
        for ii in range(xtm.shape[0]):
            for jj in range(xtm.shape[1]):
                label=str(ii)+'\n'+str(jj)
                ax.text(xtm[ii,jj],ytm[ii,jj],label,ha='center',va='center',size=4.5, color='white', alpha=0.75) #,ha='center',va='center',size=5.5, color='white'

    
    fig.patch.set_facecolor('w')
    plt.tight_layout() #pad=0.01, h_pad=1, w_pad=0.01
    return a, fig


def plot_cross(var, varname='var', colourmap='RdBu_r',colourmap_shifted=False,colourmap_normalized=False, lim_l='None', lim_u='None', cbar_opt=True,cbar_delta='None', cbar_tick_freq=2,extend='neither', \
               units='degC', time='year', title='default', fig_handle='None', crosssec_loc='None', axes='on', title_opt='on', linestyle='None', var_digits=2, ylabel='Depth (m)',ylims=False,ylevs=False,ylog=False,yinv=False,\
               cont=False, var_cont='None', var_cont_name='None', var_cont_unit='None', cont_levs='None', cont_anno=False, cont_label_fmt='%.2f', cont_labels='on', var_cont_digits=2, cont_col='k',cont_linewidth=1.1, \
               conts_misc=False,conts_misc_vars=[],conts_misc_levs=[],conts_misc_cols=[],conts_misc_lws=[],conts_misc_linestyles=[], conts_misc_cbarmark=False,\
               vec='off', vectorx='None', vectory='None', vec_freq=1):
    vel_unit='cm/s'
    
    nc_f = './topog.dta.nc'
    nc_fid = Dataset(nc_f, 'r')
    zw_k=nc_fid.variables['zw_k']
    zt_k=nc_fid.variables['zt_k']
    y_p=grid_gen(region='ocean', cell='p')[0] #c_fid.variables['yt_j']
    y_t=grid_gen(region='ocean', cell='T')[0]
    

    if isinstance(ylevs,np.ndarray)==False:
        if var.shape[0]==24: #Climber3a MOM levels
            ylevs=zt_k
        else:
            if var.shape[0]==14:
                ylevs=[    0.,   202.,   410.,   842.,  1299.,  1784.,  2851.,  4083.,
                5540.,  7324.,  9623., 12864., 18404., 30000.] #rounded zlevw from CLIMBER-X atm.nc
                #altis=zlevw
            if var.shape[0]==20:
                ylevs=[  0. ,   25.2,   50.4,  100.8,  201.6,  302.4,  403.2,
                        504. ,  604.8,  806.4, 1008. , 1260. , 1511.9, 1763.9,
                    2015.9, 2519.9, 3023.9, 3527.9, 4031.9, 4535.8]
                

            #altis=[-zlay[0]]+zlay+[zlevw[-1]+(zlevw[-1]-zlay[-1])]
            #altis=[zlevw[0]-1./2*(zlevw[1]-zlevw[0])]+zlevw+[zlevw[-1]+1./2*(zlevw[-1]-zlevw[-2])]

        #else:
            #lats=y_p[::-1]
            #altis-1*zw_k[:]  
        
    #altis=np.copy(zw_k)    
    #altis=[]
    altis=[ylevs[0]-1./2*(ylevs[1]-ylevs[0])]
    for ii in range(len(ylevs)-1):
        altis.append(ylevs[ii]+1./2*(ylevs[ii+1]-ylevs[ii]))
    altis.append(ylevs[-1]+1./2*(ylevs[-1]-ylevs[-2]))
    
    lats=[]
    if var.shape[1]==48:
        lats=grid_gen(region='ocean', cell='T')[0]
    if var.shape[1]==37:
        lats=grid_gen(region='5deg', cell='p')[0]
    if var.shape[1]==72 :
        lats=grid_gen(region='2.5x3.75deg', cell='p')[0]        
        
    if ylog==True:
        altis_orig=np.copy(altis)
        altis=np.log(altis)        
    x,y = np.meshgrid(lats[::-1], altis)

    if cont==True:            
        if var_cont.shape==(25,48):
            x_c, y_c=np.meshgrid(y_t[::-1], altis)
        elif var_cont.shape==(24,48):
            x_c, y_c=np.meshgrid(y_t[::-1], -1*zt_k[:])
        #if cont_anno==True:
            #var_cont_mean=jan.global_mean(var_cont, region='crosssec', grid='ocean')
        var_cont_min=np.min(var_cont)
        var_cont_max=np.max(var_cont)
    if conts_misc==True:            
        if conts_misc_vars[0].shape==(25,48):
            x_cmisc, y_cmisc=np.meshgrid(y_t[::-1], altis)
        elif conts_misc_vars[0].shape==(24,48):
            x_cmisc, y_cmisc=np.meshgrid(y_t[::-1], zt_k[:])
    
    varmean=jan.global_mean(var, region='crosssec') #, grid='ocean'
    varmin=np.min(var)    
    varmax=np.max(var)
    
 
    
    #varmin=np.min(var)    
    #varmax=np.max(var)
    #if cont=='on':
    #    var_cont_min=np.min(var_cont)
    #    var_cont_max=np.max(var_cont)
    if var_digits==0:
        varmin_disp=str(int(round(varmin)))
        varmean_disp=str(int(round(varmean)))
        varmax_disp=str(int(round(varmax)))
    if var_cont_digits==0 and cont==True and cont_anno==True:
        var_cont_min_disp=str(int(round(var_cont_min)))
        var_cont_mean_disp=str(int(round(var_cont_mean)))
        var_cont_max_disp=str(int(round(var_cont_max))) 
    else:
        varmin_disp=str(np.round(varmin, var_digits))
        varmean_disp=str(np.round(varmean, var_digits))
        varmax_disp=str(np.round(varmax, var_digits))
        if cont==True and cont_anno==True:
            var_cont_min_disp=str(np.round(var_cont_min, var_cont_digits))
            #var_cont_mean_disp=str(np.round(var_cont_mean, var_cont_digits))
            var_cont_max_disp=str(np.round(var_cont_max, var_cont_digits))
    #var_disp_string=varname+' ['+varmin_disp+', '+varmean_disp+', '+varmax_disp+'] '+units
    var_disp_string=varname+' ['+varmin_disp+', '+varmax_disp+'] '+units
    if cont==True and cont_anno==True:
        #var_cont_disp_string=var_cont_name+' ['+var_cont_min_disp+', '+var_cont_mean_disp+', '+var_cont_max_disp+'] '+var_cont_unit
        var_cont_disp_string=var_cont_name+' ['+var_cont_min_disp+', '+var_cont_max_disp+'] '+var_cont_unit
        
    #vmin=varmin
    #vmax=varmax    
    #if lim_l!='None':
        #vmin=lim_l
    #if lim_u!='None':
        #vmax=lim_u
    #if cbar_delta!='None':
        #boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta,math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta,cbar_delta)
    #else:
        #boundaries=np.arange(vmin,vmax,(vmax-vmin)/50)
    ##if cbar_tick_freq=='None'
        ##cbar_tick_freq=2
        
    vmin=varmin
    vmax=varmax    
    if lim_l!='None':
        vmin=lim_l
    if lim_u!='None':
        vmax=lim_u
    if extend!='neither':
        if extend=='both':
            vmin=lim_l-cbar_delta
            vmax=lim_u+cbar_delta
        elif extend=='min':
            vmin=lim_l-cbar_delta
            vmax=lim_u
        elif extend=='max':
            vmin=lim_l
            vmax=lim_u+cbar_delta            
    if cbar_delta!='None':                  
        if extend!='neither':
            if extend=='both':
                boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta, \
                                     math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta, cbar_delta)
                ticks=boundaries[1:-1:cbar_tick_freq]
            elif extend=='min':
                boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta, \
                                     math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta, cbar_delta)
                ticks=boundaries[1::cbar_tick_freq]   
            elif extend=='max':
                boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta, \
                                     math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta, cbar_delta)
                ticks=boundaries[::cbar_tick_freq]                 
        else:    
            #boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta, math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta, cbar_delta) #
            boundaries=np.arange(vmin, vmax+cbar_delta, cbar_delta)
            ticks=boundaries[::cbar_tick_freq]
    else:
        boundaries=np.arange(vmin,vmax,(vmax-vmin)/50)
        ticks=boundaries[::cbar_tick_freq]
        
    if fig_handle=='None':
        fig = plt.figure(figsize=((7.5,5.15))) #,1
        ax= fig.add_subplot(111)
    else:
        fig=fig_handle[0]
        ax=fig_handle[1]

        
    #if colourmap=='Spectral_r_shifted':
        #mycmap = shiftedColorMap('Spectral_r',start=(abs(lim_u)-abs(lim_l))/(2*lim_u), name='shifted')
    #elif colourmap=='Spectral_shifted':
        #mycmap = shiftedColorMap('Spectral',start=(abs(lim_u)-abs(lim_l))/(2*lim_u), name='shifted')
    #else:
        #mycmap= plt.cm.get_cmap(colourmap,len(boundaries)) #50
        
    if colourmap_shifted==True:
        mid=colourmap_midvalue
        if abs(lim_u-mid)>=abs(lim_l-mid):
            mycmap = shiftedColorMap(plt.cm.get_cmap(colourmap), start=((lim_l-mid)+(lim_u-mid))/(2*(lim_u-mid)))            
            #midpoint=abs(lim_l)/(lim_u+lim_l)
        elif abs(lim_u-mid)<abs(lim_l-mid):
            mycmap = shiftedColorMap(plt.cm.get_cmap(colourmap), stop=(abs(lim_l-mid)+lim_u-mid)/(2*abs(lim_l-mid)))    
    else:
        mycmap= plt.cm.get_cmap(colourmap,len(boundaries)) #50        
    if colourmap_normalized==True:
        norm=PiecewiseNorm(colourmap_normalized_levels)
    else:
        norm=None
        
    a=ax.pcolor(x, y, var,cmap=mycmap,vmin=vmin,vmax=vmax) #m.contourf(x, y, air_cyclic, 11, cmap=plt.cm.Spectral_r)
    #ax.set_ylim([-3500,0]);
    ax.set_xlabel('Latitude ($^{\circ}$)', fontsize=12); 
    ax.set_ylabel(ylabel, fontsize=12); 
    if cont==True:
        if cont_levs=='None':
            cont=ax.contour(x_c, y_c, var_cont, colors='k', linewidths=1.5,zorder=6) #, linestyles=linestyle
        else:
            cont=ax.contour(x_c, y_c, var_cont, levels=cont_levs, colors='k', linewidths=1.5,zorder=6) #, linestyles=linestyle
        if cont_labels!='off':
            ax.clabel(cont, fontsize=11, inline=1, fmt=cont_label_fmt) #fmt='%d'
    if conts_misc==True:
        for cc in range(len(conts_misc_vars)):
            ax.contour(x_cmisc, y_cmisc, conts_misc_vars[cc], levels=conts_misc_levs[cc], colors=conts_misc_cols[cc], linewidths=conts_misc_lws[cc], linestyles=conts_misc_linestyles[cc],zorder=5) #, antialiased=True
    if vec=='on':
        vec_mean=np.ma.mean(np.sqrt(vectorx**2+vectory**2))
        q=ax.quiver(x_c[::vec_freq, ::vec_freq],y_c[::vec_freq, ::vec_freq],vectorx[::vec_freq, ::vec_freq],vectory[::vec_freq, ::vec_freq],width=0.002)
        ax.quiverkey(q, 0.80, 0.96, vec_mean, r'mean: %.2f '%(vec_mean)+vel_unit, labelpos='E', coordinates='figure', fontproperties={'size': 12})            

    if title_opt!='off':
        if title=='default':
            ax.set_title("%s in %s" % (varname, time), fontsize=13)
        else:
            ax.set_title(title, fontsize=14)
        #ax.set_title("%s in %s" % (varname, time))
    
    ax.tick_params(labelsize=12)
    ax.set_xticks(np.arange(-90,120,30)); #ax.set_xticks(np.arange(-180,240,60));
    ax.set_xticklabels(['90$^{\circ}$S','60$^{\circ}$S','30$^{\circ}$S',\
                        '0$^{\circ}$','30$^{\circ}$N','60$^{\circ}$N','90$^{\circ}$N'])
    if axes=='on':
        if cbar_opt==True:
            cbar=fig.colorbar(a, ax=ax, boundaries=boundaries, ticks=ticks, orientation='vertical', pad=0.02)
            cbar.set_label("%s (%s)" % (varname, units), fontsize=12)
            cbar.ax.set_yticklabels(cbar.ax.get_yticklabels(), fontsize=14);
            cbar.ax.tick_params(labelsize=12);
        ax.annotate(var_disp_string, xy=(0.67,-0.12), xycoords='axes fraction', fontsize=12)
        if cont_anno==True:
            ax.annotate(var_cont_disp_string, xy=(-0.16,-0.12), xycoords='axes fraction', fontsize=12)        
        if crosssec_loc!='None':
            ax.annotate('['+crosssec_loc+']', xy=(1.02,1.02), xycoords='axes fraction', fontsize=12) 
            
    if conts_misc==True and cbar_opt==True:
        for cc in range(len(conts_misc_vars)):
            if conts_misc_cbarmark==True:
                for ccll in range(len(conts_misc_levs[cc])):
                    #cbar.ax.plot([(conts_misc_levs[cc][ccll]-lim_l)/(lim_u-lim_l),(conts_misc_levs[cc][ccll]-lim_l)/(lim_u-lim_l)],[0,1], conts_misc_cols[cc][ccll],lw=3)  
                    #cbar.ax.plot([(conts_misc_levs[cc][ccll]-lim_l)/(lim_u-lim_l),(conts_misc_levs[cc][ccll]-lim_l)/(lim_u-lim_l)],[0,1], conts_misc_cols[cc][ccll],lw=3)
                    cbar.ax.plot([0,1],[(conts_misc_levs[cc][ccll]-lim_l)/(lim_u-lim_l),(conts_misc_levs[cc][ccll]-lim_l)/(lim_u-lim_l)],c=conts_misc_cols[cc][ccll],lw=3)
            
    ax.grid(True,linestyle=':')
    if isinstance(ylims,list):
        ax.set_ylim(ylims)
    if ylevs[0]>ylevs[1] or yinv==True:
        ax.invert_yaxis()    
    if ylog==True:
        ax.set_yticks(np.log(ylevs))
        ax.set_yticklabels(ylevs)
    
    #ax.set_yscale('log')
    plt.tight_layout() #pad=0.01, h_pad=1, w_pad=0.01
    
    
    return a, fig
    
def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
    '''
    Function to offset the "center" of a colormap. Useful for
    data with a negative min and positive max and you want the
    middle of the colormap's dynamic range to be at zero.

    Input
    -----
      cmap : The matplotlib colormap to be altered
      start : Offset from lowest point in the colormap's range.
          Defaults to 0.0 (no lower offset). Should be between
          0.0 and `midpoint`.
      midpoint : The new center of the colormap. Defaults to 
          0.5 (no shift). Should be between 0.0 and 1.0. In
          general, this should be  1 - vmax / (vmax + abs(vmin))
          For example if your data range from -15.0 to +5.0 and
          you want the center of the colormap at 0.0, `midpoint`
          should be set to  1 - 5/(5 + 15)) or 0.75
      stop : Offset from highest point in the colormap's range.
          Defaults to 1.0 (no upper offset). Should be between
          `midpoint` and 1.0.
    '''
    cdict = {
        'red': [],
        'green': [],
        'blue': [],
        'alpha': []
    }

    # regular index to compute the colors
    reg_index = np.linspace(start, stop, 257)

    # shifted index to match the data
    shift_index = np.hstack([
        np.linspace(0.0, midpoint, 128, endpoint=False), 
        np.linspace(midpoint, 1.0, 129, endpoint=True)
    ])

    for ri, si in zip(reg_index, shift_index):
        r, g, b, a = cmap(ri)

        cdict['red'].append((si, r, r))
        cdict['green'].append((si, g, g))
        cdict['blue'].append((si, b, b))
        cdict['alpha'].append((si, a, a))

    newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
    plt.register_cmap(cmap=newcmap)

    return newcmap    


from matplotlib.colors import Normalize

class PiecewiseNorm(Normalize):
    def __init__(self, levels, clip=False):
        # the input levels
        self._levels = np.sort(levels)
        # corresponding normalized values between 0 and 1
        self._normed = np.linspace(0, 1, len(levels))
        Normalize.__init__(self, None, None, clip)

    def __call__(self, value, clip=None):
        # linearly interpolate to get the normalized value
        return np.ma.masked_array(np.interp(value, self._levels, self._normed))
    
def add_geological_timescale(aa,height=0.04,textoffset=0.0125,fontsize=10,fontweight='bold',stagenames=True):
    
    color_trias=np.array([129.,43.,146.])/255
    color_jura=np.array([52.,178.,201.])/255
    color_cret=np.array([127.,198.,78.])/255

    color_trias_lower=np.array([168.,88.,158.])/255
    color_trias_middle=np.array([188.,134.,186.])/255
    color_trias_upper=np.array([199.,166.,207.])/255

    color_jura_lower=np.array([0.,176.,233.])/255
    color_jura_middle=np.array([132.,207.,237.])/255
    color_jura_upper=np.array([188.,228.,250.])/255

    color_cret_lower=np.array([160.,201.,109.])/255
    color_cret_upper=np.array([186.,210.,95.])/255
    
    aa_ylim=aa.get_ylim()
    aa_xlim=aa.get_xlim()
    timescale_height=height*(aa_ylim[1]-aa_ylim[0])
    timescale_textoffset=textoffset*(aa_ylim[1]-aa_ylim[0])
#     rect = matplotlib.patches.Rectangle((-252,aa_ylim[0]),51,timescale_height,facecolor=color_trias) #,linewidth=1,edgecolor='r','tab:purple'
#     aa.add_patch(rect)
#     aa.text(-225,aa_ylim[0]+timescale_textoffset,  'Triassic',ha='center', color='white')  #, pad=10.0 ,bbox=dict(facecolor='tab:purple', edgecolor='tab:purple')
#     rect = matplotlib.patches.Rectangle((-201,aa_ylim[0]),56,timescale_height,facecolor=color_jura) #'dodgerblue',linewidth=1,edgecolor='r',
#     aa.add_patch(rect)
#     aa.text(-175,aa_ylim[0]+timescale_textoffset,  'Jurassic',ha='center', color='white')  #, pad=10.0 ,bbox=dict(facecolor='tab:purple', edgecolor='tab:purple')
#     rect = matplotlib.patches.Rectangle((-145,aa_ylim[0]),79,timescale_height,facecolor=color_cret) #'mediumseagreen',linewidth=1,edgecolor='r',
#     aa.add_patch(rect)
#     aa.text(-105,aa_ylim[0]+timescale_textoffset,  'Cretaceous',ha='center', color='white') 
    starts=[aa_xlim[0],-252,-247,-237,-201,-174,-163.5,-145,-100.5,-66]
    widths=[starts[1]-starts[0],5,10,36,27,10.5,18.5,44.5,34.5,aa_xlim[1]-starts[-1]]
    colors=['w',color_trias_lower,color_trias_middle,color_trias_upper,color_jura_lower,color_jura_middle,color_jura_upper,color_cret_lower,color_cret_upper,'w']
    for ii in range(len(starts)):
        rect=matplotlib.patches.Rectangle((starts[ii],aa_ylim[0]),widths[ii],timescale_height,facecolor=colors[ii],edgecolor='tab:grey',clip_on=False,zorder=9)
        #rect=matplotlib.patches.Rectangle((starts[ii],aa_ylim[0]-timescale_height),widths[ii],timescale_height,facecolor=colors[ii],edgecolor='tab:grey',clip_on=False)
        aa.add_patch(rect)
        
    texts=['Triassic','Jurassic','Cretaceous']
    texts_pos=[-225,-178.5,-105]
    for ii in range(len(texts)):
        aa.text(texts_pos[ii],aa_ylim[0]+timescale_textoffset,texts[ii],ha='center', color='k',fontweight=fontweight,fontsize=fontsize,zorder=10)
        #aa.text(texts_pos[ii],aa_ylim[0]-timescale_height+timescale_textoffset,texts[ii],ha='center', color='k',fontweight=fontweight,fontsize=fontsize)
    
    if stagenames==True:
        texts_LMU=['E','M','L','E','M','L','E','L']
        texts_LMU_pos=[-250,-242.5,-210,-192.5,-167.5,-155,-130,-85]
        for ii in range(len(texts_LMU)):
            aa.text(texts_LMU_pos[ii],aa_ylim[0]+timescale_textoffset,texts_LMU[ii],ha='center', color='k',fontweight='normal',fontsize=fontsize,zorder=10)
    
