from __future__ import division

import tools 
import numpy as np 
import netCDF4 as nc 
import os
import matplotlib.pyplot as plt
import pandas as pd 
import pdb
import pycwt as wavelet
from pycwt.helpers import find
from scipy.stats.stats import pearsonr
import new_tools

'----------------------------------------------------------'
'FUNCTIONS'
'----------------------------------------------------------'
def get_clfiles(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'a.pdcljjs.nc'))
	
	return nc_files 


def get_clfiles_may(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'a.pdclmay.nc'))
	
	return nc_files 


def get_clofiles(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'o.pfcljjs.nc'))
	
	return nc_files 
	
	
	
def get_clofiles_may(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'o.pfclmay.nc'))
	
	return nc_files 


def get_mskfiles(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/inidata/' + simu + '.qrparm.mask.nc'))
	
	return nc_files 
	

def get_var(nomvar, nc_files):
	
	var = []
	
	for i, f in enumerate(nc_files):

		var     .append(f.variables[nomvar][:])
	
	return var 
'----------------------------------------------------------'


'----------------------------------------------------------'
'DOMAIN'
'----------------------------------------------------------'
lat_min, lat_max = 50, 0 
#lon_min, lon_max = 40, 130
lon_min, lon_max = 60, 130
'----------------------------------------------------------'


'----------------------------------------------------------'
'DATA RECOVERY'
'----------------------------------------------------------'
os.chdir(os.pardir) 

list_simu_14      = new_tools.list_simu_14
list_simu_14_800k = list_simu_14[700:]

list_simu    = [tools.list_simu_08, tools.list_simu_09, list_simu_14_800k, tools.list_simu_10]
#list_simu  = [tools.list_simu_04, tools.list_simu_05, tools.list_simu_06]
#list_simu  = [tools.list_simu_10]
nc_files     = [get_clfiles(list_simu[i])  for i in range(len(list_simu))]  
precip       = [get_var('precip_mm_srf', nc_files[i]) for i in range(len(list_simu))]
tempe        = [get_var('temp_mm_srf', nc_files[i]) for i in range(len(list_simu))]
lat, lon     = nc_files[0][0].variables['latitude'][:], nc_files[0][0].variables['longitude'][:]

for i in range(len(list_simu)):
	for j in range(len(nc_files[0])):

		nc_files[i][j].close()
		
nc_files_may = [get_clfiles_may(list_simu[i])  for i in range(len(list_simu))]
precip_may   = [get_var('precip_mm_srf', nc_files_may[i]) for i in range(len(list_simu))]
temp_may     = [get_var('temp_mm_srf', nc_files_may[i]) for i in range(len(list_simu))]


for i in range(len(list_simu)):
	for j in range(len(nc_files_may[0])):

		nc_files_may[i][j].close()


nc_ofiles    = [get_clofiles(list_simu[i])  for i in range(len(list_simu))] 
sst          = [get_var('temp_mm_uo', nc_ofiles[i]) for i in range(len(list_simu))]
lato, lono   = nc_ofiles[0][0].variables['latitude'][:], nc_ofiles[0][0].variables['longitude'][:]

for i in range(len(list_simu)):
	for j in range(len(nc_files[0])):

		nc_ofiles[i][j].close()

nc_ofiles_may = [get_clofiles_may(list_simu[i])  for i in range(len(list_simu))] 
sst_may       = [get_var('temp_mm_uo', nc_ofiles_may[i]) for i in range(len(list_simu))]


for i in range(len(list_simu)):
	for j in range(len(nc_files_may[0])):

		nc_ofiles_may[i][j].close()


nc_masks      = [get_mskfiles(list_simu[i])  for i in range(len(list_simu))]  
msk           = [get_var('lsm', nc_masks[i]) for i in range(len(list_simu))]


file_co2_800       = pd.read_csv('/user/work/xk22684/work/data/co2_bbc/regular_CO2.dat', delimiter = '\s')
time800            = file_co2_800['time']*1e-3
co2_800            = file_co2_800['ppmv']

co2_rec            = tools.reconstruct(np.asarray(tools.list_time), time800, co2_800)

nc_file_monsoon    = nc.Dataset('/user/work/xk22684/work/data/monsoon_msk_regridded.nc')
precip_2mm_msk     = nc_file_monsoon.variables['precip_monsoon'][:]


'''
nc_monsoon_msk     = []

for simu in list_simu[0] :
	nc_monsoon_msk.append(nc.Dataset('/user/work/xk22684/work/data/msk_monsoon_bbc8/' + simu + '_monsoon_msk.nc'))

monsoon_msk        = []

for i in range(len(nc_monsoon_msk)): 

	monsoon_msk.append(nc_monsoon_msk[i].variables['precip_monsoon'][:])
'''

print('DATA RECOVERED')
'----------------------------------------------------------'
#pdb.set_trace()

'----------------------------------------------------------'
'CUT'
'----------------------------------------------------------'
idx, idxo = [], []

idx.append(tools.select_domain(lat, lon, lat_min, lat_max, lon_min, lon_max))
idxo.append(tools.select_domain(lato, lono, lat_max, lat_min, lon_min, lon_max))

precip_cut     = [tools.cut(precip[i], idx[0][0], idx[0][1], idx[0][2], idx[0][3]) for i in range(len(list_simu))]
tempe_cut      = [tools.cut(tempe[i], idx[0][0], idx[0][1], idx[0][2], idx[0][3]) for i in range(len(list_simu))]
sst_cut        = [tools.cut(sst[i], idxo[0][0], idxo[0][1], idxo[0][2], idxo[0][3]) for i in range(len(list_simu))]
masks_cut      = [tools.cut(msk[i], idx[0][0], idx[0][1], idx[0][2], idx[0][3]) for i in range(len(list_simu))]


precip_cut_may = [tools.cut(precip_may[i], idx[0][0], idx[0][1], idx[0][2], idx[0][3]) for i in range(len(list_simu))]
tempe_cut_may  = [tools.cut(temp_may[i], idx[0][0], idx[0][1], idx[0][2], idx[0][3]) for i in range(len(list_simu))]
sst_cut_may    = [tools.cut(sst_may[i], idxo[0][0], idxo[0][1], idxo[0][2], idxo[0][3]) for i in range(len(list_simu))]


precip_2mm_msk_cut = precip_2mm_msk[idx[0][0] : idx[0][1] + 1,  idx[0][2]: idx[0][3] + 1]
precip_2mm_msk_cut_masked = np.ma.masked_where(precip_2mm_msk_cut != 2, precip_2mm_msk_cut)
precip_2mm_msk_cut_masked[~precip_2mm_msk_cut_masked.mask] = 1

'''
monsoon_msk_cut    = new_tools.cut(monsoon_msk, idx[0][0],  idx[0][1],  idx[0][2],  idx[0][3])
monsoon_msk_cut    = [np.ma.masked_where(monsoon_msk_cut[i] != 2, precip_2mm_msk_cut) for i in range(len(monsoon_msk_cut))]

for i in range(len(monsoon_msk_cut)):

	monsoon_msk_cut[i][~monsoon_msk_cut[i].mask] = 1
'''
print('DATA CUT')
'----------------------------------------------------------'


'----------------------------------------------------------'
'MSK APPLICATION'
'----------------------------------------------------------'
precip_land     = [tools.mask_application(precip_cut[i], masks_cut[i], 'land') for i in range(len(list_simu))]
precip_sea      = [tools.mask_application(precip_cut[i], masks_cut[i], 'sea') for i in range(len(list_simu))]

tempe_land      = [tools.mask_application(tempe_cut[i], masks_cut[i], 'land') for i in range(len(list_simu))]
tempe_sea       = [tools.mask_application(tempe_cut[i], masks_cut[i], 'sea') for i in range(len(list_simu))]


precip_land_may = [tools.mask_application(precip_cut_may[i], masks_cut[i], 'land') for i in range(len(list_simu))]
precip_sea_may  = [tools.mask_application(precip_cut_may[i], masks_cut[i], 'sea') for i in range(len(list_simu))]

tempe_land_may  = [tools.mask_application(tempe_cut_may[i], masks_cut[i], 'land') for i in range(len(list_simu))]
tempe_sea_may   = [tools.mask_application(tempe_cut_may[i], masks_cut[i], 'sea') for i in range(len(list_simu))]

precip_land2     = [precip_land[i]     * precip_2mm_msk_cut for i in range(len(precip_land))]
precip_land_may2 = [precip_land_may[i] * precip_2mm_msk_cut for i in range(len(precip_land))]
'''
precip_land3     = [precip_land[i]     * monsoon_msk_cut[i] for i in range(len(precip_land))]
precip_land_may3 = [precip_land_may[i] * monsoon_msk_cut[i] for i in range(len(precip_land))]
'''

print('DATA MASKED')
'----------------------------------------------------------'


'----------------------------------------------------------'
'AVERAGE'
'----------------------------------------------------------'
precip_lmean      = [tools.averages(precip_land[i], conv_1 = 86400) for i in range(len(list_simu))]
precip_omean      = [tools.averages(precip_sea[i],  conv_1 = 86400) for i in range(len(list_simu))]
precip_mean       = [tools.averages(precip[i],  conv_1 = 86400) for i in range(len(list_simu))]

#tempe_mean   = [tools.averages(tempe_land[i],  conv_2 = -273.15) for i in range(len(list_simu))]

tempe_mean        = [tools.averages(tempe_land[i],  conv_2 = -273.15) for i in range(len(list_simu))]
sst_mean          = [tools.averages(sst_cut[i]) for i in range(len(list_simu))]


precip_lmean_may  = [tools.averages(precip_land_may[i], conv_1 = 86400) for i in range(len(list_simu))]
precip_omean_may  = [tools.averages(precip_sea_may[i],  conv_1 = 86400) for i in range(len(list_simu))]
precip_mean_may   = [tools.averages(precip_may[i],  conv_1 = 86400) for i in range(len(list_simu))]

#tempe_mean   = [tools.averages(tempe_land[i],  conv_2 = -273.15) for i in range(len(list_simu))]

tempe_mean_may    = [tools.averages(tempe_land_may[i],  conv_2 = -273.15) for i in range(len(list_simu))]
sst_mean_may      = [tools.averages(sst_cut_may[i]) for i in range(len(list_simu))]


precip_lmean_MJJS = [(np.asarray(precip_lmean[i])*4 + np.asarray(precip_lmean_may[i]))/5 for i in range(len(list_simu))]
precip_omean_MJJS = [(np.asarray(precip_omean[i])*4 + np.asarray(precip_omean_may[i]))/5 for i in range(len(list_simu))]
precip_mean_MJJS  = [(np.asarray(precip_mean[i])*4  + np.asarray(precip_mean_may[i]))/5  for i in range(len(list_simu))]

tempe_mean_MJJS   = [(np.asarray(tempe_mean[i])*4   + np.asarray(tempe_mean_may[i]))/5 for i in range(len(list_simu))]
sst_mean_MJJS     = [(np.asarray(sst_mean[i])*4     + np.asarray(sst_mean_may[i]))/5 for i in range(len(list_simu))]



precip_av2        = [tools.averages(precip_land2[i],      conv_1 = 86400) for i in range(len(list_simu))]
precip_av_may2    = [tools.averages(precip_land_may2[i],  conv_1 = 86400) for i in range(len(list_simu))]

precip_lmean_MJJS2 = [(np.asarray(precip_av2[i])*4 + np.asarray(precip_av2[i]))/5 for i in range(len(list_simu))]
'''
precip_av3        = [tools.averages(precip_land3[i],      conv_1 = 86400) for i in range(len(list_simu))]
precip_av_may3    = [tools.averages(precip_land_may3[i],  conv_1 = 86400) for i in range(len(list_simu))]

precip_lmean_MJJS3 = [(np.asarray(precip_av3[i])*4 + np.asarray(precip_av3[i]))/5 for i in range(len(list_simu))]
'''

deltaT            = [np.asarray(tempe_mean[i]) - np.asarray(sst_mean[i]) for i in range(len(list_simu))]

print('DATA AVERAGED')
'----------------------------------------------------------'
#pdb.set_trace()

'----------------------------------------------------------'
'corr'
'----------------------------------------------------------'
print('t-score co2 sst MJJS= ', pearsonr(co2_rec, sst_mean_MJJS[0]))
print('')
print('t-score co2 lst MJJS= ', pearsonr(co2_rec, tempe_mean_MJJS[0]))
print('')
print('t-score co2  Land Precip MJJS= ', pearsonr(co2_rec, precip_lmean_MJJS[0]))
print('')
print('t-score co2  Sea Precip MJJS= ', pearsonr(co2_rec, precip_omean_MJJS[0]))
'----------------------------------------------------------'
#pdb.set_trace()


'----------------------------------------------------------'
'SPECTRAL ANALYSIS - FIG'
'----------------------------------------------------------'
#list_var       = [sst_mean, tempe_mean, precip_omean, precip_lmean, precip_mean]

list_var       = [sst_mean_MJJS, tempe_mean_MJJS, precip_lmean_MJJS2, precip_omean_MJJS]

columns_titles = ['SST_MJJAS', 'LST_MJJAS', 'Land Precip_MJJAS \n mousson area from GPCC', 'Sea Precip_MJJAS']


#columns_titles = ['Precip over land (mm/day)', 'SST (degC)' ,'LST (degC)', '$\Delta$(LST-SST) (degC)']
rows_titles    = ['All Forcings \n \n Period (kyr)', 'Orb_GhG \n \n Period (kyr)', 'Orb_Ice \n \n Period (kyr)','Orb_Only \n \n Period (kyr)']

fig, ax        = plt.subplots(nrows = len(list_simu), ncols = len(list_var), figsize = (20, 20))

fig2, bx       = plt.subplots(nrows = len(list_simu), ncols = len(list_var), figsize = (20, 20))

t              = np.asarray(tools.list_time)
dt             = t[0] - t[1]


for i in range(len(list_simu)):

	for j in range(len(list_var)):	

		dat         = np.asarray(list_var[j][i])
		N           = dat.size
		p           = np.polyfit(t, dat, 1)  #Find the trend 
		dat_notrend = dat - np.polyval(p, t) #Remove trend from data 
		std         = dat_notrend.std()       
		var         = std ** 2
		dat_norm    = dat_notrend/std        #Normalize data

		mother      = wavelet.Morlet(6)
		s0          = 2*dt                    #Starting scale
		dj          = 1/12                    #Twelve sub-octaves per octaves 
		J           = 7/dj                    #Seven powers of two with df sub-octaves
		alpha, _, _ = wavelet.ar1(dat)        #Lag-1 autocorrelation for red noise
		
		wave, scales, freqs, coi, fft, fftfreqs = wavelet.cwt(dat_norm, dt, dj, s0, J, mother)

		iwave       = wavelet.icwt(wave, scales, dt, dj, mother) * std
		power       = (np.abs(wave)) ** 2
		fft_power   = np.abs(fft) ** 2
		period      = 1 / freqs

		signif, fft_theor = wavelet.significance(1.0, dt, scales, 0, alpha, significance_level=0.95, wavelet=mother)
		
		sig95       = np.ones([1, N]) * signif[:, None]
		sig95       = power / sig95
		glbl_power  = power.mean(axis=1)
		dof         = N - scales              #Correction for padding at edges

		glbl_signif, tmp = wavelet.significance(var, dt, scales, 1, alpha, significance_level=0.95, dof = dof, wavelet=mother)
	
		#ax[j, i] = plt.axes([0.1, 0.37, 0.65, 0.28])
		
		levels = [0.0625, 0.125, 0.25, 0.5, 1, 2, 4, 8, 16]
	
		ax[i, j].contourf(t, np.log2(period), np.log2(power), np.log2(levels), extend='both', cmap=plt.cm.viridis)

		extent = [t.min(), t.max(), 0, max(period)]

		ax[i, j].contour(t, np.log2(period), sig95, [-99, 1], colors='k', linewidths=2, extent=extent)


		#tutu = np.concatenate([t, t[-1:] + dt, t[-1:] + dt, t[:1] - dt, t[:1] - dt])
		#toto = np.concatenate([np.log2(coi), [1e-9], np.log2(period[-1:]), np.log2(period[-1:]), [1e-9]])
		bottom = np.log2(period.max())
		tutu   = np.concatenate([t, t[::-1]])
		toto   = np.concatenate([np.log2(np.clip(coi, period.min(), period.max())), bottom * np.ones_like(t)])
		
		ax[i, j].fill(tutu, toto, 'k', alpha = 0.3, hatch = 'x')

		Yticks = 2 ** np.arange(np.ceil(np.log2(period.min())), np.ceil(np.log2(period.max())))


		ax[i, j].set_yticks(np.log2(Yticks))
		#ax[i, j].invert_yaxis()
		ax[i, j].set_yticklabels(Yticks)


		if j == 0:
			
			ax[i, j].set_ylabel(rows_titles[i], size = 'large')

		if i == len(list_simu) - 1:

			ax[i, j].set_xlabel('Time (kyrs)')

		bx[i, j].plot(glbl_signif, np.log2(period), 'k--')
		bx[i, j].plot(var * fft_theor, np.log2(period), '--', color='#cccccc')
		bx[i, j].plot(var * fft_power, np.log2(1./fftfreqs), '-', color='#cccccc', linewidth=1.)
		bx[i, j].plot(var * glbl_power, np.log2(period), 'k-', linewidth=1.5)
		#bx[i, j].set_xlabel(r'Power [({})^2]'.format(units))
		
		if j > 2: 
			bx[i, j].set_xlim([0, 1])
		
		if j == 2 :
		
			bx[i, j].set_xlim([0, 3])
		else:
		
			bx[i, j].set_xlim([0, 9])
			
		#bx[i, j].set_xlim([0, glbl_power.max() + var])
		#bx[i, j].set_xlim([0, 9])
		bx[i, j].set_ylim(np.log2([period.min(), period.max()]))
		bx[i, j].set_yticks(np.log2(Yticks))
		bx[i, j].set_yticklabels(Yticks)
		bx[i, j].grid(alpha = 0.7)

		if j == 0:

			bx[i, j].set_ylabel(rows_titles[i], size = 'large')

		if i == len(list_simu) - 1:

			bx[i, j].set_xlabel('Power')


for ax, col in zip(ax[0], columns_titles):
	
	ax.set_title(col, size = 'large')



for bx, col in zip(bx[0], columns_titles):
	
	bx.set_title(col, size = 'large')



#fig.subplots_adjust(hspace=0)
#fig.tight_layout()

fig.savefig('/user/work/xk22684/work/fig/fig_paper_monsoon/wave_spectrum_rev.png')
fig2.savefig('/user/work/xk22684/work/fig/fig_paper_monsoon/glob_spectrum_rev.pdf')

plt.show()
plt.close()




'----------------------------------------------------------'
pdb.set_trace()



