import numpy as np
from netCDF4 import Dataset  # http://code.google.com/p/netcdf4-python/
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap, addcyclic, shiftgrid
from nc_dump import ncdump
import math
#from plot_scripts import plot_world
import pickle

##Generate continental mask
#nc_f = './c3beta_tria_200Ma_1500ppm/snapshots.004002.01.01.dta.nc'  
#nc_fid = Dataset(nc_f, 'r')  
#nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)
#temp = np.flipud(nc_fid.variables['temp'][0,0,:,:])
#mask=np.empty([48,96])
#mask[np.where(temp > -1000)]=1
#mask[np.where(temp < -1000)]=0
#with open('continental_mask.pickle', 'wb') as f:  
#    pickle.dump(mask, f)
#with open('continental_mask.pickle', 'rb') as f:  
#    continental_mask = pickle.load(f, encoding='latin1')

###Function for generating the different Grids:
def grid_gen(region='atmo', cell='T'):
    if region=='atmo':
        if cell=='T':
            lat_n=24; dlat=7.5; lat_max=90-dlat/2;   lat_min=-90+dlat/2; 
            lon_n=16; dlon=22.5;lon_min=-180+dlon/2; lon_max=180-dlon/2;
        if cell=='p': 
            lat_n=25; dlat=7.5;  lat_max=90;   lat_min=-90;  
            lon_n=17; dlon=22.5; lon_min=-180; lon_max=180; 
        if cell=='u': 
            lat_n=24; dlat=7.5;  lat_max=90;   lat_min=-90+dlat;  
            lon_n=14; dlon=22.5; lon_min=-180+dlon; lon_max=180;            
    if region=='ocean':
        if cell=='T':
            lat_n=48; dlat=3.75; lat_max=90-dlat/2;   lat_min=-90+dlat/2; 
            lon_n=96; dlon=3.75; lon_min=-180+dlon/2; lon_max=180-dlon/2;      
        if cell=='p':
            lat_n=49; dlat=3.75;  lat_max=90;   lat_min=-90;  
            lon_n=97; dlon=3.75;  lon_min=-180; lon_max=180; 
        if cell=='u':
            lat_n=48; dlat=3.75;  lat_max=90;   lat_min=-90+dlat;  
            lon_n=96; dlon=3.75;  lon_min=-180+dlon; lon_max=180;     
    if region=='atmo_1deg':
        if cell=='T':
            lat_n=180; dlat=1; lat_max=90-dlat/2;   lat_min=-90+dlat/2; 
            lon_n=360; dlon=1;lon_min=-180+dlon/2; lon_max=180-dlon/2;
        if cell=='p': 
            lat_n=181; dlat=1;  lat_max=90;   lat_min=-90;  
            lon_n=361; dlon=1; lon_min=-180; lon_max=180; 
        if cell=='u': 
            lat_n=180; dlat=1;  lat_max=90;   lat_min=-90+dlat;  
            lon_n=360; dlon=1; lon_min=-180+dlon; lon_max=180;
    lats = np.linspace(lat_max, lat_min, lat_n) 
    lons = np.linspace(lon_min, lon_max, lon_n)
    x,y = np.meshgrid(lons, lats)         
    return lats, lons, x, y


#Function for calculating spatial global or regional averages:
def global_mean(var, region='global', grid='atmo', axis='lonlat', timeseries='no', restrict_lat='no'):    
#    with open('../climber/landfrac.pickle', 'rb') as f:  
#        landfrac = pickle.load(f, encoding='latin1')    #Python3: , encoding='latin1'
    if grid=='atmo':
        lats=grid_gen(region='atmo',cell='T')[0]
    if grid=='ocean':
        if restrict_lat=='yes':
          lats_temp=grid_gen(region='ocean',cell='T')[0]
          lats=lats_temp[16:33]
        else:
          lats=grid_gen(region='ocean',cell='T')[0]
    if grid=='atmo_1deg':
        lats=grid_gen(region='atmo_1deg',cell='T')[0]
    pi_frac=(np.pi/180)*lats[:]
    
#    ax_lat=0
#    if timeseries=='yes':
#        ax_lat=1
#    ax_lon=ax_lat+1
    ax_lat=0
    if len(var.shape)==3:
        ax_lat=1
    ax_lon=ax_lat+1  
    
    if region=='global' or region=='NH' or region=='SH':
        if grid=='atmo' or grid=='atmo_1deg':        
            weightsum=np.sum(np.cos(pi_frac)) 
            var_mean_zonal=np.mean(var,axis=ax_lon)#zonal mean
            var_mean_zonal_weight=var_mean_zonal[:]*np.cos(pi_frac)[:]#weighted zonal mean 
            var_mean=np.sum(var_mean_zonal_weight,axis=ax_lat)/weightsum #mean
        if grid=='ocean':
            if len(var.shape)==3:            
                weightsum=np.sum(np.cos(pi_frac)*var.count(axis=ax_lon)[0,:]) #weighting also by number of land-/sea-cell at every latitude
            elif len(var.shape)==2:
                weightsum=np.sum(np.cos(pi_frac)*var.count(axis=ax_lon)) #weighting also by number of land-/sea-cell at every latitude                
            var_mean_zonal=np.ma.mean(var,axis=ax_lon)#zonal mean
            var_mean_zonal_weight=var_mean_zonal[:]*np.cos(pi_frac)[:]*var.count(axis=ax_lon)#weighted zonal mean 
            var_mean=np.sum(var_mean_zonal_weight,axis=ax_lat)/weightsum #mean                
        if region=='NH':
            var_mean=np.sum(var_mean_zonal_weight[0:len(lats)/2],axis=ax_lat)/np.sum(np.cos(pi_frac[0:len(lats)/2])*var.count(axis=ax_lon)[0:len(lats)/2]) 
        if region=='SH':
            var_mean=np.sum(var_mean_zonal_weight[len(lats)/2:],axis=ax_lat)/np.sum(np.cos(pi_frac[len(lats)/2:])*var.count(axis=ax_lon)[len(lats)/2:])        
        if axis=='zonal':
            var_mean=var_mean_zonal  
            
#    if region=='land' or region=='sea':
#        mask=np.empty((var.shape), dtype=np.bool)
#        if region=='land':
#            mask[:,:]=(landfrac<0.9)
#            var_masked=np.ma.MaskedArray(var, mask=mask)
#        if region=='sea':
#            mask[:,:]=(landfrac>0.1)
#            var_masked=np.ma.MaskedArray(var, mask=mask)        
##        weightsum=np.sum(np.cos(pi_frac)*var_masked[0,:,:].count(axis=ax_lon)) #weighting also by number of land-/sea-cell at every latitude
#        n_loncells=var_masked.count(axis=ax_lon)
#        if len(var.shape)==3:
#            n_loncells=var_masked[1,:,:].count(axis=1)
#        weightsum=np.sum(np.cos(pi_frac)*n_loncells) #weighting also by number of land-/sea-cell at every latitude
#        var_mean_zonal=np.ma.mean(var_masked,axis=ax_lon)#zonal mean
#        var_mean_zonal_weight=var_mean_zonal[:]*np.cos(pi_frac)[:]*n_loncells#weighted zonal mean 
#        var_mean=np.sum(var_mean_zonal_weight,axis=ax_lat)/weightsum #mean
        
    if region=='crosssec':
        nc_f = '../climber/c3beta_tria_200Ma_1500ppm/topog.dta.nc'
        nc_fid = Dataset(nc_f, 'r')
        zw_k=nc_fid.variables['zw_k'][:]         
        thicks=zw_k[1:]-zw_k[0:-1]
        var_mean_horizontal=np.sum(var,axis=1)/(var.count(axis=1))
        var_mean=np.sum(var_mean_horizontal*var.count(axis=1)*thicks)/np.sum(var.count(axis=1)*thicks)
 
    return var_mean
