#!/usr/bin/env python

# This file is part of chelsa_isimip3b_ba_1km.
#
# chelsa_isimip3b_ba_1km is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# chelsa_isimip3b_ba_1km is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with isimip3b_ba_1km.  If not, see <https://www.gnu.org/licenses/>.

from PySAGA import saga_api
from . import saga_functions
saga_functions.Load_Tool_Libraries(True)


def calculate_windeffect(Coarse, Dem):
    # import the wind files
    Coarse.set('cmip6_uas')
    Coarse.set('cmip6_vas')

    uwind = Coarse.cmip6_uas
    vwind = Coarse.cmip6_vas

    # set coordinate reference system
    saga_functions.set_2_latlong(uwind)
    saga_functions.set_2_latlong(vwind)

    # change to shapefile for projection
    uwind_shp = saga_functions.gridvalues_to_points(uwind)
    vwind_shp = saga_functions.gridvalues_to_points(vwind)

    # Coarse._delete_grid_list_('cmip6_uas')
    # Coarse._delete_grid_list_('cmip6_vas')

    # reproject to mercator projection
    uwind_shp = saga_functions.reproject_shape(uwind_shp)
    vwind_shp = saga_functions.reproject_shape(vwind_shp)

    Dem.set('demproj')
    # multilevel b spline
    uwind_ras = saga_functions.multilevel_B_spline(uwind_shp,
                                                   Dem.demproj, 14)

    vwind_ras = saga_functions.multilevel_B_spline(vwind_shp,
                                                   Dem.demproj, 14)

    direction = saga_functions.polar_coords(uwind_ras,
                                            vwind_ras)

    # dem = change_data_storage(Dem.dem_latlong3)

    windef = saga_functions.windeffect(direction,
                                       Dem.demproj)

    Dem.set('dem_latlong3')
    windef1 = saga_functions.proj_2_latlong(windef,
                                            Dem.dem_latlong3)
    Dem.delete('dem_latlong3')

    # clean up memory
    saga_api.SG_Get_Data_Manager().Delete(uwind_ras)
    saga_api.SG_Get_Data_Manager().Delete(vwind_ras)
    saga_api.SG_Get_Data_Manager().Delete(uwind_shp)
    saga_api.SG_Get_Data_Manager().Delete(vwind_shp)
    saga_api.SG_Get_Data_Manager().Delete(windef)
    # saga_api.SG_Get_Data_Manager().Delete(dem)

    return windef1


def correct_windeffect(windef1, Coarse, Dem, Aux):
    Coarse.set('cmip6_lcl')
    cblev = saga_functions.calc_geopotential(Coarse.cmip6_lcl)
    Coarse.delete('cmip6_lcl')
    saga_functions.set_2_latlong(cblev)
    cblev_shp = saga_functions.gridvalues_to_points(cblev)

    # dem_geo = calc_geopotential(Dem.dem_low)

    Dem.set('dem_low')
    cblev_ras = saga_functions.multilevel_B_spline(cblev_shp,
                                                   Dem.dem_low, 14)

    # correct wind effect by boundary layer height
    pblh = saga_functions.grid_calculatorX(Dem.dem_low,
                                           cblev_ras,
                                           'a+(b/9.80665)')

    Dem.delete('dem_low')
    Dem.set('dem_high')

    dist2bound = saga_functions.calc_dist2bound(Dem.dem_high,
                                                pblh)
    Dem.delete('dem_high')

    maxdist2bound = saga_functions.resample_up(dist2bound,
                                               pblh, 7)

    maxdist2bound2 = saga_functions.invert_dist2bound(dist2bound,
                                                      maxdist2bound)

    wind_cor = saga_functions.grid_calculatorX(maxdist2bound2,
                                               windef1,
                                               '(b/(1-a/9000))')

    # patchgrid = Aux.patch #load_sagadata(INPUT + 'patch.sgrd')

    # exp_index = Aux.expocor #load_sagadata(INPUT + 'expocor.sgrd')

    Aux.set('expocor')
    wind_cor = saga_functions.grid_calculatorX(Aux.expocor,
                                               wind_cor,
                                               'a*b')

    Aux.delete('expocor')
    Aux.set('patch')

    wind_cor = saga_functions.patching(wind_cor,
                                       Aux.patch)

    Aux.delete('patch')
    Aux.set('dummy_W5E5')

    wind_coarse = saga_functions.resample_up(wind_cor,
                                             Aux.dummy_W5E5, 4)

    wind_coarse = saga_functions.closegaps(wind_coarse)

    # downscale precipitation and export
    wind_coarse25 = saga_functions.resample_up(wind_cor,
                                               pblh, 4)

    Aux.delete('dummy_W5E5')

    return wind_cor, wind_coarse25, wind_coarse


def precipitation(wind_cor, wind_coarse25, wind_coarse, Coarse, Aux):
    Coarse.set('isimip3b_pr')

    prec = saga_functions.grid_calculator_simple(Coarse.isimip3b_pr,
                                                 'a*86400')

    Coarse.delete('isimip3b_pr')
    Aux.set('dummy_W5E5')

    prec = saga_functions.resample(prec,
                                   Aux.dummy_W5E5)

    Aux.delete('dummy_W5E5')

    prec = saga_functions.closegaps(prec)

    prec2 = saga_functions.downscale_precip(wind_coarse25,
                                            wind_coarse,
                                            prec,
                                            'prec', 1)

    precip = saga_functions.downscale_precip(wind_cor,
                                             wind_coarse25,
                                             prec2,
                                             'total surface precipitation', 3)

    precip_ccINT = saga_functions.convert2uinteger10(precip,
                                                     'Precipitation')

    # clean memory
    saga_api.SG_Get_Data_Manager().Delete(prec)
    saga_api.SG_Get_Data_Manager().Delete(prec2)
    saga_api.SG_Get_Data_Manager().Delete(precip)

    return precip_ccINT


def cloudcover(Coarse, wind_cor):
    Coarse.set('cmip6_clt')

    # set upper and lower bounds of clt (between 0.001 and 0.999)
    clt = saga_functions.grid_calculator_simple(
        saga_functions.grid_calculator_simple(
            Coarse.cmip6_clt, 'ifelse(a<0.001,a-a+0.001,a*1)'
        ),
        'ifelse(a>0.999,a-a+0.999,a*1)',
    )
    # set upper and lower bounds of clt (between 0.001 and 0.999)
    windnorm = saga_functions.grid_calculator_simple(
        saga_functions.grid_calculator_simple(
            wind_cor, 'ifelse(a<0.001,a-a+0.001,a*1.0)'
        ),
        'ifelse(a>0.999,a-a+0.999,a*1.0)'
    )
    # resample wind_cor to clt grid
    windnorm_coarse = saga_functions.resample_up(windnorm, clt, 4)
    windnorm_coarse_shp = saga_functions.gridvalues_to_points(windnorm_coarse)
    # interplate wind_cor to highres
    windnorm_coarse_high = saga_functions.multilevel_B_spline(windnorm_coarse_shp,
                                                              windnorm.asGrid(), 14)
    # interpolate clt to highres
    clt_shp = saga_functions.gridvalues_to_points(clt)
    clt_highres = saga_functions.multilevel_B_spline(clt_shp,
                                                     windnorm.asGrid(), 14)
    # calculate wind-correction for clt
    clt_cor = saga_functions.grid_calculator3(clt_highres, windnorm.asGrid(), windnorm_coarse_high,
                                              '(b*(a+(c-a)*(1-(c)))/c)')
    # set upper and lower bounds of clt
    clt_out = saga_functions.grid_calculator_simple(
        saga_functions.grid_calculator_simple(clt_cor, 'ifelse(a>1,a-a+1,a*1)'), 'ifelse(a<0,a-a,a*1)'
    )

    saga_api.SG_Get_Data_Manager().Delete(windnorm)
    saga_api.SG_Get_Data_Manager().Delete(windnorm_coarse_high)
    saga_api.SG_Get_Data_Manager().Delete(windnorm_coarse_shp)
    saga_api.SG_Get_Data_Manager().Delete(clt_shp)
    saga_api.SG_Get_Data_Manager().Delete(clt_highres)
    saga_api.SG_Get_Data_Manager().Delete(clt_cor)

    return clt_out


def temperature(Coarse, Dem, Aux, var):
    # calculate temperature
    if var == 'tas':
        Coarse.set('isimip3b_tas')
        temp_in = Coarse.isimip3b_tas
        varname = 'Daily Near-Surface Air Temperature'
    elif var == 'tasmax':
        Coarse.set('isimip3b_tasmax')
        temp_in = Coarse.isimip3b_tasmax
        varname = 'Daily Maximum Near-Surface Air Temperature'
    elif var == 'tasmin':
        Coarse.set('isimip3b_tasmin')
        temp_in = Coarse.isimip3b_tasmin
        varname = 'Daily Minimum Near-Surface Air Temperature'

    Coarse.set('cmip6_tlapse')
    Dem.set('dem_high')
    Dem.set('dem_low')
    temp_highres = saga_functions.lapse_rate_based_downscaling(Dem.dem_high,
                                                               Coarse.cmip6_tlapse,
                                                               Dem.dem_low,
                                                               temp_in)
    Coarse.delete('cmip6_tlapse')
    Dem.delete('dem_high')
    Dem.delete('dem_low')

    Aux.set('landseamask')

    temp_highres = saga_functions.grid_calculator(temp_highres,
                                                  Aux.landseamask,
                                                  'a*b')

    Aux.delete('landseamask')
    Aux.set('oceans')

    temp_ocean = saga_functions.grid_calculatorX(Aux.oceans,
                                                 temp_in,
                                                 'a*b')

    Aux.delete('oceans')

    temp_ocean = saga_functions.closegaps(temp_ocean)

    temp_highres = saga_functions.patchingBspline(temp_highres,
                                                  temp_ocean)

    saga_api.SG_Get_Data_Manager().Delete(temp_ocean)

    temp_out = saga_functions.convert2uinteger10(temp_highres, varname)
    Coarse.delete('isimip3b_' + var)

    # clean memory
    saga_api.SG_Get_Data_Manager().Delete(temp_highres)

    return temp_out


def solar_radiation(Srad, Coarse, cc_fin):
    # calculate solar radiation
    Srad.set('rsds_clim')

    srad_sur = saga_functions.surface_radiation(Srad.rsds_clim,
                                                cc_fin,
                                                'Surface Downwelling Shortwave Radiation')

    Srad.delete('rsds_clim')

    srad_cor = saga_functions.change_data_storage2(srad_sur, 3)

    Coarse.set('isimip3b_rsds')

    srad_resamp = saga_functions.resample_up(srad_cor,
                                             Coarse.isimip3b_rsds, 5)

    srad_bias = saga_functions.grid_calculator(Coarse.isimip3b_rsds,
                                               srad_resamp,
                                               '(a+1)/(b+1)')

    srad_cor2 = saga_functions.grid_calculatorX(srad_cor,
                                                srad_bias,
                                                'a*b')

    srad_cor2 = saga_functions.grid_calculator_simple(srad_cor2, 'a*0.0864')

    Coarse.delete('isimip3b_rsds')

    srad_cor2 = saga_functions.change_data_storage2(srad_cor2, 3)

    # clean memory
    saga_api.SG_Get_Data_Manager().Delete(srad_sur)
    saga_api.SG_Get_Data_Manager().Delete(srad_cor)
    saga_api.SG_Get_Data_Manager().Delete(srad_resamp)
    saga_api.SG_Get_Data_Manager().Delete(srad_bias)

    return srad_cor2
