#!/usr/bin/env python

# This file is part of chelsa_isimip3b_ba_1km.
#
# chelsa_isimip3b_ba_1km is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# chelsa_isimip3b_ba_1km is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with chelsa_isimip3b_ba_1km.  If not, see <https://www.gnu.org/licenses/>.

import numpy as np
import xarray as xr
import os

from . import auxiliary

class ingest_isimip:
    ''' ingest isimip class '''

    def __init__(self,
                 inputdir=None,
                 outputdir=None,
                 model=None,
                 experiment=None,
                 realization=None,
                 parm=None,
                 day=None,
                 month=None,
                 year=None,
                 interpol=None,
                 lshiftlong=False,
                 file_suffix_out=None):
        self.inputdir = inputdir
        self.outputdir = outputdir
        self.model = model
        self.experiment = 'historical' if experiment is None else experiment
        self.realization = 'r1i1p1f1' if realization is None else realization
        self.parm = parm
        self.day = day
        self.month = month
        self.year = year
        self.interpol = interpol
        self.lshiftlong = lshiftlong
        self.file_suffix_out = 'isimip3b' if file_suffix_out is None else file_suffix_out

    def _absPaths_(self, directory):
        for dirpath, _, filenames in os.walk(directory):
            for f in filenames:
                yield os.path.abspath(os.path.join(dirpath, f))

    def get_ds(self):

        print('  %Preprocessing:     {0} - {1}{2}{3}'.format(self.parm,
                                                             '0000' if self.year is None else '{0:04d}'.format(self.year),
                                                             '00' if self.month is None else '{0:02d}'.format(self.month),
                                                             '00' if self.day is None else '{0:02d}'.format(self.day)
                                                             ))

        # get file list for the given -----------------------------------------
        include = {'tlapse': ['ta_|zg_|ps_', self.model, self.experiment, self.realization],
                   'lcl': ['tas_|hurs_', self.model, self.experiment, self.realization]}.get(
                  self.parm, [self.parm + '_', self.model, self.experiment, self.realization])
        # to speedup the i/o we can refine the file list with the given year, assuming one file per year
        if self.year is not None:
            include += ['{0:04d}'.format(self.year)]
        file_list = auxiliary.search_file_list(dir=self.inputdir, include=include)
        
        # do nothing if no file was found
        if len(file_list) < 1:
            print('  %Preprocessing:     no files found for')
            print('  %Preprocessing:         dir = {0}'.format(self.inputdir))
            print('  %Preprocessing:         pattern = {0}'.format(include))
            print('  %Preprocessing:     skipping this variable')
            return None

        # read data ---------------------------------------------------------
        ds = xr.open_mfdataset(file_list)

        tstr = '-'.join([_ for _ in (None if self.year is None else '{0:04d}'.format(self.year),
                                     None if self.month is None else '{0:02d}'.format(self.month),
                                     None if self.day is None else '{0:02d}'.format(self.day)) if _ is not None])
        ds = ds.sel(time=tstr)

        # average over time
        if any('time' in ds.data_vars[_].dims for _ in ds.data_vars):
            ds = ds.groupby('time').mean('time')

        # calculate temperature lapse rate from ta and zg fields on the lowers pressure
        # levels above surface
        if self.parm == 'tlapse':
            ds = calculate_tlapse(ds=ds)
        # calculate lifted condensation level from tas and hurs
        if self.parm == 'lcl':
            ds = calculate_lcl(ds=ds)

        # remove trivial time dimensions from data (i.e., time = 1)
        if any('time' in ds.data_vars[_].dims for _ in ds.data_vars):
            ds = ds.squeeze(dim=['time'])

        # preprocess data ---------------------------------------------------
        if self.lshiftlong:
            ds = ds.assign_coords(lon=(((ds.lon + 180) % 360) - 180))

        if self.interpol is not None:
            ds = ds.interp(lat=self.interpol['lat'], lon=self.interpol['lon'])

        for variable in ['lat_bnds', 'lon_bnds', 'time', 'time_bnds', 'plev']:
            if variable in ds:
                print('  %Preprocessing:         dimension {0} exist... deleting'.format(variable))
                ds = ds.drop(variable)
            else:
                print('  %Preprocessing:         dimension {0} does not exist... doing nothing'.format(variable))

        # determine fname_out -----------------------------------------------
        fname_out = os.path.join(self.outputdir, self.file_suffix_out + '_' + self.parm + '.nc')

        # write preprocessed data to fname_out ------------------------------
        print('  %Preprocessing:     writing {0}'.format(fname_out))
        ds.to_netcdf(fname_out)

        return ds


class ingest:
    ''' ingest data class '''

    def __init__(self,
                 dir_isimip=None,
                 dir_cmip=None,
                 dir_temp=None,
                 model=None,
                 realization='r1i1p1f1',
                 experiment='historical',
                 day=1,
                 month=1,
                 year=1981):

        self.dir_isimip = dir_isimip
        self.dir_cmip = dir_cmip
        self.dir_temp = dir_temp
        self.model = model
        self.experiment = experiment
        self.realization = realization
        self.day = day
        self.month = month
        self.year = year

    def ingest_data(self):

        # ingest the isimip data
        ds = {}
        for variable in ['tas', 'tasmin', 'tasmax', 'rsds']:
            ds['isimip3b_'+variable] = ingest_isimip(inputdir=self.dir_isimip,
                                                     outputdir=self.dir_temp,
                                                     model=self.model.lower(),
                                                     experiment=self.experiment,
                                                     realization=self.realization,
                                                     parm=variable,
                                                     day=self.day,
                                                     month=self.month,
                                                     year=self.year,
                                                     file_suffix_out='isimip3b').get_ds()
        ds['isimip3b_pr'] = ingest_isimip(inputdir=self.dir_isimip,
                                          outputdir=self.dir_temp,
                                          model=self.model.lower(),
                                          experiment=self.experiment,
                                          realization=self.realization,
                                          parm='pr',
                                          day=self.day,
                                          month=self.month,
                                          interpol={'lat': ds['isimip3b_tas']['lat'],
                                                    'lon': ds['isimip3b_tas']['lon']},
                                          year=self.year,
                                          file_suffix_out='isimip3b').get_ds()

        # ingest cmip6 data
        for variable in ['clt', 'ps', 'uas', 'vas', 'lcl', 'tlapse']:
            ds['cmip6_'+variable] = ingest_isimip(inputdir=self.dir_cmip,
                                                  outputdir=self.dir_temp,
                                                  model=self.model,
                                                  realization=self.realization,
                                                  experiment=self.experiment,
                                                  parm=variable,
                                                  day=self.day,
                                                  month=self.month,
                                                  year=self.year,
                                                  interpol={'lat': ds['isimip3b_tas']['lat'],
                                                            'lon': ds['isimip3b_tas']['lon']},
                                                  lshiftlong=True,
                                                  file_suffix_out='cmip6').get_ds()


def calculate_tlapse(ds=None):

    shape = ds['ta'].shape
    # get level index (pindex) of pressure levels above surface pressure
    pindex = np.argmax(np.array([_ < ds['ps'].to_numpy() for _ in ds['plev'].to_numpy()]), axis=0).flatten()
    nplev = ds['plev'].size
    ngrid = pindex.size

    ishift = 0
    lfound = False
    while not lfound and (ishift+2) < nplev:
        # get ta and zg of the two levels above surface
        lower_level = {
            'ta': np.array(
                [
                    ds['ta']
                    .to_numpy()
                    .reshape(*shape[:-2], -1)[min(pindex[i] + ishift, nplev - 1), i]
                    for i in range(ngrid)
                ]
            ).reshape(shape[-2:]),
            'zg': np.array(
                [
                    ds['zg']
                    .to_numpy()
                    .reshape(*shape[:-2], -1)[min(pindex[i] + ishift, nplev - 1), i]
                    for i in range(ngrid)
                ]
            ).reshape(shape[-2:]),
        }
        upper_level = {
            'ta': np.array(
                [
                    ds['ta']
                    .to_numpy()
                    .reshape(*shape[:-2], -1)[min(pindex[i] + ishift + 1, nplev - 1), i]
                    for i in range(ngrid)
                ]
            ).reshape(shape[-2:]),
            'zg': np.array(
                [
                    ds['zg']
                    .to_numpy()
                    .reshape(*shape[:-2], -1)[min(pindex[i] + ishift + 1, nplev - 1), i]
                    for i in range(ngrid)
                ]
            ).reshape(shape[-2:]),
        }
        # calculate lapse rate
        tlapse = (lower_level['ta'] - upper_level['ta'])/(upper_level['zg'] - lower_level['zg'])
        if not np.any(np.isnan(tlapse)):
            lfound = True
        else:
            print('  %Warning:     still found nan-values on these layers for ta and zg, continue with one level above.')
            ishift += 1
    if ishift+2 >= nplev:
        raise ValueError('  %Error:     no levels found without NaNs for ta and zg, unable to calculate temperature lapse rate (check your inputs)!')

    # remove ta, zg, ps and plev from dataset (ds)
    ds = ds.drop(['ta', 'zg', 'ps', 'plev'])
    # add tlapse to dataset and add attributes
    ds = ds.assign(tlapse=(['lat', 'lon'], tlapse))
    ds['tlapse'].attrs = {'standard_name': 'temperature_lapse_rate',
                          'long_name': 'Near surface temperature lapse rate',
                          'units': 'K/m'}

    return ds


def calculate_lcl(ds=None):

    lcl = (20.0 + ((ds['tas']-273.15)/5.0)) * (100-ds['hurs'])
    ds = ds.drop(['tas', 'hurs'])
    # add tlapse to dataset and add attributes
    ds = ds.assign(lcl=(['lat', 'lon'], lcl.to_numpy()))
    ds['lcl'].attrs = {'standard_name': 'lifted_condensation_level',
                       'long_name': 'Lifted condensation level above surface',
                       'units': 'm'}

    return ds
