#undef DEBUG_PIK_DIAG

module ocean_pik_diag_mod
  !  
  !<CONTACT EMAIL="Stefan.Petri@pik-potsdam"> Stefan Petri
  !</CONTACT>
  !
  !<OVERVIEW>
  ! Meridional Overturning and related diagnostics. 
  !</OVERVIEW>
  !
  !<DESCRIPTION>
  ! Meridional Overturning and related diagnostics.
  ! This module computes the meridional overturning stream function for the
  ! mass transport, and sends it as vertical 2-dimensional variable
  ! to the diagnostics manager.
  ! This apporach saves signifcant disk space compared to
  ! saving the 3-D variables and calculating the overturning functions in a
  ! postprocessing / visualization step.
  !
  ! Also, this module can calculate overturning according to ocean basins.
  ! Currently that is implemented for the atlantic and the pacific basin.
  ! The selection of basin cells is based on the file INPUT/basin_mask.
  !  The basin-mask convention used at GFDL has 
  !  Southern=1.0,Atlantic=2.0,Pacific=3.0,Arctic=4.0,Indian=5.0
  !
  !
  ! First, the vertical integral for the Eulerian part is defined over
  ! the range from -H (bottom) to the level Z. However, FERRET can not
  ! deal with a vertically upward integration. Therefore, the vertical
  ! integral is splitted in the integral over the whole water column
  ! minus the integral from the surface down to the level k:
  ! -1.*val_atl[i=@sum,k=@sum] + val_atl[i=@sum,k=@rsum]. This part of
  ! the calculation is correct. Second, in the GM part, there is no
  ! need for a numerical integration over Z, because it is already
  ! done analytically. We have only to perform the zonal integral:
  ! val_atl_gm[i=@sum].
  ! The corresponding post-processing script in ferret looks like this:
  !
  ! let atl_mask if basin_mask[D=2] EQ 2 then 1.0 else 0.0
  !
  ! let val_atl    = ty_trans[D=1]*atl_mask
  ! let val_atl_gm = ty_trans_gm[D=1]*atl_mask
  !
  ! let psi_atl=-1.*val_atl[i=@sum,k=@sum]+val_atl[i=@sum,k=@rsum]+val_atl_gm[i=@sum]
  !
  ! Unfortunately, the values for ty_trans and ty_trans_gm are calculated
  ! in very different modules of mom4, and at different times in the ocean
  ! model main loop.
  !
  !   subroutine update_ocean_model()
  !     [..]
  !     call neutral_physics()
  !         call neutral_physicsX() -> ty_trans_gm ! X = A,B,C
  !     [..]
  !     call update_ocean_tracer() -> tracer overturning diag with basin masks
  !     [..]
  !     call ocean_diagnostics()
  !        call ocean_adv_vel_diagnostics()
  !           call transport_on_s() -> ty_trans
  !     [..]
  !   end subroutine
  !
  ! To avoid copying of 3D arrays, this module stores pointers to
  ! the arrays which are allocated in their respective modules.
  ! This module provides an init subroutine to be called from each of the data-contributing 
  ! modules, so that the pointers can be initialised.
  !
  !
  ! Sigh. The ty_trans is calculated in ocean_adv_vel_diag_mod::transport_on_s()
  ! _only_ if the corresponding diagnostic output is enabled in the diag_table.
  ! _And_ it is stored in a temporary field only.
  ! Thus, we copy that ty_trans calculation into here also.
  !
  ! Sigh Deeply. ty_trans_gm is calculated only if neutral_physicsA,
  ! neutral_physicsB or neutral_physicsC is
  ! enabled in file input.nml ,
  ! sections &ocean_nphysics_nml and &ocean_nphysicsB_nml resp. ocean_nphysicsC_nml .
  !
  !</DESCRIPTION>
  !
  !<NAMELIST NAME="ocean_pik_diag_nml">
  !  <DATA NAME="do_pik_diag" TYPE="logical">
  !  Do Meridional Overturning Circulation diagnostics.
  !  </DATA> 
  !</NAMELIST>
  !
  use constants_mod,       only: epsln
  use diag_manager_mod,    only: register_diag_field, send_data, need_data
  use fms_mod,             only: open_namelist_file, check_nml_error, close_file, write_version_number
  use fms_mod,             only: FATAL, stdout, stdlog
  use fms_mod,             only: read_data
  use mpp_mod,             only: mpp_error, mpp_max, mpp_pe
  use mpp_mod,             only: mpp_clock_id, mpp_clock_begin, mpp_clock_end, CLOCK_ROUTINE
  use mpp_domains_mod,     only: mpp_global_field, XUPDATE
  use time_manager_mod,    only: time_type, increment_time
  
  use ocean_domains_mod,    only: get_local_indices
  use ocean_parameters_mod, only: missing_value, rho0r 
  use ocean_types_mod,      only: ocean_domain_type, ocean_grid_type
  use ocean_types_mod,      only: ocean_adv_vel_type, ocean_density_type
  use ocean_types_mod,      only: ocean_prog_tracer_type, ocean_thickness_type
  use ocean_types_mod,      only: ocean_time_type, ocean_time_steps_type
  use ocean_workspace_mod,  only: wrk1 ! , wrk2, wrk1_2d
  
  implicit none
  
  private

  public ocean_pik_diag_init
  public ocean_pik_diag_init_ty_trans_gm
  public ocean_pik_diagnostics 
  public do_pik_diag

#include <ocean_memory.h>

  ! for diagnostics clocks - done by the calling module
  !integer :: id_pik_diag_clock
  
  ! for diag manager
  integer :: id_pik_diag_gmoc  = -1
  integer :: id_pik_diag_amoc  = -1
  integer :: id_pik_diag_pmoc  = -1
#ifdef DEBUG_PIK_DIAG
  ! diags for debugging
  integer :: id_pik_ty_trans = -1
  integer :: id_pik_ty_trans_globalfield = -1
  integer :: id_pik_ty_trans_atl = -1
  integer :: id_pik_ty_trans_atl_sum = -1
  integer :: id_pik_ty_trans_atl_rsum = -1
  integer :: id_pik_ty_trans_atl_upwardrsum = -1
  integer :: id_pik_ty_trans_pac = -1
  integer :: id_pik_ty_trans_pac_sum = -1
  integer :: id_pik_ty_trans_pac_rsum = -1
  integer :: id_pik_ty_trans_pac_upwardrsum = -1
  integer :: id_pik_ty_trans_gm = -1
  integer :: id_pik_ty_trans_gm_globalfield = -1
  integer :: id_pik_ty_trans_gm_atl = -1
  integer :: id_pik_ty_trans_gm_atl_sum = -1
  integer :: id_pik_ty_trans_gm_atl_sum_ijk = -1
  integer :: id_pik_diag_amoc_upward  = -1
  integer :: id_pik_diag_pmoc_upward  = -1
#endif
  
  ! for specifying transport units
  ! can either be Sv or mks
  character(len=32) :: transport_dims ='Sv (10^9 kg/s)' 
  real              :: transport_convert=1.0e-9 
  
  type(ocean_grid_type), pointer   :: Grd =>NULL()
  type(ocean_domain_type), pointer :: Dom =>NULL()
  
  
#ifdef MOM_STATIC_ARRAYS
  real, dimension(jsc:jec,nk)  :: psi
  real, dimension(isc:iec,jsc:jec)  :: basin_mask
  real, dimension(ni, jsc:jec) :: global_basin_mask
  real, dimension(ni, jsc:jec, nk) :: global_tmask
  !real, dimension(ni, jsc:jec, nk) :: global_ty_trans
  !real, dimension(ni, jsc:jec, nk) :: global_ty_trans_gm
#else
  real, dimension(:,:),   allocatable :: psi
  real, dimension(:,:),   allocatable :: basin_mask
  real, dimension(:,:),   allocatable :: global_basin_mask
  real, dimension(:,:,:), allocatable :: global_tmask
  !real, dimension(:,:,:), allocatable :: global_ty_trans
  !real, dimension(:,:,:), allocatable :: global_ty_trans_gm
#endif


  ! Here we store pointers to the data arrays that are allocated in other modules.
  ! At init time, the pointer is assigned.
  real, dimension(:,:,:), pointer :: ty_trans_gm =>NULL();

  logical :: module_is_initialized = .FALSE.
  
  character(len=128) :: version=&
       '$Id: ocean_pik_diag.F90 $'
  character (len=128) :: tagname = &
       '$Name: mom5_siena_08jun2012_smg $'
  
  ! sigh. we need this in other modules to ensure that the required
  ! calculations are actually done.
  ! Notably ocean_nphysicsX_mod::fz_terms()
  logical :: do_pik_diag   = .false.
  
  namelist /ocean_pik_diag_nml/ &
       do_pik_diag


contains

  !#######################################################################
  ! <SUBROUTINE NAME="ocean_pik_diag_init">
  !
  ! <DESCRIPTION>
  ! Initialize the ocean_pik_diag module
  ! </DESCRIPTION>
  !
  subroutine ocean_pik_diag_init(Grid, Domain, Time)
    
    type(ocean_grid_type),    target, intent(in) :: Grid
    type(ocean_domain_type),  target, intent(in) :: Domain
    type(ocean_time_type),            intent(in) :: Time
    
    integer :: ioun, io_status, ierr
    integer :: stdoutunit,stdlogunit 
    stdoutunit=stdout();stdlogunit=stdlog() 
    
    if (module_is_initialized) return
    
    module_is_initialized = .TRUE.
    
    call write_version_number(version, tagname)

#ifdef INTERNAL_FILE_NML
    read (input_nml_file, nml=ocean_pik_diag_nml, iostat=io_status
    ierr = check_nml_error(io_status,'ocean_pik_diag_nml')
#else
    ioun = open_namelist_file()
    read(ioun, ocean_pik_diag_nml, iostat=io_status)
    ierr = check_nml_error(io_status,'ocean_pik_diag_nml')
    call close_file(ioun)
#endif
    write (stdlogunit, ocean_pik_diag_nml)
    write (stdoutunit,'(/)')
    write (stdoutunit, ocean_pik_diag_nml)

    if (.not. do_pik_diag) return

    Dom => Domain
    Grd => Grid
    
#ifndef MOM_STATIC_ARRAYS
    call get_local_indices(Domain, isd, ied, jsd, jed, isc, iec, jsc, jec)
    nk = Grid%nk
    ni = Grid%ni
    allocate (psi(jsc:jec,nk))
    allocate (basin_mask(isc:iec,jsc:jec))
    allocate (global_basin_mask(Grd%ni,jsc:jec))
    allocate (global_tmask(Grid%ni, jsc:jec, nk))
    !allocate (global_ty_trans(Grid%ni, jsc:jec, nk))
    !allocate (global_ty_trans_gm(Grid%ni, jsc:jec, nk))
#endif
    
    !  For reading in a mask that selects regions of the domain 
    !  for performing gyre and overturning diagnostics. 
    !  The basin-mask convention used at GFDL has 
    !  Southern=1.0,Atlantic=2.0,Pacific=3.0,Arctic=4.0,Indian=5.0
    !  Mediterranean=6.0,BlackSea=7.0,HudsonBay=8.0,Baltic=9.0,RedSea=10.0
    call read_data('INPUT/basin_mask','basin_mask',basin_mask,Domain%domain2d)
    call mpp_global_field(Dom%domain2d, basin_mask, global_basin_mask, flags=XUPDATE)
    ! for masking-out non-ocean cells
    call mpp_global_field(Dom%domain2d, Grd%tmask, global_tmask, flags=XUPDATE)

    ! register fields for diagnostic output
    ! TODO: check the axis definition - DONE: seems OK
    id_pik_diag_gmoc = register_diag_field ('ocean_model','pik_gmoc',              &
         Grid%tracer_axes_flux_y(2:3),                                             &
         Time%model_time, 'T-cell j-mass transport global',trim(transport_dims),   &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='global_ocean_meridional_overturning_mass_streamfunction')
    id_pik_diag_amoc = register_diag_field ('ocean_model','pik_amoc',              &
         Grid%tracer_axes_flux_y(2:3),                                             &
         Time%model_time, 'T-cell j-mass transport atlantic',trim(transport_dims), &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='atlantic_ocean_meridional_overturning_mass_streamfunction')
    id_pik_diag_pmoc = register_diag_field ('ocean_model','pik_pmoc',              &
         Grid%tracer_axes_flux_y(2:3),                                             &
         Time%model_time, 'T-cell j-mass transport pacific',trim(transport_dims), &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pacific_ocean_meridional_overturning_mass_streamfunction')

#ifdef DEBUG_PIK_DIAG
    write (*,*) 'pe', mpp_pe(), ': isc ',isc,' iec ',iec
    id_pik_ty_trans = register_diag_field ('ocean_model','pik_ty_trans',   &
         Grid%tracer_axes_flux_y(1:3),                                             &
         Time%model_time, 'pik version should be identical to ty_trans', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans')
    id_pik_ty_trans_globalfield = register_diag_field ('ocean_model','pik_ty_trans_globalfield',   &
         Grid%tracer_axes_flux_y(1:3),                                             &
         Time%model_time, 'pik global field version should be identical to pik_ty_trans', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_globalfield')

    id_pik_ty_trans_atl = register_diag_field ('ocean_model','pik_ty_trans_atl',   &
         Grid%tracer_axes_flux_y(1:3),                                             &
         Time%model_time, 'pik_ty_trans*atl_mask', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_atl')
    id_pik_ty_trans_atl_sum = register_diag_field ('ocean_model','pik_ty_trans_atl_sum',   &
         Grid%tracer_axes_flux_y(2:2), & ! must be 1-D array of length 1, not scalar
         Time%model_time, 'pik_ty_trans_atl[i=@sum,k=@sum]', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_atl_sum')
    id_pik_ty_trans_atl_rsum = register_diag_field ('ocean_model','pik_ty_trans_atl_rsum',   &
         Grid%tracer_axes_flux_y(2:3),                                             &
         Time%model_time, 'pik_ty_trans_atl[i=@sum,k=@rsum]', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_atl_rsum')
    id_pik_ty_trans_atl_upwardrsum = register_diag_field ('ocean_model','pik_ty_trans_atl_upwardrsum',   &
         Grid%tracer_axes_flux_y(2:3),                                             &
         Time%model_time, 'pik_ty_trans_atl[i=@sum,k=@upwardrsum]', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_atl_upwardrsum')

    id_pik_ty_trans_pac = register_diag_field ('ocean_model','pik_ty_trans_pac',   &
         Grid%tracer_axes_flux_y(1:3),                                             &
         Time%model_time, 'pik_ty_trans*pac_mask', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_pac')
    id_pik_ty_trans_pac_sum = register_diag_field ('ocean_model','pik_ty_trans_pac_sum',   &
         Grid%tracer_axes_flux_y(2:2), & ! must be 1-D array of length 1, not scalar
         Time%model_time, 'pik_ty_trans_pac[i=@sum,k=@sum]', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_pac_sum')
    id_pik_ty_trans_pac_rsum = register_diag_field ('ocean_model','pik_ty_trans_pac_rsum',   &
         Grid%tracer_axes_flux_y(2:3),                                             &
         Time%model_time, 'pik_ty_trans_pac[i=@sum,k=@rsum]', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_pac_rsum')
    id_pik_ty_trans_pac_upwardrsum = register_diag_field ('ocean_model','pik_ty_trans_pac_upwardrsum',   &
         Grid%tracer_axes_flux_y(2:3),                                             &
         Time%model_time, 'pik_ty_trans_pac[i=@sum,k=@upwardrsum]', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_pac_upwardrsum')

    id_pik_ty_trans_gm = register_diag_field ('ocean_model','pik_ty_trans_gm',   &
         Grid%tracer_axes_flux_y(1:3),                                             &
         Time%model_time, 'pik version should be identical to ty_trans_gm', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_gm')
    id_pik_ty_trans_gm_globalfield = register_diag_field ('ocean_model','pik_ty_trans_gm_globalfield',   &
         Grid%tracer_axes_flux_y(1:3),                                             &
         Time%model_time, 'pik global field version should be identical to pik_ty_trans_gm', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_gm_globalfield')
    id_pik_ty_trans_gm_atl = register_diag_field ('ocean_model','pik_ty_trans_gm_atl',   &
         Grid%tracer_axes_flux_y(1:3),                                             &
         Time%model_time, 'ty_trans_gm*atl_mask', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_gm_atl')
    id_pik_ty_trans_gm_atl_sum = register_diag_field ('ocean_model','pik_ty_trans_gm_atl_sum',   &
         Grid%tracer_axes_flux_y(2:3),                                             &
         Time%model_time, 'pik_ty_trans_gm_atl[i=@sum] via sum()', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_gm_atl_sum')
    id_pik_ty_trans_gm_atl_sum_ijk = register_diag_field ('ocean_model','pik_ty_trans_gm_atl_sum_ijk',   &
         Grid%tracer_axes_flux_y(2:3),                                             &
         Time%model_time, 'pik_ty_trans_gm_atl[i=@sum] via loops', trim(transport_dims),                &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pik_ty_trans_gm_atl_sum_ijk')

    id_pik_diag_amoc_upward = register_diag_field ('ocean_model','pik_amoc_upward',              &
         Grid%tracer_axes_flux_y(2:3),                                             &
         Time%model_time, 'T-cell j-mass transport atlantic calculated direct upward',trim(transport_dims), &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='atlantic_ocean_meridional_overturning_mass_streamfunction_upward')

    id_pik_diag_pmoc_upward = register_diag_field ('ocean_model','pik_pmoc_upward',              &
         Grid%tracer_axes_flux_y(2:3),                                             &
         Time%model_time, 'T-cell j-mass transport pacific calculated direct upward',trim(transport_dims), &
         missing_value=missing_value, range=(/-1e9,1e9/),                          &
         standard_name='pacific_ocean_meridional_overturning_mass_streamfunction_upward')
#endif
   
    ! set ids for clocks - done by the calling module
    !id_pik_diag_clock     = mpp_clock_id('(Ocean pik_diag)'    ,grain=CLOCK_ROUTINE)
    
  end subroutine ocean_pik_diag_init
  ! </SUBROUTINE>



  !#######################################################################
  ! <SUBROUTINE NAME="ocean_pik_diag_init_ty_trans_gm">
  !
  ! <DESCRIPTION>
  ! Initialize the pointer to ty_trans_gm
  ! </DESCRIPTION>
  !
  subroutine ocean_pik_diag_init_ty_trans_gm(ty_trans_gm_in)
    real, dimension(:,:,:), target, intent(in) :: ty_trans_gm_in
    !if (.not. do_pik_diag) return
    if (ASSOCIATED(ty_trans_gm)) then
       call mpp_error(FATAL, &
            '==>Error from ocean_pik_diag.F90::ocean_pik_diag_init_ty_trans_gm : ty_trans_gm already assigned')
    endif
    ty_trans_gm => ty_trans_gm_in
  end subroutine ocean_pik_diag_init_ty_trans_gm
  ! </SUBROUTINE>
  

  !#######################################################################
  ! <SUBROUTINE NAME="calc_psi">
  !
  ! <DESCRIPTION>
  ! First, the vertical integral for the Eulerian part is defined over
  ! the range from -H (bottom) to the level Z. However, FERRET can not
  ! deal with a vertically upward integration. Therefore, the vertical
  ! integral is splitted in the integral over the whole water column
  ! minus the integral from the surface down to the level k:
  ! -1.*val_atl[i=@sum,k=@sum] + val_atl[i=@sum,k=@rsum]. This part of
  ! the calculation is correct. Second, in the GM part, there is no
  ! need for a numerical integration over Z, because it is already
  ! done analytically. We have only to perform the zonal integral:
  ! val_atl_gm[i=@sum].
  !
  !  let psi = -1.*ty_trans[i=@sum,k=@sum] + ty_trans[i=@sum,k=@rsum] + ty_trans_gm[i=@sum] 
  ! </DESCRIPTION>
  subroutine calc_psi(global_ty_trans, global_ty_trans_gm, &
#ifdef DEBUG_PIK_DIAG
       Time, &
#endif
       psi)
    real, dimension(ni,jsc:jec,nk), intent(in) :: global_ty_trans, global_ty_trans_gm
#ifdef DEBUG_PIK_DIAG
    type(ocean_time_type),        intent(in) :: Time    
#endif
    real, dimension(jsc:jec,nk), intent(out) :: psi

    integer ::  i, j, k

    ! for the following three arrays, the j-dimension is needed only for
    ! code readability and debugging output...
    real, dimension(jsc:jec) :: ty_trans_sum          ! ty_trans[i=@sum,k=@sum]
    real, dimension(jsc:jec,0:nk):: ty_trans_rsum   ! ty_trans[i=@sum,k=@rsum]
    real, dimension(jsc:jec,nk) :: ty_trans_gm_sum    ! ty_trans_gm[i=@sum]
#ifdef DEBUG_PIK_DIAG
    logical :: used
#endif

    ty_trans_rsum = 0.0
    ty_trans_sum = 0.0
    ty_trans_gm_sum(:,:) = sum(global_ty_trans_gm,1) ! sum over 1st dimension
    do j=jsc,jec
       do k=1,nk
          do i=1,ni
             ty_trans_sum(j) = ty_trans_sum(j) + global_ty_trans(i,j,k)
             ty_trans_rsum(j,k) = ty_trans_rsum(j,k) + global_ty_trans(i,j,k)
             !ty_trans_gm_sum(j,k) = ty_trans_gm_sum(j,k) + global_ty_trans_gm(i,j,k)
          end do
          ty_trans_rsum(j,k) = ty_trans_rsum(j,k) + ty_trans_rsum(j,k-1)
       end do
       do k=1,nk
          psi(j,k) = - ty_trans_sum(j) + ty_trans_rsum(j,k) + ty_trans_gm_sum(j,k)
       end do
    end do

#ifdef DEBUG_PIK_DIAG
    used = send_data(id_pik_ty_trans_atl_sum, ty_trans_sum, Time%model_time)
    used = send_data(id_pik_ty_trans_gm_atl_sum, ty_trans_gm_sum, Time%model_time)
    used = send_data(id_pik_ty_trans_atl_rsum, ty_trans_rsum(:,1:nk), Time%model_time)

    ! ferret does the rsum top-to-bottom, thus we need to calculate and subtract
    ! the 2-D-sum term ty_trans[i=@sum,k=@sum] . That is mimiced above.
    ! Here we try to do the bottom-to-top running sum, negated, and avoid the additional 2-D-sum.
    ! One major problem seems to the influence of missing-values.
    ! Ferret obeyes those, apparently missing+x = missing.
    ! That makes top-to-bottom rsum non-commutative with bottom-to-top revers-rsum,
    ! because the missing-values propagate along the summing direction.
    ! And there are more missing values near the bottom of the ocean, than at top.
    ! sigh.
    ty_trans_rsum = 0.0
    do j=jsc,jec
       do k=nk-1,1,-1
          do i=1,ni
             ty_trans_rsum(j,k) = ty_trans_rsum(j,k) - global_ty_trans(i,j,k+1)
          end do
          ty_trans_rsum(j,k) = ty_trans_rsum(j,k) - ty_trans_rsum(j,k+1)
       end do
    end do
    ty_trans_gm_sum(:,:) = sum(global_ty_trans_gm,1) ! sum over 1st dimension 
    !psi = ty_trans_rsum(:,1:nk) + ty_trans_gm_sum

    used = send_data(id_pik_ty_trans_atl_upwardrsum, ty_trans_rsum(:,1:nk), Time%model_time)
    !used = send_data(id_pik_ty_trans_gm_atl_sum_ijk, ty_trans_gm_sum, Time%model_time)
    used = send_data(id_pik_diag_amoc_upward, ty_trans_rsum(:,1:nk) + ty_trans_gm_sum, Time%model_time)
#endif
  end subroutine calc_psi
  ! </SUBROUTINE>


  !#######################################################################
  ! <SUBROUTINE NAME="diag_basin">
  !
  ! <DESCRIPTION>
  ! calculate and send to diag_manager the overturning diagnostics for one or two ocean basins
  ! </DESCRIPTION
  subroutine diag_basin(Time, basinnr, basinnr2, id_pik_diag_basin, &
#ifdef DEBUG_PIK_DIAG
       id_pik_ty_trans_basin, id_pik_ty_trans_gm_basin, &
#endif
       global_ty_trans, global_ty_trans_gm)
    type(ocean_time_type),            intent(in) :: Time
    real,                             intent(in) :: basinnr, basinnr2
    integer,                          intent(in) :: id_pik_diag_basin
#ifdef DEBUG_PIK_DIAG
    integer,                          intent(in) :: id_pik_ty_trans_basin, id_pik_ty_trans_gm_basin
#endif
    real, dimension(ni, jsc:jec, nk), intent(in) :: global_ty_trans, global_ty_trans_gm

    integer :: i, j
    logical :: used
    real, dimension(ni, jsc:jec, nk) :: basin_ty_trans, basin_ty_trans_gm

    if (id_pik_diag_basin <= 0) return
    
    basin_ty_trans = global_ty_trans
    basin_ty_trans_gm = global_ty_trans_gm
    do j=jsc,jec
       do i=1,ni
          if ( (global_basin_mask(i,j) .ne. basinnr) .and. &
               (global_basin_mask(i,j) .ne. basinnr2)) then
             basin_ty_trans(i,j,:) = 0.0
             basin_ty_trans_gm(i,j,:) = 0.0
          end if
       end do
    end do
#ifdef DEBUG_PIK_DIAG
    if (id_pik_ty_trans_basin > 0) &
         used = send_data (id_pik_ty_trans_basin, basin_ty_trans(isc:iec,:,:),     &
         Time%model_time, rmask=global_tmask(isc:iec,:,:))
    if (id_pik_ty_trans_gm_basin > 0) &
         used = send_data (id_pik_ty_trans_gm_basin, basin_ty_trans_gm(isc:iec,:,:), &
         Time%model_time, rmask=global_tmask(isc:iec,:,:))
#endif
    call calc_psi(basin_ty_trans, basin_ty_trans_gm, &
#ifdef DEBUG_PIK_DIAG
         Time, &
#endif
         psi)
    used = send_data(id_pik_diag_basin, psi(:,:), Time%model_time) !, is_in=jsc, js_in=1, ie_in=jec, je_in=nk)
  end subroutine diag_basin
  ! </SUBROUTINE>
  

  !#######################################################################
  ! <SUBROUTINE NAME="ocean_pik_diagnostics">
  !
  ! <DESCRIPTION>
  ! Call PIK diagnostics 
  ! </DESCRIPTION>
  subroutine ocean_pik_diagnostics(Time, Adv_vel)
    type(ocean_time_type),        intent(in) :: Time
    type(ocean_adv_vel_type),     intent(in) :: Adv_vel

    integer :: i, j, k
    logical :: used

    ! make these arrays local, for global, atlantic basin, pacific basin
    real, dimension(ni, jsc:jec, nk) :: global_ty_trans
    real, dimension(ni, jsc:jec, nk) :: global_ty_trans_gm


    if (.not. do_pik_diag) return
    if (-1 == id_pik_diag_amoc .and. -1 == id_pik_diag_gmoc .and. -1 == id_pik_diag_pmoc) return

    !call mpp_clock_begin(id_pik_diag_clock)

    if (.not. module_is_initialized) then 
       call mpp_error(FATAL, &
            '==>Error from ocean_pik_diag.F90::ocean_pik_diag_init_ty_trans_gm : needs initialization')
    endif
    if (.not. ASSOCIATED(ty_trans_gm)) then
       write (stdout(),*) 'Error: ocean_pik_diag.F90::ocean_pik_diagnostics : ty_trans_gm not assigned.', &
            ' Check that neutral_physicsA, neutral_physicsB or neutral_physicsC is enabled in file input.nml ,', &
            ' sections &ocean_nphysics_nml and &neutral_physicsA_nml resp. &ocean_nphysicsB_nml resp. ocean_nphysicsC_nml'
       call mpp_error(FATAL, &
            '==>Error from ocean_pik_diag.F90::ocean_pik_diagnostics : ty_trans_gm not assigned')
    endif

    ! calculation of ty_trans
    ! copied from ocean_adv_vel_diag.F90 :: transport_on_s()
    ! but here mask out non-ocean cells with Grd%tmask
    do k=1,nk
       do j=jsc,jec
          do i=isc,iec
             wrk1(i,j,k) = Adv_vel%vhrho_nt(i,j,k)*Grd%dxtn(i,j)*transport_convert*Grd%tmask(i,j,k)
          enddo
       enddo
    enddo
#ifdef DEBUG_PIK_DIAG
    if (id_pik_ty_trans > 0) &
         used = send_data (id_pik_ty_trans, wrk1(:,:,:), &
         Time%model_time, rmask=Grd%tmask(:,:,:), &
         is_in=isc, js_in=jsc, ks_in=1, ie_in=iec, je_in=jec, ke_in=nk)
    if (id_pik_ty_trans_gm > 0) &
         used = send_data (id_pik_ty_trans_gm, ty_trans_gm(:,:,:), &
         Time%model_time, rmask=Grd%tmask(:,:,:), &
         is_in=isc, js_in=jsc, ks_in=1, ie_in=iec, je_in=jec, ke_in=nk)
#endif

    ! Sigh. I think we really need this on the root PE only, but here all PEs get it.
    ! On the other hand, mppnccombine seems to work fine, maybe even requires,
    ! with having redundant slices in x direction.
    call mpp_global_field(Dom%domain2d, wrk1,        global_ty_trans,    flags=XUPDATE) 
    call mpp_global_field(Dom%domain2d, ty_trans_gm, global_ty_trans_gm, flags=XUPDATE)

    ! mask out any non-ocean cells. Dont do that on the local ty_trans_gm,
    ! because that is only ``borrowed'' via pointer association from neutral physics module
    where (global_tmask .eq. 0.0) global_ty_trans_gm = 0.0

#ifdef DEBUG_PIK_DIAG
    if (id_pik_ty_trans_globalfield > 0) &
         used = send_data (id_pik_ty_trans_globalfield, global_ty_trans(isc:iec,:,:), &
         Time%model_time, rmask=global_tmask(isc:iec,:,:))
    if (id_pik_ty_trans_gm_globalfield > 0) &
         used = send_data (id_pik_ty_trans_gm_globalfield, global_ty_trans_gm(isc:iec,:,:), &
         Time%model_time, rmask=global_tmask(isc:iec,:,:))
#else
    ! for debugging, call calc_psi() only once, for atlantic basin
    ! global overturning
    if (id_pik_diag_gmoc > 0) then
       call calc_psi(global_ty_trans, global_ty_trans_gm, psi)
       used = send_data(id_pik_diag_gmoc, psi(:,:), Time%model_time) !, is_in=jsc, js_in=1, ie_in=jec, je_in=nk)
    end if
#endif

    ! atlantic overturning:
    ! Willem suggested to include also the arctic region
    call diag_basin(Time, 2.0, 4.0, id_pik_diag_amoc, &
#ifdef DEBUG_PIK_DIAG
         id_pik_ty_trans_atl, id_pik_ty_trans_gm_atl, &
#endif
         global_ty_trans, global_ty_trans_gm)
!    if (id_pik_diag_amoc > 0) then
!       amoc_ty_trans = global_ty_trans
!       amoc_ty_trans_gm = global_ty_trans_gm
!       do j=jsc,jec
!          do i=1,ni
!             if (global_basin_mask(i,j) .ne. 2.0) then
!                amoc_ty_trans(i,j,:) = 0.0
!                amoc_ty_trans_gm(i,j,:) = 0.0
!             end if
!          end do
!       end do
!#ifdef DEBUG_PIK_DIAG
!       if (id_pik_ty_trans_atl > 0) &
!            used = send_data (id_pik_ty_trans_atl, amoc_ty_trans(isc:iec,:,:),     &
!            Time%model_time, rmask=global_tmask(isc:iec,:,:))
!       if (id_pik_ty_trans_gm_atl > 0) &
!            used = send_data (id_pik_ty_trans_gm_atl, amoc_ty_trans_gm(isc:iec,:,:), &
!            Time%model_time, rmask=global_tmask(isc:iec,:,:))
!#endif
!       call calc_psi(amoc_ty_trans, amoc_ty_trans_gm, &
!#ifdef DEBUG_PIK_DIAG
!            Time, &
!#endif
!            psi)
!       used = send_data(id_pik_diag_amoc, psi(:,:), Time%model_time) !, is_in=jsc, js_in=1, ie_in=jec, je_in=nk)
!    end if ! id_pik_diag_amoc > 0

    ! pacific overturning
    call diag_basin(Time, 3.0, 3.0, id_pik_diag_pmoc, &
#ifdef DEBUG_PIK_DIAG
         id_pik_ty_trans_pac, id_pik_ty_trans_gm_pac, &
#endif
         global_ty_trans, global_ty_trans_gm)

    !call mpp_clock_end(id_pik_diag_clock)
  end subroutine ocean_pik_diagnostics
  ! </SUBROUTINE>

end module ocean_pik_diag_mod
