import tools
import numpy as np 
import netCDF4 as nc 
import os
import matplotlib.pyplot as plt 
from scipy.stats.stats import pearsonr
import pdb
import xarray as xr
import new_tools 

'Delta tempe land/sea as a function of precip'

'----------------------------------------------------------------------------'
'SELECT DOMAIN'
'----------------------------------------------------------------------------'
lat_min, lat_max = 50, 0 
#lon_min, lon_max = 40, 130
lon_min, lon_max = 60, 130
'----------------------------------------------------------------------------'

'----------------------------------------------------------------------------'
'DATA RECOVERY'
'----------------------------------------------------------------------------'
os.chdir(os.pardir)

list_simu = tools.list_simu_08

nc_files  = new_tools.get_clfiles(list_simu, 'jjs')
precip    = new_tools.get_var2d('precip_mm_srf', nc_files)
temp      = new_tools.get_var2d('temp_mm_srf',   nc_files)
lats,  lons  = nc_files[0]. variables['latitude'][:], nc_files[0]. variables['longitude'][:]
new_tools.close_files(nc_files)


nc_files_may  = new_tools.get_clfiles(list_simu, 'may')
precip_may    = new_tools.get_var2d('precip_mm_srf', nc_files_may)
temp_may      = new_tools.get_var2d('temp_mm_srf',   nc_files_may)
new_tools.close_files(nc_files_may)


nc_files_jun  = new_tools.get_clfiles(list_simu, 'jun')
precip_jun    = new_tools.get_var2d('precip_mm_srf', nc_files_jun)
temp_jun      = new_tools.get_var2d('temp_mm_srf',   nc_files_jun)
new_tools.close_files(nc_files_jun)


nc_files_jul  = new_tools.get_clfiles(list_simu, 'jul')
precip_jul    = new_tools.get_var2d('precip_mm_srf', nc_files_jul)
temp_jul      = new_tools.get_var2d('temp_mm_srf',   nc_files_jul)
new_tools.close_files(nc_files_jul)



nc_files_aug  = new_tools.get_clfiles(list_simu, 'aug')
precip_aug    = new_tools.get_var2d('precip_mm_srf', nc_files_aug)
temp_aug      = new_tools.get_var2d('temp_mm_srf',   nc_files_aug)
new_tools.close_files(nc_files_aug)




nc_ofiles      = new_tools.get_clofiles(list_simu, 'jjs')
SST            = new_tools.get_var2d('temp_mm_uo',   nc_ofiles)
latso,  lonso  = nc_ofiles[0]. variables['latitude'][:], nc_ofiles[0]. variables['longitude'][:]
new_tools.close_files(nc_ofiles)



nc_ofiles_may      = new_tools.get_clofiles(list_simu, 'may')
SST_may            = new_tools.get_var2d('temp_mm_uo',   nc_ofiles_may)
new_tools.close_files(nc_ofiles_may)


nc_ofiles_jun      = new_tools.get_clofiles(list_simu, 'jun')
SST_jun            = new_tools.get_var2d('temp_mm_uo',   nc_ofiles_jun)
new_tools.close_files(nc_ofiles_jun)


nc_ofiles_jul      = new_tools.get_clofiles(list_simu, 'jul')
SST_jul            = new_tools.get_var2d('temp_mm_uo',   nc_ofiles_jul)
new_tools.close_files(nc_ofiles_jul)



nc_ofiles_aug      = new_tools.get_clofiles(list_simu, 'aug')
SST_aug            = new_tools.get_var2d('temp_mm_uo',   nc_ofiles_aug)
new_tools.close_files(nc_ofiles_aug)


nc_masks      = []

for simu in list_simu :
	nc_masks     .append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/inidata/' + simu + '.qrparm.mask.nc'))
	

masks         = []

for i in range(len(nc_files)):
	masks     .append(nc_masks     [i].variables['lsm']          [:])
	
'''
nc_monsoon_msk = []

for simu in list_simu :
	nc_monsoon_msk.append(nc.Dataset('/user/work/xk22684/work/data/msk_monsoon_bbc8/' + simu + '_monsoon_msk.nc'))
	

monsoon_msk    = []

for i in range(len(nc_monsoon_msk)): 

	monsoon_msk.append(nc_monsoon_msk[i].variables['precip_monsoon'][:])
'''


nc_file_monsoon  = nc.Dataset('/user/work/xk22684/work/data/monsoon_msk_regridded.nc')
precip_2mm_msk   = nc_file_monsoon.variables['precip_monsoon'][:]


os.chdir('./codes')
'----------------------------------------------------------------------------'

#pdb.set_trace()
'----------------------------------------------------------------------------'
'CUTTING OF DATA'
'----------------------------------------------------------------------------'

id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max  = tools.select_domain(lats, lons, lat_min, lat_max, lon_min, lon_max)
#id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max  = tools.select_domain(lats, lons, 30, lat_max, lon_min, lon_max)
id_lat_mino, id_lat_maxo, id_lon_mino, id_lon_maxo = tools.select_domain(latso, lonso, lat_max, lat_min, lon_min, lon_max)

precip_cut     = new_tools.cut(precip,     id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max)
precip_cut_may = new_tools.cut(precip_may, id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max)
precip_cut_jun = new_tools.cut(precip_jun, id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max)
precip_cut_jul = new_tools.cut(precip_jul, id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max)
precip_cut_aug = new_tools.cut(precip_aug, id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max)


temp_cut       = new_tools.cut(temp,       id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max)
temp_cut_may   = new_tools.cut(temp_may,   id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max)
temp_cut_jun   = new_tools.cut(temp_jun,   id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max)
temp_cut_jul   = new_tools.cut(temp_jul,   id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max)
temp_cut_aug   = new_tools.cut(temp_aug,   id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max)


masks_cut      = tools.cut(masks,      id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max)


SST_cut        = new_tools.cut(SST,        id_lat_mino, id_lat_maxo, id_lon_mino, id_lon_maxo)
SST_cut_may    = new_tools.cut(SST_may,    id_lat_mino, id_lat_maxo, id_lon_mino, id_lon_maxo)
SST_cut_jun    = new_tools.cut(SST_jun,    id_lat_mino, id_lat_maxo, id_lon_mino, id_lon_maxo)
SST_cut_jul    = new_tools.cut(SST_jul,    id_lat_mino, id_lat_maxo, id_lon_mino, id_lon_maxo)
SST_cut_aug    = new_tools.cut(SST_aug,    id_lat_mino, id_lat_maxo, id_lon_mino, id_lon_maxo)



precip_2mm_msk_cut = precip_2mm_msk[id_lat_min : id_lat_max + 1, id_lon_min : id_lon_max + 1]
#precip_2mm_msk_cut = np.where(precip_2mm_msk_cut == 2, 1, np.nan)
precip_2mm_msk_cut_masked = np.ma.masked_where(precip_2mm_msk_cut != 2, precip_2mm_msk_cut)
precip_2mm_msk_cut_masked[~precip_2mm_msk_cut_masked.mask] = 1
'''
monsoon_msk_cut    = new_tools.cut(monsoon_msk, id_lat_min,  id_lat_max,  id_lon_min,  id_lon_max)
monsoon_msk_cut    = [np.ma.masked_where(monsoon_msk_cut[i] != 2, precip_2mm_msk_cut) for i in range(len(monsoon_msk_cut))]

for i in range(len(monsoon_msk_cut)):

	monsoon_msk_cut[i][~monsoon_msk_cut[i].mask] = 1
'''
'----------------------------------------------------------------------------'

#pdb.set_trace()

'----------------------------------------------------------------------------'
'MASKING AND MEANS'
'----------------------------------------------------------------------------'
precip_land      = tools.mask_application(precip_cut,     masks_cut, 'land')
precip_land_may  = tools.mask_application(precip_cut_may, masks_cut, 'land')
precip_land_jun  = tools.mask_application(precip_cut_jun, masks_cut, 'land')
precip_land_jul  = tools.mask_application(precip_cut_jul, masks_cut, 'land')
precip_land_aug  = tools.mask_application(precip_cut_aug, masks_cut, 'land')

precip_land2     = [precip_land[i]     * precip_2mm_msk_cut for i in range(len(precip_land))]
precip_land_may2 = [precip_land_may[i] * precip_2mm_msk_cut for i in range(len(precip_land))]
'''
precip_land3     = [precip_land[i]     * monsoon_msk_cut[i] for i in range(len(precip_land))]
precip_land_may3 = [precip_land_may[i] * monsoon_msk_cut[i] for i in range(len(precip_land))]
'''


temp_land        = tools.mask_application(temp_cut,       masks_cut, 'land')
temp_land_may    = tools.mask_application(temp_cut_may,   masks_cut, 'land')
temp_land_jun    = tools.mask_application(temp_cut_jun,   masks_cut, 'land')
temp_land_jul    = tools.mask_application(temp_cut_jul,   masks_cut, 'land')
temp_land_aug    = tools.mask_application(temp_cut_aug,   masks_cut, 'land')

#temp_land2       = [temp_land[i] * precip_2mm_msk_cut for i in range(len(temp_land))]

precip_land_av   = tools.averages(precip_land,     conv_1 = 86400)

precip_av        = tools.averages(precip_cut,      conv_1 = 86400)
precip_av_may    = tools.averages(precip_cut_may,  conv_1 = 86400)
precip_av_jun    = tools.averages(precip_cut_jun,  conv_1 = 86400)
precip_av_jul    = tools.averages(precip_cut_jul,  conv_1 = 86400)
precip_av_aug    = tools.averages(precip_cut_aug,  conv_1 = 86400)

precip_av2       = tools.averages(precip_land2,      conv_1 = 86400)
precip_av_may2   = tools.averages(precip_land_may2,  conv_1 = 86400)
'''
precip_av3       = tools.averages(precip_land3,      conv_1 = 86400)
precip_av_may3   = tools.averages(precip_land_may3,  conv_1 = 86400)
'''

precip_av_JJA    = [(precip_av_jun[i] + precip_av_jul[i] + precip_av_aug[i])/3 for i in range(len(list_simu))]
precip_av_MJJS   = [(4*precip_av[i]   + precip_av_may[i])/5 for i in range(len(list_simu))]
precip_av_MJJS2  = [(4*precip_av2[i]  + precip_av_may2[i])/5 for i in range(len(list_simu))]
#precip_av_MJJS3  = [(4*precip_av3[i]  + precip_av_may3[i])/5 for i in range(len(list_simu))]


temp_land_av     = tools.averages(temp_land,       conv_2 = -273.15)
temp_land_av_may = tools.averages(temp_land_may,   conv_2 = -273.15)
temp_land_av_jun = tools.averages(temp_land_jun,   conv_2 = -273.15)
temp_land_av_jul = tools.averages(temp_land_jul,   conv_2 = -273.15)
temp_land_av_aug = tools.averages(temp_land_aug,   conv_2 = -273.15)

temp_av_JJA      = [(temp_land_av_jun[i] + temp_land_av_jul[i] + temp_land_av_aug[i])/3 for i in range(len(list_simu))]
temp_av_MJJS     = [(4*temp_land_av[i] + temp_land_av_may[i])/5 for i in range(len(list_simu))]


#temp_land_av2     = tools.averages(temp_land2,       conv_2 = -273.15)


SST_av           = tools.averages(SST_cut)
SST_av_may       = tools.averages(SST_cut_may)
SST_av_jun       = tools.averages(SST_cut_jun)
SST_av_jul       = tools.averages(SST_cut_jul)
SST_av_aug       = tools.averages(SST_cut_aug)

SST_av_MJJS      = [(4*SST_av[i] + SST_av_may[i])/5 for i in range(len(SST_av))]
SST_av_JJA       = [(SST_av_jun[i] + SST_av_jul[i] + SST_av_aug[i])/3 for i in range(len(SST_av))]

#pdb.set_trace()


deltaT          = np.asarray(temp_land_av) - np.asarray(SST_av)
#deltaT2         = np.asarray(temp_land_av2) - np.asarray(SST_av)
deltaT3         = np.asarray(temp_av_MJJS) - np.asarray(SST_av_MJJS)
#deltaT4         = np.asarray(temp_av_JJA) - np.asarray(SST_av_JJA)
'----------------------------------------------------------------------------'

print('t-score = ', pearsonr(deltaT3, precip_av_MJJS2))


'----------------------------------------------------------------------------'
'PLOTS'
'----------------------------------------------------------------------------'
plt.scatter(deltaT3, precip_av_MJJS2, label = 'Simulations between 800k to 0k')
plt.grid()
plt.xlabel('$\Delta$T (land-sea) degree C')
plt.ylabel('Precipitation in mm/day')
plt.title('$\Delta$T (land-sea) as a function of precip (MJJAS) \n precip only monsoon area defined with GPCC data \n - All Forcings')
plt.legend(loc = 'lower right')
plt.savefig('/user/work/xk22684/work/fig/fig_paper_monsoon/corr_rev.pdf')
plt.show()
plt.close()


'----------------------------------------------------------------------------'

pdb.set_trace()





'''
nc_files      = []
nc_files_may  = []
nc_files_jun  = []
nc_files_jul  = []
nc_files_aug  = []

nc_ofiles     = []
nc_ofiles_may = []

nc_masks      = []


for simu in list_simu :

	nc_files     .append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'a.pdcljjs.nc'))

	nc_files_may .append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'a.pdclmay.nc'))
	nc_files_jun .append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'a.pdcljun.nc')) 
	nc_files_jul .append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'a.pdcljul.nc')) 
	nc_files_aug .append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'a.pdclaug.nc')) 

	nc_masks     .append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/inidata/' + simu + '.qrparm.mask.nc'))




precip     = []
precip_may = []

temp       = []
temp_may   = []

SST        = []
SST_may    = []

masks      = []


for i in range(len(nc_files)):

	precip    .append(nc_files     [i].variables['precip_mm_srf'][:])
	precip_may.append(nc_files_may [i].variables['precip_mm_srf'][:])
	precip_jun.append(nc_files_jun [i].variables['precip_mm_srf'][:])
	precip_jul.append(nc_files_jul [i].variables['precip_mm_srf'][:])
	precip_aug.append(nc_files_aug [i].variables['precip_mm_srf'][:])

	masks     .append(nc_masks     [i].variables['lsm']          [:])
	
	temp      .append(nc_files     [i].variables['temp_mm_srf']  [:])  
	temp_may  .append(nc_files_may [i].variables['temp_mm_srf']  [:])  
	temp_jun  .append(nc_files_may [i].variables['temp_mm_srf']  [:])  
	temp_jul  .append(nc_files_may [i].variables['temp_mm_srf']  [:])  
	temp_aug  .append(nc_files_may [i].variables['temp_mm_srf']  [:])  


lats,  lons  = nc_files[0]. variables['latitude'][:], nc_files[0]. variables['longitude'][:]

new_tools.close_files(nc_files)
new_tools.close_files(nc_files_may)
new_tools.close_files(nc_files_jun)
new_tools.close_files(nc_files_jul)
new_tools.close_files(nc_files_aug)

new_tools.close_files(nc_masks)

pdb.set_trace()


for simu in list_simu :

	nc_ofiles    .append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'o.pfcljjs.nc'))
	nc_ofiles_may.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'o.pfclmay.nc'))




for i in range(len(nc_files)):

	SST       .append(nc_ofiles    [i].variables['temp_mm_uo']   [:])
	SST_may   .append(nc_ofiles_may[i].variables['temp_mm_uo']   [:])

	
 

latso, lonso = nc_ofiles[0].variables['latitude'][:], nc_ofiles[0].variables['longitude'][:]
'''
