#!/usr/bin/python
# libraries ####################################################################
import os
import numpy as np
import xarray as xr

import warnings

import sys
# add the directory of this script to allow module search
sys.path.append(os.path.dirname(os.path.abspath(os.path.join(sys.argv[0], '..'))))
from chelsa_isimip3b_ba_1km.functions import auxiliary
################################################################################

# functions ####################################################################
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++


def convert_calendar(data=None, dates_in=None, dates_out=None):

    years_in = np.array([date.year for date in dates_in])
    months_in = np.array([date.month for date in dates_in])
    days_in = np.array([date.day for date in dates_in])
    years_out = np.array([date.year for date in dates_out])
    months_out = np.array([date.month for date in dates_out])
    days_out = np.array([date.day for date in dates_out])

    shape = list(data.shape)
    shape[0] = len(dates_out)
    out = np.empty(shape)

    calendar_in = auxiliary.get_calendar_from_type(type=type(dates_in[0]))
    calendar_out = auxiliary.get_calendar_from_type(type=type(dates_out[0]))
    index_converter = auxiliary.index_engine(years=years_in, months=months_in, days=days_in, calendar=calendar_in)
    tindex_out = list(map(lambda year, month, day: index_converter.get_index(year=year, month=month, day=day,
                                                                             calendar=calendar_out),
                          years_out, months_out, days_out))
    for i in range(len(tindex_out)):
        # calculate mean, filter RuntimeWarnings (layz solution for handling
        # grid points with only nans, e.g. masked data)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', category=RuntimeWarning)
            out[i, ...] = np.nanmean(data[tindex_out[i], ...], axis=0)

    return out
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++


def get_date_range(date_start=None, date_stop=None, calendar_out=None):

    date_type_out = auxiliary.get_type_from_calendar(calendar=calendar_out)
    date_start_out = auxiliary.date_converter(date=date_start, date_type=date_type_out)
    date_stop_out = auxiliary.date_converter(date=date_stop, date_type=date_type_out)
    if date_stop_out.month == 12 and date_stop_out.day == 30:
        date_stop_out = date_type_out(date_stop_out.year, 12, 31, date_stop_out.hour,
                                      date_stop_out.minute, date_stop_out.second, date_stop_out.microsecond)

    dates_out = xr.cftime_range(date_start_out, date_stop_out, calendar=calendar_out)

    return dates_out
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# user configuration ###########################################################


file_in = sys.argv[1]
file_out = sys.argv[2]

calendar_out = sys.argv[3]
################################################################################

data_in = xr.open_dataset(file_in, decode_times=xr.coders.CFDatetimeCoder(use_cftime=True))
data_out = data_in.copy()

if 'time' in data_in.coords:
    times_in = data_in['time'].to_numpy()
    time_attrs = data_in['time'].attrs

    # get new time values
    times_out = auxiliary.get_date_list(year_start=times_in[0].year, year_stop=times_in[-1].year,
                                        calendar_type=calendar_out, ldates=True)

    data_out = data_out.drop_dims('time')

    # add new time and time_bnds
    data_out = data_out.merge({'time': xr.DataArray(times_out, coords={'time': times_out}, dims='time',
                                                    attrs=data_in['time'].attrs)})
    if 'time_bnds' in data_in.variables:
        time_bnds = []
        for i in [0, 1]:
            # date_start and date stop assuming daily values!!
            date_start = type(times_out[0])(times_out[0].year, times_out[0].month, times_out[0].day,   0,  0,  0)
            date_stop = type(times_out[-1])(times_out[-1].year, times_out[-1].month, times_out[-1].day, 23, 59, 59)
            time_bnds.append(get_date_range(date_start=date_start, date_stop=date_stop, calendar_out=calendar_out))
        time_bnds = np.array(time_bnds).T
        data_out = data_out.merge({'time_bnds': xr.DataArray(time_bnds, dims=('time', 'bnds'),
                                                             attrs=data_in['time_bnds'].attrs)})

    # convert calendar of each variable with time dimension
    for variable in [_ for _ in data_in.variables if _ not in ['time', 'time_bnds']]:
        if 'time' in data_in[variable].dims:
            print('  %Status:     converting calendar of {0}'.format(variable))
            dims = data_in[variable].dims
            coords = {coord: data_out[coord] for coord in data_out.coords if coord in dims}
            data_aux = convert_calendar(data=data_in[variable].to_numpy(), dates_in=times_in, dates_out=times_out)
            data_out = data_out.merge({variable: xr.DataArray(data_aux, coords=coords, dims=dims,
                                                              attrs=data_in[variable].attrs)})

data_out.to_netcdf(file_out, unlimited_dims=['time'], encoding={'time': {'units': 'days since 1860-01-01 00:00:00'}})
data_in.close()

