import tools 
import numpy as np 
import netCDF4 as nc 
import os
import matplotlib.pyplot as plt
import pandas as pd 
import pdb
import new_tools

def get_clfiles(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'a.pdcljjs.nc'))
	
	return nc_files 


def get_clfiles_may(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'a.pdclmay.nc'))
	
	return nc_files 


def get_clofiles(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'o.pfcljjs.nc'))
	
	return nc_files 


def get_clofiles_may(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'o.pfclmay.nc'))
	
	return nc_files 


def get_mskfiles(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/inidata/' + simu + '.qrparm.mask.nc'))
	
	return nc_files 
	


def get_var(nomvar, nc_files):
	
	var = []
	
	for i, f in enumerate(nc_files):

		var     .append(f.variables[nomvar][:])
	
	return var 


'-----------------------------------------------------------------'
'DOMAIN DEMARCATION'
'-----------------------------------------------------------------'
lat_min, lat_max = 50, 0 
#lon_min, lon_max = 40, 130
lon_min, lon_max = 60, 130
lon_mid          = 90
lat_mid          = 25
'-----------------------------------------------------------------'

'------------------------------------------------------------------'
'DATA'
'------------------------------------------------------------------'
os.chdir(os.pardir)


list_simu_14       = new_tools.list_simu_14
list_simu_14_800k  = list_simu_14[700:]

list_simu          = [tools.list_simu_08, tools.list_simu_09, list_simu_14_800k, tools.list_simu_10]

nc_ofiles          = [get_clofiles(list_simu[i])  for i in range(len(list_simu))] 
sst                = [get_var('temp_mm_uo', nc_ofiles[i]) for i in range(len(list_simu))]
lato, lono         = nc_ofiles[0][0].variables['latitude'][:], nc_ofiles[0][0].variables['longitude'][:]

for i in range(len(nc_ofiles)):

	for j in range(len(nc_ofiles[0])):

		nc_ofiles[i][j].close()

nc_ofiles_may = [get_clofiles_may(list_simu[i])  for i in range(len(list_simu))] 
sst_may       = [get_var('temp_mm_uo', nc_ofiles_may[i]) for i in range(len(list_simu))]


for i in range(len(list_simu)):
	for j in range(len(nc_ofiles_may[0])):

		nc_ofiles_may[i][j].close()



nc_files           = [get_clfiles(list_simu[i]) for i in range(len(list_simu))]
precip             = [get_var('precip_mm_srf', nc_files[i]) for i in range(len(list_simu))]
lat, lon           = nc_files[0][0].variables['latitude'][:], nc_files[0][0].variables['longitude'][:]


for i in range(len(nc_files)):

	for j in range(len(nc_files[0])):

		nc_files[i][j].close()
		
nc_files_may = [get_clfiles_may(list_simu[i])  for i in range(len(list_simu))]
precip_may   = [get_var('precip_mm_srf', nc_files_may[i]) for i in range(len(list_simu))]


for i in range(len(list_simu)):
	for j in range(len(nc_files_may[0])):

		nc_files_may[i][j].close()

nc_masks           = [get_mskfiles(list_simu[i])  for i in range(len(list_simu))]  
msk                = [get_var('lsm', nc_masks[i]) for i in range(len(list_simu))]

nc_file_monsoon    = nc.Dataset('/user/work/xk22684/work/data/monsoon_msk_regridded.nc')
precip_2mm_msk     = nc_file_monsoon.variables['precip_monsoon'][:]

'''
nc_monsoon_msk     = []

for simu in list_simu[0] :
	nc_monsoon_msk.append(nc.Dataset('/user/work/xk22684/work/data/msk_monsoon_bbc8/' + simu + '_monsoon_msk.nc'))

monsoon_msk        = []

for i in range(len(nc_monsoon_msk)): 

	monsoon_msk.append(nc_monsoon_msk[i].variables['precip_monsoon'][:])
'''

print('DATA RECOVERED')
'----------------------------------------------------------'


'----------------------------------------------------------'
'CUT'
'----------------------------------------------------------'
idx, idxo = [], []

idx.append(tools.select_domain(lat, lon, lat_min, lat_max, lon_min, lon_max))
idxo.append(tools.select_domain(lato, lono, lat_max, lat_min, lon_min, lon_max))

idx_lon_mid    = tools.find_nearest(lon, lon_mid)
idx_lat_mid    = tools.find_nearest(lat, lat_mid)



precip_cut     = [tools.cut(precip[i], idx[0][0], idx[0][1], idx[0][2], idx[0][3]) for i in range(len(list_simu))]

precip_cut_Neast = [tools.cut(precip[i], idx[0][0], idx_lat_mid, idx[0][2], idx_lon_mid) for i in range(len(list_simu))]
precip_cut_Nwest = [tools.cut(precip[i], idx[0][0], idx_lat_mid, idx_lon_mid, idx[0][3]) for i in range(len(list_simu))]

precip_cut_Seast = [tools.cut(precip[i], idx_lat_mid, idx[0][1], idx[0][2], idx_lon_mid) for i in range(len(list_simu))]
precip_cut_Swest = [tools.cut(precip[i], idx_lat_mid, idx[0][1], idx_lon_mid, idx[0][3]) for i in range(len(list_simu))]

 
sst_cut        = [tools.cut(sst[i], idxo[0][0], idxo[0][1], idxo[0][2], idxo[0][3]) for i in range(len(list_simu))]
masks_cut      = [tools.cut(msk[i], idx[0][0], idx[0][1], idx[0][2], idx[0][3]) for i in range(len(list_simu))]


precip_cut_may = [tools.cut(precip_may[i], idx[0][0], idx[0][1], idx[0][2], idx[0][3]) for i in range(len(list_simu))]




precip_cut_may_Neast = [tools.cut(precip_may[i], idx[0][0], idx_lat_mid, idx[0][2], idx_lon_mid) for i in range(len(list_simu))]
precip_cut_may_Nwest = [tools.cut(precip_may[i], idx[0][0], idx_lat_mid, idx_lon_mid, idx[0][3]) for i in range(len(list_simu))]

precip_cut_may_Seast = [tools.cut(precip_may[i], idx_lat_mid, idx[0][1], idx[0][2], idx_lon_mid) for i in range(len(list_simu))]
precip_cut_may_Swest = [tools.cut(precip_may[i], idx_lat_mid, idx[0][1], idx_lon_mid, idx[0][3]) for i in range(len(list_simu))]


sst_cut_may    = [tools.cut(sst_may[i], idxo[0][0], idxo[0][1], idxo[0][2], idxo[0][3]) for i in range(len(list_simu))]

precip_2mm_msk_cut = precip_2mm_msk[idx[0][0] : idx[0][1] + 1,  idx[0][2]: idx[0][3] + 1]
precip_2mm_msk_cut_masked = np.ma.masked_where(precip_2mm_msk_cut != 2, precip_2mm_msk_cut)
precip_2mm_msk_cut_masked[~precip_2mm_msk_cut_masked.mask] = 1

precip_2mm_msk_cut_Neast = precip_2mm_msk[idx[0][0] : idx_lat_mid + 1,  idx[0][2]:idx_lon_mid  + 1]
precip_2mm_msk_cut_masked_Neast = np.ma.masked_where(precip_2mm_msk_cut_Neast != 2, precip_2mm_msk_cut_Neast)
precip_2mm_msk_cut_masked_Neast[~precip_2mm_msk_cut_masked_Neast.mask] = 1


precip_2mm_msk_cut_Nwest = precip_2mm_msk[idx[0][0] : idx_lat_mid + 1,  idx_lon_mid:idx[0][3]  + 1]
precip_2mm_msk_cut_masked_Nwest = np.ma.masked_where(precip_2mm_msk_cut_Nwest != 2, precip_2mm_msk_cut_Nwest)
precip_2mm_msk_cut_masked_Nwest[~precip_2mm_msk_cut_masked_Nwest.mask] = 1


precip_2mm_msk_cut_Seast = precip_2mm_msk[idx_lat_mid : idx[0][1] + 1 ,  idx[0][2]:idx_lon_mid  + 1]
precip_2mm_msk_cut_masked_Seast = np.ma.masked_where(precip_2mm_msk_cut_Seast != 2, precip_2mm_msk_cut_Seast)
precip_2mm_msk_cut_masked_Seast[~precip_2mm_msk_cut_masked_Seast.mask] = 1


precip_2mm_msk_cut_Swest = precip_2mm_msk[idx_lat_mid  : idx[0][1] + 1,  idx_lon_mid:idx[0][3]  + 1]
precip_2mm_msk_cut_masked_Swest = np.ma.masked_where(precip_2mm_msk_cut_Swest != 2, precip_2mm_msk_cut_Swest)
precip_2mm_msk_cut_masked_Swest[~precip_2mm_msk_cut_masked_Swest.mask] = 1



'''
monsoon_msk_cut    = new_tools.cut(monsoon_msk, idx[0][0],  idx[0][1],  idx[0][2],  idx[0][3])
monsoon_msk_cut    = [np.ma.masked_where(monsoon_msk_cut[i] != 2, precip_2mm_msk_cut) for i in range(len(monsoon_msk_cut))]

for i in range(len(monsoon_msk_cut)):

	monsoon_msk_cut[i][~monsoon_msk_cut[i].mask] = 1
'''
print('DATA CUT')
'----------------------------------------------------------'

#pdb.set_trace()

'----------------------------------------------------------'
'MSK APPLICATION'
'----------------------------------------------------------'
precip_land      = [tools.mask_application(precip_cut[i], masks_cut[i], 'land') for i in range(len(list_simu))]
precip_land_may  = [tools.mask_application(precip_cut_may[i], masks_cut[i], 'land') for i in range(len(list_simu))]


precip_land2     = [precip_land[i]     * precip_2mm_msk_cut for i in range(len(precip_land))]
precip_land_may2 = [precip_land_may[i] * precip_2mm_msk_cut for i in range(len(precip_land))]

precip_land2_Neast     = [precip_cut_Neast[i]     * precip_2mm_msk_cut_masked_Neast for i in range(len(list_simu))]
precip_land_may2_Neast = [precip_cut_may_Neast[i] * precip_2mm_msk_cut_masked_Neast for i in range(len(list_simu))]

precip_land2_Seast     = [precip_cut_Seast[i]     * precip_2mm_msk_cut_masked_Seast for i in range(len(list_simu))]
precip_land_may2_Seast = [precip_cut_may_Seast[i] * precip_2mm_msk_cut_masked_Seast for i in range(len(list_simu))]


precip_land2_Nwest     = [precip_cut_Nwest[i]     * precip_2mm_msk_cut_masked_Nwest for i in range(len(list_simu))]
precip_land_may2_Nwest = [precip_cut_may_Nwest[i] * precip_2mm_msk_cut_masked_Nwest for i in range(len(list_simu))]

precip_land2_Swest     = [precip_cut_Swest[i]     * precip_2mm_msk_cut_masked_Swest for i in range(len(list_simu))]
precip_land_may2_Swest = [precip_cut_may_Swest[i] * precip_2mm_msk_cut_masked_Swest for i in range(len(list_simu))]


'''
precip_land3     = [precip_land[i]     * monsoon_msk_cut[i] for i in range(len(precip_land))]
precip_land_may3 = [precip_land_may[i] * monsoon_msk_cut[i] for i in range(len(precip_land))]
'''
'----------------------------------------------------------'

'----------------------------------------------------------'
'MEAN'
'----------------------------------------------------------'
sst_mean          = [tools.averages(sst_cut[i]) for i in range(len(list_simu))]
precip_lmean      = [tools.averages(precip_land[i], conv_1 = 86400) for i in range(len(list_simu))]

precip_lmean_may  = [tools.averages(precip_land_may[i], conv_1 = 86400) for i in range(len(list_simu))]
sst_mean_may      = [tools.averages(sst_cut_may[i]) for i in range(len(list_simu))]

precip_lmean_MJJS = [(np.asarray(precip_lmean[i])*4 + np.asarray(precip_lmean_may[i]))/5 for i in range(len(list_simu))]
sst_mean_MJJS     = [(np.asarray(sst_mean[i])*4     + np.asarray(sst_mean_may[i]))/5 for i in range(len(list_simu))]

precip_av2        = [tools.averages(precip_land2[i],      conv_1 = 86400) for i in range(len(list_simu))]
precip_av_may2    = [tools.averages(precip_land_may2[i],  conv_1 = 86400) for i in range(len(list_simu))]

precip_av2_Neast        = [tools.averages(precip_land2_Neast[i],      conv_1 = 86400) for i in range(len(list_simu))]
precip_av_may2_Neast    = [tools.averages(precip_land_may2_Neast[i],  conv_1 = 86400) for i in range(len(list_simu))]

precip_av2_Seast        = [tools.averages(precip_land2_Seast[i],      conv_1 = 86400) for i in range(len(list_simu))]
precip_av_may2_Seast    = [tools.averages(precip_land_may2_Seast[i],  conv_1 = 86400) for i in range(len(list_simu))]


precip_av2_Nwest        = [tools.averages(precip_land2_Nwest[i],      conv_1 = 86400) for i in range(len(list_simu))]
precip_av_may2_Nwest    = [tools.averages(precip_land_may2_Nwest[i],  conv_1 = 86400) for i in range(len(list_simu))]

precip_av2_Swest        = [tools.averages(precip_land2_Swest[i],      conv_1 = 86400) for i in range(len(list_simu))]
precip_av_may2_Swest    = [tools.averages(precip_land_may2_Swest[i],  conv_1 = 86400) for i in range(len(list_simu))]


precip_lmean_MJJS2 = [(np.asarray(precip_av2[i])*4 + np.asarray(precip_av2[i]))/5 for i in range(len(list_simu))]

precip_lmean_MJJS2_Neast = [(np.asarray(precip_av2_Neast[i])*4 + np.asarray(precip_av2_Neast[i]))/5 for i in range(len(list_simu))]
precip_lmean_MJJS2_Nwest = [(np.asarray(precip_av2_Nwest[i])*4 + np.asarray(precip_av2_Nwest[i]))/5 for i in range(len(list_simu))]

precip_lmean_MJJS2_Seast = [(np.asarray(precip_av2_Seast[i])*4 + np.asarray(precip_av2_Seast[i]))/5 for i in range(len(list_simu))]
precip_lmean_MJJS2_Swest = [(np.asarray(precip_av2_Swest[i])*4 + np.asarray(precip_av2_Swest[i]))/5 for i in range(len(list_simu))]


'''
precip_av3        = [tools.averages(precip_land3[i],      conv_1 = 86400) for i in range(len(list_simu))]
precip_av_may3    = [tools.averages(precip_land_may3[i],  conv_1 = 86400) for i in range(len(list_simu))]

precip_lmean_MJJS3 = [(np.asarray(precip_av3[i])*4 + np.asarray(precip_av3[i]))/5 for i in range(len(list_simu))]
'''
'----------------------------------------------------------'
#pdb.set_trace()



print(np.asarray(sst_mean[0]).mean())
print(np.asarray(sst_mean[1]).mean())
#pdb.set_trace()

'----------------------------------------------------------'
'PLOT'
'----------------------------------------------------------'
labels  = ['All_forcings', 'Orb_Ghg', 'Orb_Ice', 'Orb']
colors  = ['black', 'green', 'purple' ,'blue'] 

fig, ax = plt.subplots(nrows = 2, ncols = 2, figsize = (20, 10))

for i in range(len(list_simu)):

	ax[0, 0].plot(tools.list_time, precip_lmean_MJJS2_Neast[i], '-o', color = colors[i], label = labels[i])
	ax[0, 1].plot(tools.list_time, precip_lmean_MJJS2_Nwest[i], '-o', color = colors[i], label = labels[i])
	ax[1, 0].plot(tools.list_time, precip_lmean_MJJS2_Seast[i], '-o', color = colors[i], label = labels[i])
	ax[1, 1].plot(tools.list_time, precip_lmean_MJJS2_Swest[i], '-o', color = colors[i], label = labels[i])
	

ax[0, 0].set_ylabel('mm/day')
ax[0, 0].set_title('Summer monsoon precipitation (MJJAS)- Northern East part')
ax[0, 0].grid()
ax[0, 0].legend()

ax[0, 1].set_ylabel('mm/day')
ax[0, 1].set_title('Summer monsoon precipitation (MJJAS) - Northern West part')
ax[0, 1].grid()
ax[0, 1].legend()


ax[1, 0].set_ylabel('mm/day')
ax[1, 0].set_title('Summer monsoon precipitation (MJJAS)- Southern East part')
ax[1, 0].grid()
ax[1, 0].legend()

ax[1, 1].set_ylabel('mm/day')
ax[1, 1].set_title('Summer monsoon precipitation (MJJAS) - Southern West part')
ax[1, 1].grid()
ax[1, 1].legend()



fig.savefig('/user/work/xk22684/work/fig/fig_paper_monsoon/TS_allsimu_sector.png')


plt.show()
plt.close()

pdb.set_trace()

labels  = ['All_forcings', 'Orb_Ghg', 'Orb_Ice', 'Orb']
colors  = ['black', 'green', 'purple' ,'blue'] 

fig, ax = plt.subplots(nrows = 2, ncols = 1, figsize = (13, 10))

for i in range(len(list_simu)):

	ax[0].plot(tools.list_time, precip_lmean_MJJS2[i], '-o', color = colors[i], label = labels[i])
	ax[1].plot(tools.list_time, sst_mean_MJJS[i], '-o', color = colors[i], label = labels[i])

ax[0].set_ylabel('mm/day')
ax[0].set_title('Precipitation over land MJJAS \n only where >2mm and <50% of annual precip based on GPCC dataset')
ax[0].grid()
ax[0].legend()

ax[1].set_ylabel('$^\circ$C')
ax[1].set_title('SST MJJAS')
ax[1].grid()
ax[1].set_xlabel('Time (kyr)')

fig.savefig('/user/work/xk22684/work/fig/fig_paper_monsoon/TS_allsimu.pdf')

plt.show()
plt.close()
'----------------------------------------------------------'

pdb.set_trace()




