import numpy as np
import math
from netCDF4 import Dataset 
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap, addcyclic, shiftgrid
import pickle
from tools_analysis import grid_gen
import importlib
import tools_analysis as jan
importlib.reload(jan)
#import tools_analysis as jan
#reload(jan)

##Example:
#nc_f = './history_p2_nc_files/history_p2_O24p5_E0p000_P000.nc'  
#nc_fid = Dataset(nc_f, 'r')  
##nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)
#ts_ann = nc_fid.variables['ts_ann'][:]  # shape is time, lat, lon as shown above
#var=ts_ann[5700, :, :]; projection='rec';varname='var'; units='degC'; time='year'
#
#nc_f = './snapshots.004103.01.01.dta.nc'  
#nc_fid = Dataset(nc_f, 'r')  
#nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)
#temp = np.flipud(nc_fid.variables['temp'][0,0,:,:])
#var=temp; projection='rec';varname='var'; units='degC'; time='year'       

def plot_world(var, varname='var', units='degC', lim_l='None',lim_u='None', cbar_delta='None', cbar_tick_freq=2, \
               projection='rec', continents='off', continents_timeslice=201, continents_array='None', rotate='no',time='year', title='default', colourmap='Spectral_r',axes='on', title_opt='on', fig_handle='None', var_digits=2, extend='neither', cbar_orient='hor', anno='on', \
               vec='off',vectorx='None',vectory='None', vec_freq=1, vec_anno='on',vec_col='k',\
               cont='off', var_cont='None', var_cont_name='None', cont_anno='on', var_cont_unit='None', cont_label_fmt='%.2f',  cont_levs=None, cont_labels='on', linestyle=None, var_cont_digits=2, cont_col='k', \
              scat='off', var_scat='None', var_scat2='None', scat_colored='off', scat_colors='None', scat_size=50, var_scat_c='xkcd:bright blue' ,var_scat2_c='xkcd:green',
              hatches='off', var_hatch='None', hatch_levels='None', hatch_levels2='None'):
    #,time='year' ,axorplt='def'
    latso, lonso, xo, yo=grid_gen(region='ocean', cell='T')             
    if len(var)==24:
        x, y=grid_gen(region='atmo', cell='p')[2:4]
        xu,yu=grid_gen(region='atmo', cell='u')[2:4]
        xt,yt=grid_gen(region='atmo', cell='T')[2:4]
        varmean=jan.global_mean(var, grid='atmo')
        #if cont=='on':
        #    var_cont_mean=jan.global_mean(var_cont, grid='atmo')
        vel_unit='m/s'
    if len(var)==48:
        x, y=grid_gen(region='ocean', cell='p')[2:4] 
        xu,yu=grid_gen(region='ocean', cell='u')[2:4]
        xt,yt=grid_gen(region='ocean', cell='T')[2:4]
        varmean=jan.global_mean(var, grid='ocean')
        #if cont=='on':
        #    var_cont_mean=jan.global_mean(var_cont, grid='ocean')            
        vel_unit='cm/s'
    if cont=='on' or hatches=='on':     
        if len(var_cont)==24 or len(var_hatch)==24:
            xtc,ytc=grid_gen(region='atmo', cell='T')[2:4]
        if len(var_cont)==48 or len(var_hatch)==48:
            xtc,ytc=grid_gen(region='ocean', cell='T')[2:4]
        if len(var_cont)==180 or len(var_hatch)==180:
            xtc,ytc=grid_gen(region='atmo_1deg', cell='T')[2:4]
    if cont=='on':     
        if len(var_cont)==24:
            var_cont_mean=jan.global_mean(var_cont, grid='atmo')
        if len(var_cont)==48:
            var_cont_mean=jan.global_mean(var_cont, grid='ocean')
    if len(var)==180:
        x, y=grid_gen(region='atmo_1deg', cell='p')[2:4] 
        xu,yu=grid_gen(region='atmo_1deg', cell='u')[2:4]
        xt,yt=grid_gen(region='atmo_1deg', cell='T')[2:4]
        varmean=jan.global_mean(var, grid='atmo_1deg')
        if cont=='on':
            var_cont_mean=jan.global_mean(var_cont, grid='atmo_1deg')            
        vel_unit='cm/s'
            
    if rotate=='yes':
        var=np.ma.concatenate((var[:,len(var):],var[:,0:len(var)]),axis=1)
        var_cont=np.ma.concatenate((var_cont[:,len(var_cont):],var_cont[:,0:len(var_cont)]),axis=1)
        vectorx=np.ma.concatenate((vectorx[:,len(var):],vectorx[:,0:len(var)]),axis=1)
        vectory=np.ma.concatenate((vectory[:,len(var):],vectory[:,0:len(var)]),axis=1)
        
    varmin=np.min(var)    
    varmax=np.max(var)
    if cont=='on':
        var_cont_min=np.min(var_cont)
        var_cont_max=np.max(var_cont)
    if var_digits==0:
        varmin_disp=str(int(round(varmin)))
        varmean_disp=str(int(round(varmean)))
        varmax_disp=str(int(round(varmax)))
    else:
        varmin_disp=str(np.round(varmin, var_digits))
        varmean_disp=str(np.round(varmean, var_digits))
        varmax_disp=str(np.round(varmax, var_digits))
    var_disp_string=varname+' ['+varmin_disp+', '+varmean_disp+', '+varmax_disp+'] '+units
    if cont=='on':
        if var_cont_digits==0:
            var_cont_min_disp=str(int(round(var_cont_min)))
            var_cont_mean_disp=str(int(round(var_cont_mean)))
            var_cont_max_disp=str(int(round(var_cont_max))) 
        else:
            var_cont_min_disp=str(np.round(var_cont_min, var_cont_digits))
            var_cont_mean_disp=str(np.round(var_cont_mean, var_cont_digits))
            var_cont_max_disp=str(np.round(var_cont_max, var_cont_digits))
        var_cont_disp_string=var_cont_name+' ['+var_cont_min_disp+', '+var_cont_mean_disp+', '+var_cont_max_disp+'] '+var_cont_unit
 
    vmin=varmin
    vmax=varmax    
    if lim_l!='None':
        vmin=lim_l
    if lim_u!='None':
        vmax=lim_u
    if extend!='neither':
        if extend=='both':
            vmin=lim_l-cbar_delta
            vmax=lim_u+cbar_delta
        elif extend=='min':
            vmin=lim_l-cbar_delta
            vmax=lim_u
        elif extend=='max':
            vmin=lim_l
            vmax=lim_u+cbar_delta            
    if cbar_delta!='None':                  
        if extend!='neither':
            if extend=='both':
                boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta, \
                                     math.ceil(vmax/cbar_delta)*cbar_delta, cbar_delta)
                ticks=boundaries[1:-1:cbar_tick_freq]
            elif extend=='min':
                boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta, \
                                     math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta, cbar_delta)
                ticks=boundaries[1::cbar_tick_freq]   
            elif extend=='max':
                boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta, \
                                     math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta, cbar_delta)
                ticks=boundaries[::cbar_tick_freq]                 
        else:    
            #boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta, math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta, cbar_delta) #
            boundaries=np.arange(vmin, vmax+cbar_delta, cbar_delta)
            ticks=boundaries[::cbar_tick_freq]
    else:
        boundaries=np.arange(vmin,vmax,(vmax-vmin)/50)
        ticks=boundaries[::cbar_tick_freq]

    #if cbar_tick_freq=='None'
        #cbar_tick_freq=2

    if continents_timeslice==201:
        with open('../climber/continental_mask.pickle', 'rb') as f:  
            continental_mask = pickle.load(f, encoding='latin1') #Python3: , encoding='latin1'
    elif continents_array!='None':
        continental_mask=continents_array
    elif continents_timeslice!=201:
        timeslice='{:03d}'.format(continents_timeslice)
        continental_mask=np.loadtxt('../MesoClim/c3a_meso_InputFiles/'+timeslice+'Ma/depth_adj.dat', dtype=int)[:,:]
        continental_mask[continental_mask>0]=1
        
    if fig_handle=='None':
        fig = plt.figure(figsize=((8,5))) #,1
        ax= fig.add_subplot(111)
    else:
        fig=fig_handle[0]
        ax=fig_handle[1]
        
    #plt.clf()
    #fig.subplots_adjust(left=0., right=1., bottom=0., top=0.9)
    if projection=='moll':     # Setup the map. See http://matplotlib.org/basemap/users/mapsetup.html for other projections.
        m = Basemap(projection='moll', llcrnrlat=-90, urcrnrlat=90,llcrnrlon=-180, urcrnrlon=180, resolution='c', lon_0=0)
        #m = Basemap(projection='robin', llcrnrlat=-90, urcrnrlat=90,llcrnrlon=-180, urcrnrlon=180, resolution='c', lon_0=0)
        m.drawmapboundary()        
        x, y = m(x, y) # Transforms lat/lon into plotting coordinates for projection
        xo, yo = m(xo, yo)
        xtm,ytm=m(xt,yt)   
        if cont=='on' or hatches=='on':
            xtcm,ytcm=m(xtc,ytc)
        if scat=='on':
            x_scat, y_scat=m(var_scat[:,0],var_scat[:,1])
            if var_scat2!='None':
                x_scat2, y_scat2=m(var_scat2[:,0],var_scat2[:,1])
    elif projection=='rec':
        xtm,ytm=xt,yt
        if cont=='on':
            xtcm,ytcm=xtc,ytc
    
    if colourmap=='Spectral_shifted':
        if abs(lim_u)>abs(lim_l):
            mycmap = shiftedColorMap(matplotlib.cm.Spectral_r, start=(lim_l+lim_u)/(2*lim_u)) #midpoint=abs(lim_l)/(lim_u+lim_l)
        elif abs(lim_u)<abs(lim_l):
            mycmap = shiftedColorMap(matplotlib.cm.Spectral_r, stop=(abs(lim_l)+lim_u)/(2*abs(lim_l)))
    elif colourmap=='Spectral_shifted_r':
        if abs(lim_u)>abs(lim_l):
            mycmap = shiftedColorMap(matplotlib.cm.Spectral, start=(lim_l+lim_u)/(2*lim_u)) #midpoint=abs(lim_l)/(lim_u+lim_l)
        elif abs(lim_u)<abs(lim_l):
            mycmap = shiftedColorMap(matplotlib.cm.Spectral, stop=(abs(lim_l)+lim_u)/(2*abs(lim_l)))  
    else:
        mycmap= plt.cm.get_cmap(colourmap,len(boundaries)) #50
        
    a=ax.pcolor(x, y, var,cmap=mycmap,vmin=vmin,vmax=vmax) #m.contourf(x, y, air_cyclic, 11, cmap=plt.cm.Spectral_r)
    if len(var)==24 or continents=='on':
        ax.contour(xo,yo,continental_mask,levels=[0.5], linewidths=1.75)
    if cont=='on':
        #if cont_levs=='None':
            #cont=ax.contour(xtm, ytm, var_cont, colors='k', linewidths=0.9, linestyles=linestyle) #6 , levels=[-10, -5, 0, 5, 15, 20, 25, 30]
        #else:
        cont=ax.contour(xtcm, ytcm, var_cont, levels=cont_levs, colors=cont_col, linewidths=1.1, linestyles=linestyle, antialiased=True) #0.9
        if cont_labels!='off':
            ax.clabel(cont, fontsize=11, inline=1, fmt=cont_label_fmt) ##fmt='%d'
    if vec=='on': #vectorx.any()!='None':
        vec_mean=np.ma.mean(np.sqrt(vectorx**2+vectory**2))
        if projection=='rec':          
            q=ax.quiver(xt[::vec_freq, ::vec_freq],yt[::vec_freq, ::vec_freq],vectorx[::vec_freq, ::vec_freq],vectory[::vec_freq, ::vec_freq],width=0.002,color=vec_col)
            if vec_anno=='on':
                ax.quiverkey(q, 0.80, 0.96, vec_mean, r'mean: %.2f '%(vec_mean)+vel_unit, labelpos='E', coordinates='figure', fontproperties={'size': 12})            
        if projection=='moll':
            #q=m.quiver(xtm[0::2,0::2],ytm[0::2,0::2],vectorx[0::2,0::2],vectory[0::2,0::2],width=0.002)
            q=m.quiver(xtm[::vec_freq, ::vec_freq],ytm[::vec_freq, ::vec_freq],vectorx[::vec_freq, ::vec_freq],vectory[::vec_freq, ::vec_freq],width=0.0015,color=vec_col)            
            #q=m.quiver(xtm,ytm,vectorx,vectory, width=0.0015, scale=vec_mean*8, scale_units='inches')
            if vec_anno=='on':
                ax.quiverkey(q, 0.12, 0.205, vec_mean, r'mean: %.2f '%(vec_mean)+vel_unit, labelpos='E',coordinates='figure', fontproperties={'size': 12})
    if scat=='on':
        if scat_colored=='on':
            scat=ax.scatter(x_scat, y_scat,edgecolors='k', zorder=5, c=scat_colors,cmap=mycmap,vmin=vmin,vmax=vmax)     
        else:
            scat=ax.scatter(x_scat, y_scat,edgecolors='k', c=var_scat_c, zorder=5, s=scat_size) #, marker='*', s=125
        if var_scat2!='None':
            scat2=ax.scatter(x_scat2, y_scat2,edgecolors='k', c=var_scat2_c, zorder=5, s=scat_size) 
    if hatches=='on':
        matplotlib.rcParams['hatch.color']='whitesmoke' #'navy' #
        ax.contourf(xtcm, ytcm, var_hatch, hatches=['...',None], levels=hatch_levels, colors='none')
        if hatch_levels2!='None':
            ax.contourf(xtcm, ytcm, var_hatch, hatches=['xxx',None], levels=hatch_levels2, colors='none')
            
    if title_opt!='off':
        if title_opt=='bottom':
            ax.set_title(title, fontsize=14, pad=-250)
        else:
            if title=='default':
                ax.set_title("%s in %s" % (varname, time), fontsize=14)
            else:
                ax.set_title(title, fontsize=14)
        #ax.set_title("%s in %s" % (varname, time))
    
    if projection=='rec':
        ax.set_xticks(np.arange(-180,240,60)); ax.set_yticks(np.arange(-90,120,30));
        if axes=='on':
            cbar=fig.colorbar(a, ax=ax, boundaries=boundaries, ticks=ticks, orientation='vertical', pad=0.02, extend=extend)
            cbar.set_label("%s (%s)" % (varname, units), fontsize=12)
            cbar.ax.set_yticklabels(cbar.ax.get_yticklabels(), fontsize=14);
            ax.set_xlabel('Longitude ($^{\circ}$)', fontsize=12); ax.set_ylabel('Latitude ($^{\circ}$)', fontsize=12); 
            ax.tick_params(labelsize=12)
            ax.annotate(var_disp_string, xy=(0.67,-0.12), xycoords='axes fraction', fontsize=12)
            if cont!='off' and cont_anno!='off':
                ax.annotate(var_cont_disp_string, xy=(-0.1,-0.12), xycoords='axes fraction', fontsize=12)  
    
    if projection=='moll': 
        m.drawparallels(np.arange(-90.,120.,30.), linewidth=0.5)
        m.drawmeridians(np.arange(0.,420.,60.), linewidth=0.5)
        if axes=='on':
            if cbar_orient=='hor':
                cbar=fig.colorbar(a, ax=ax, orientation='horizontal',boundaries=boundaries, ticks=ticks, \
                               shrink=0.7 , pad=0.06, extend=extend) #drawedges=True, #cbar.dividers.set_color('white') 
                #cbar=fig.colorbar(a, ax=ax, orientation='horizontal',boundaries=boundaries, ticks=ticks, \
                #               shrink=0.46 , pad=0.06, extend=extend) #drawedges=True, #cbar.dividers.set_color('white') 
            if cbar_orient=='vert':
                cbar=fig.colorbar(a, ax=ax, orientation='vertical',boundaries=boundaries, ticks=ticks, \
                               pad=0.01,shrink=0.65, extend=extend) #drawedges=True, #cbar.dividers.set_color('white')                
            #a.set_clim([boundaries[0], boundaries[-1]])
            cbar.set_label("%s (%s)" % (varname, units), fontsize=12)
            cbar.ax.tick_params(labelsize=12);
            if anno=='on':
                ax.annotate(var_disp_string, xy=(0.53,-0.058), xycoords='axes fraction', fontsize=12) 
                if cont!='off' and cont_anno!='off':
                    ax.annotate(var_cont_disp_string, xy=(0.02,-0.058), xycoords='axes fraction', fontsize=12)
    
    fig.patch.set_facecolor('w')
    plt.tight_layout() #pad=0.01, h_pad=1, w_pad=0.01
    return a, fig


def plot_cross(var, varname='var', colourmap='Spectral_r', lim_l='None', lim_u='None', cbar_delta='None', cbar_tick_freq=2, \
               units='degC', time='year', title='default', fig_handle='None', crosssec_loc='None', axes='on', title_opt='on', linestyle='None', var_digits=2,\
               cont='off', var_cont='None', var_cont_name='None', var_cont_unit='None', cont_levs='None', cont_anno='on', cont_label_fmt='%.2f', cont_labels='on', var_cont_digits=2,\
               vec='off', vectorx='None', vectory='None', vec_freq=1):
    vel_unit='cm/s'
    
    nc_f = '../climber/c3beta_tria_200Ma_1500ppm/topog.dta.nc'
    nc_fid = Dataset(nc_f, 'r')
    zw_k=nc_fid.variables['zw_k']
    zt_k=nc_fid.variables['zt_k']
    y_p=grid_gen(region='ocean', cell='p')[0] #c_fid.variables['yt_j']
    y_t=grid_gen(region='ocean', cell='T')[0]
    
    lats=y_p[::-1]
    depth=-1*zw_k[:]   
    x,y = np.meshgrid(lats, depth)
    if var_cont.shape==(25,48):
        x_c, y_c=np.meshgrid(y_t[::-1], depth)
    elif var_cont.shape==(24,48):
        x_c, y_c=np.meshgrid(y_t[::-1], -1*zt_k[:])
    
    varmean=jan.global_mean(var, region='crosssec', grid='ocean')
    varmin=np.min(var)    
    varmax=np.max(var)
    
    if cont_anno!='off':
        var_cont_mean=jan.global_mean(var_cont, region='crosssec', grid='ocean')
    var_cont_min=np.min(var_cont)
    var_cont_max=np.max(var_cont)
    
    #varmin=np.min(var)    
    #varmax=np.max(var)
    #if cont=='on':
    #    var_cont_min=np.min(var_cont)
    #    var_cont_max=np.max(var_cont)
    if var_digits==0:
        varmin_disp=str(int(round(varmin)))
        varmean_disp=str(int(round(varmean)))
        varmax_disp=str(int(round(varmax)))
    if var_cont_digits==0 and cont=='on' and cont_anno!='off':
        var_cont_min_disp=str(int(round(var_cont_min)))
        var_cont_mean_disp=str(int(round(var_cont_mean)))
        var_cont_max_disp=str(int(round(var_cont_max))) 
    else:
        varmin_disp=str(np.round(varmin, var_digits))
        varmean_disp=str(np.round(varmean, var_digits))
        varmax_disp=str(np.round(varmax, var_digits))
        if cont=='on' and cont_anno!='off':
            var_cont_min_disp=str(np.round(var_cont_min, var_cont_digits))
            var_cont_mean_disp=str(np.round(var_cont_mean, var_cont_digits))
            var_cont_max_disp=str(np.round(var_cont_max, var_cont_digits))
    var_disp_string=varname+' ['+varmin_disp+', '+varmean_disp+', '+varmax_disp+'] '+units
    if cont=='on' and cont_anno!='off':
        var_cont_disp_string=var_cont_name+' ['+var_cont_min_disp+', '+var_cont_mean_disp+', '+var_cont_max_disp+'] '+var_cont_unit
        
    vmin=varmin
    vmax=varmax    
    if lim_l!='None':
        vmin=lim_l
    if lim_u!='None':
        vmax=lim_u
    if cbar_delta!='None':
        boundaries=np.arange(math.floor(vmin/cbar_delta)*cbar_delta,math.ceil(vmax/cbar_delta)*cbar_delta+cbar_delta,cbar_delta)
    else:
        boundaries=np.arange(vmin,vmax,(vmax-vmin)/50)
    #if cbar_tick_freq=='None'
        #cbar_tick_freq=2
        
    if fig_handle=='None':
        fig = plt.figure(figsize=((7.5,5.15))) #,1
        ax= fig.add_subplot(111)
    else:
        fig=fig_handle[0]
        ax=fig_handle[1]
        
    if colourmap=='Spectral_shifted':
        mycmap = shiftedColorMap('Spectral_r',start=(abs(lim_u)-abs(lim_l))/(2*lim_u), name='shifted')
    else:
        mycmap= plt.cm.get_cmap(colourmap,len(boundaries)) #50
        
    a=ax.pcolor(x, y, var,cmap=mycmap,vmin=vmin,vmax=vmax) #m.contourf(x, y, air_cyclic, 11, cmap=plt.cm.Spectral_r)
    ax.set_ylim([-3500,0]);
    ax.set_xlabel('Latitude ($^{\circ}$)', fontsize=12); ax.set_ylabel('Depth (m)', fontsize=12); 
    if cont=='on':
        if cont_levs=='None':
            cont=ax.contour(x_c, y_c, var_cont, colors='k', linewidths=0.9, linestyles=linestyle) #
        else:
            cont=ax.contour(x_c, y_c, var_cont, levels=cont_levs, colors='k', linewidths=0.9) #, linestyles=linestyle
        if cont_labels!='off':
            ax.clabel(cont, fontsize=11, inline=1, fmt=cont_label_fmt) #fmt='%d'
    if vec=='on':
        vec_mean=np.ma.mean(np.sqrt(vectorx**2+vectory**2))
        q=ax.quiver(x_c[::vec_freq, ::vec_freq],y_c[::vec_freq, ::vec_freq],vectorx[::vec_freq, ::vec_freq],vectory[::vec_freq, ::vec_freq],width=0.002)
        ax.quiverkey(q, 0.80, 0.96, vec_mean, r'mean: %.2f '%(vec_mean)+vel_unit, labelpos='E', coordinates='figure', fontproperties={'size': 12})            

    if title_opt!='off':
        if title=='default':
            ax.set_title("%s in %s" % (varname, time), fontsize=13)
        else:
            ax.set_title(title, fontsize=14)
        #ax.set_title("%s in %s" % (varname, time))
    
    ax.tick_params(labelsize=12)
    ax.set_xticks(np.arange(-90,120,30)); #ax.set_xticks(np.arange(-180,240,60));
    ax.set_xticklabels(['90$^{\circ}$S','60$^{\circ}$S','30$^{\circ}$S',\
                        '0$^{\circ}$','30$^{\circ}$N','60$^{\circ}$N','90$^{\circ}$N'])
    if axes=='on':
        cbar=fig.colorbar(a, ax=ax, boundaries=boundaries, ticks=boundaries[::cbar_tick_freq], orientation='vertical', pad=0.02)
        cbar.set_label("%s (%s)" % (varname, units), fontsize=12)
        cbar.ax.set_yticklabels(cbar.ax.get_yticklabels(), fontsize=14);
        ax.annotate(var_disp_string, xy=(0.67,-0.12), xycoords='axes fraction', fontsize=12)
        if cont_anno!='off':
            ax.annotate(var_cont_disp_string, xy=(-0.16,-0.12), xycoords='axes fraction', fontsize=12)        
        if crosssec_loc!='None':
            ax.annotate('['+crosssec_loc+']', xy=(1.02,1.02), xycoords='axes fraction', fontsize=12)        
    plt.tight_layout() #pad=0.01, h_pad=1, w_pad=0.01
    return a, fig
    
def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
    '''
    Function to offset the "center" of a colormap. Useful for
    data with a negative min and positive max and you want the
    middle of the colormap's dynamic range to be at zero.

    Input
    -----
      cmap : The matplotlib colormap to be altered
      start : Offset from lowest point in the colormap's range.
          Defaults to 0.0 (no lower offset). Should be between
          0.0 and `midpoint`.
      midpoint : The new center of the colormap. Defaults to 
          0.5 (no shift). Should be between 0.0 and 1.0. In
          general, this should be  1 - vmax / (vmax + abs(vmin))
          For example if your data range from -15.0 to +5.0 and
          you want the center of the colormap at 0.0, `midpoint`
          should be set to  1 - 5/(5 + 15)) or 0.75
      stop : Offset from highest point in the colormap's range.
          Defaults to 1.0 (no upper offset). Should be between
          `midpoint` and 1.0.
    '''
    cdict = {
        'red': [],
        'green': [],
        'blue': [],
        'alpha': []
    }

    # regular index to compute the colors
    reg_index = np.linspace(start, stop, 257)

    # shifted index to match the data
    shift_index = np.hstack([
        np.linspace(0.0, midpoint, 128, endpoint=False), 
        np.linspace(midpoint, 1.0, 129, endpoint=True)
    ])

    for ri, si in zip(reg_index, shift_index):
        r, g, b, a = cmap(ri)

        cdict['red'].append((si, r, r))
        cdict['green'].append((si, g, g))
        cdict['blue'].append((si, b, b))
        cdict['alpha'].append((si, a, a))

    newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
    plt.register_cmap(cmap=newcmap)

    return newcmap    
    
