from __future__ import division

import tools 
import numpy as np 
import netCDF4 as nc 
import os
import matplotlib.pyplot as plt
import pandas as pd 
import pdb
import pycwt as wavelet
from pycwt.helpers import find
import new_tools
import xarray as xr

#import pyleoclim as pyleo
#pyleo.set_style('web')

'----------------------------------------------------------'
'FUNCTIONS'
'----------------------------------------------------------'

def get_clfiles(list_simu, season):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'a.pdcl' + season +'.nc'))
	
	return nc_files 


def get_clofiles(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'o.pfcljjs.nc'))
	
	return nc_files 


def get_mskfiles(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/inidata/' + simu + '.qrparm.mask.nc'))
	
	return nc_files 
	

def get_var(nomvar, nc_files):
	
	var = []
	
	for i, f in enumerate(nc_files):

		var     .append(f.variables[nomvar][:])
	
	return var 

'----------------------------------------------------------'


'----------------------------------------------------------'
'DOMAIN'
'----------------------------------------------------------'
lat_min, lat_max = np.arange(15, 40, 5), np.arange(20, 45, 5)
lon_min, lon_max = 105, 120
'----------------------------------------------------------'


'----------------------------------------------------------'
'DATA RECOVERY'
'----------------------------------------------------------'
os.chdir(os.pardir) 

'''
#list_simu = [tools.list_simu_04, tools.list_simu_05, tools.list_simu_06, tools.list_simu_03]
#list_simu = [tools.list_simu_06, tools.list_simu_05, tools.list_simu_04, tools.list_simu_03]
#list_simu = [tools.list_simu_04, tools.list_simu_05, tools.list_simu_06, tools.list_simu_03, tools.list_triff5]

#list_simu = [tools.list_simu_08, tools.list_simu_09, tools.list_simu_10, tools.list_simu_03]
list_simu = [tools.list_simu_10, tools.list_simu_09, tools.list_simu_08]
list_time = tools.list_time

nc_files  = [get_clfiles(list_simu[i])  for i in range(len(list_simu))]  
precip    = [get_var('precip_mm_srf', nc_files[i]) for i in range(len(list_simu))]

lat, lon  = nc_files[0][0].variables['latitude'][:], nc_files[0][0].variables['longitude'][:]
'''
list_simu_14      = new_tools.list_simu_14
list_simu_14_800k = list_simu_14[700:]


list_simu10 = tools.list_simu_10
list_simu09 = tools.list_simu_09
list_simu08 = tools.list_simu_08

list_time   = tools.list_time

nc_files10   = get_clfiles(list_simu10, 'jun')
precip10_jun = get_var('precip_mm_srf', nc_files10)
lat, lon     = nc_files10[0].variables['latitude'][:], nc_files10[0].variables['longitude'][:]
new_tools.close_files(nc_files10)

nc_files10   = get_clfiles(list_simu10, 'jul')
precip10_jul = get_var('precip_mm_srf', nc_files10)
new_tools.close_files(nc_files10)

nc_files10   = get_clfiles(list_simu10, 'aug')
precip10_aug = get_var('precip_mm_srf', nc_files10)
new_tools.close_files(nc_files10)

precip10     = [(np.asarray(precip10_jun[i]) + np.asarray(precip10_jul[i]) + np.asarray(precip10_aug[i]))/3 for i in range(len(list_time))]


nc_files09   = get_clfiles(list_simu09, 'jun')
precip09_jun = get_var('precip_mm_srf', nc_files09)
new_tools.close_files(nc_files09)

nc_files09   = get_clfiles(list_simu09, 'jul')
precip09_jul = get_var('precip_mm_srf', nc_files09)
new_tools.close_files(nc_files09)

nc_files09   = get_clfiles(list_simu09, 'aug')
precip09_aug = get_var('precip_mm_srf', nc_files09)
new_tools.close_files(nc_files09)


precip09     = [(np.asarray(precip09_jun[i]) + np.asarray(precip09_jul[i]) + np.asarray(precip09_aug[i]))/3 for i in range(len(list_time))]


nc_files08   = get_clfiles(list_simu08, 'jun')
precip08_jun = get_var('precip_mm_srf', nc_files08)
new_tools.close_files(nc_files08)

nc_files08   = get_clfiles(list_simu08, 'jul')
precip08_jul = get_var('precip_mm_srf', nc_files08)
new_tools.close_files(nc_files08)

nc_files08   = get_clfiles(list_simu08, 'aug')
precip08_aug = get_var('precip_mm_srf', nc_files08)
new_tools.close_files(nc_files08)

precip08     = [(np.asarray(precip08_jun[i]) + np.asarray(precip08_jul[i]) + np.asarray(precip08_aug[i]))/3 for i in range(len(list_time))]


nc_files14   = get_clfiles(list_simu_14_800k, 'jun')
precip14_jun = get_var('precip_mm_srf', nc_files14)
new_tools.close_files(nc_files14)

nc_files14   = get_clfiles(list_simu_14_800k, 'jul')
precip14_jul = get_var('precip_mm_srf', nc_files14)
new_tools.close_files(nc_files14)

nc_files14   = get_clfiles(list_simu_14_800k, 'aug')
precip14_aug = get_var('precip_mm_srf', nc_files14)
new_tools.close_files(nc_files14)

precip14     = [(np.asarray(precip14_jun[i]) + np.asarray(precip14_jul[i]) + np.asarray(precip14_aug[i]))/3 for i in range(len(list_time))]

#pdb.set_trace()

precip      = [precip10, precip09, precip14, precip08]

print('DATA RECOVERED')
'----------------------------------------------------------'


'----------------------------------------------------------'
'CUT'
'----------------------------------------------------------'
idx = []
	
for j in range(len(lat_min)):
		
	idx.append(tools.select_domain(lat, lon, lat_max[j], lat_min[j], lon_min, lon_max))


precip_cut = []

for i in range(len(precip)):

	precip_lat = []	

	for j in range(len(idx)):

		precip_lat.append(tools.cut(precip[i], idx[j][0], idx[j][1],idx[j][2], idx[j][3]))

	precip_cut.append(precip_lat)


print('DATA CUT')
'----------------------------------------------------------'


'----------------------------------------------------------'
'AVERAGES'
'----------------------------------------------------------'
precip_mean = []

for i in range(len(precip)):

	precip_dom = []

	for j in range(len(precip_cut[0])):
		
		precip_dom.append(tools.averages(precip_cut[i][j], conv_1 = 86400))

	precip_mean.append(precip_dom)


print('DATA AVERAGED')
'----------------------------------------------------------'

'----------------------------------------------------------'
'CUT AT 140KY'
'----------------------------------------------------------'
idx_140        = tools.find_nearest(np.asarray(tools.list_time), 140)

list_time140   = tools.list_time[idx_140 :]

precip_mean140 = []

for i in range(len(precip)):

	precip_inter = []

	for j in range(len(precip_mean[0])):

		if i < 4 :

			precip_inter.append(precip_mean[i][j][idx_140:])

		else :

			precip_inter.append(precip_mean[i][j])

	precip_mean140.append(precip_inter)


'----------------------------------------------------------'

'----------------------------------------------------------'
'WAVELETS'
'----------------------------------------------------------'
'''
ts_precip_mean = []

t          = np.asarray(list_time)
dt         = np.abs(t[1] - t[0])   # timestep (make sure it's positive)
mother     = "morlet"
wavenumber = 6     # ?0=6 in Torrence & Compo

s0         = 2 * dt        # smallest scale
dj         = 1/12          # 12 sub-octaves per octave
J          = int(7/dj)     # ~7 powers of two



for i in range(len(precip_mean)):

	ts_precip_lat = []

	for j in range(len(precip_mean[0])):
	
		ts = pyleo.Series(time = np.asarray(list_time), value = precip_mean[i][j], time_name = 'kyrs', value_name = 'Precip')
		ts_precip_lat.append(ts)
	
	ts_precip_mean.append(ts_precip_lat)	
	
	
def wavelet(ts):
	
	ts_std   = ts.detrend().interp().standardize()
	scal     = ts_std.wavelet (method = 'cwt')
	scal_sig = scal.signif_test(method = 'ar1asym')

	return scal_sig
	
	
def psd_sig(ts):

	ts_std   = ts.detrend().interp().standardize()	
	psd_sig  = ts_std.spectral(method = 'mtm').signif_test()
	
	return psd_sig
	
	

wavelets_precip = []
psd_precip      = []

for i in range(len(precip_mean)):
	
	wave_lat = []
	psd_lat  = []
	
	for j in range(len(precip_mean[0])):
	
		wave_lat.append(wavelet(ts_precip_mean[i][j]))
		psd_lat .append(psd_sig(ts_precip_mean[i][j]))
		
	wavelets_precip.append(wave_lat)
	psd_precip     .append(psd_lat )

		
#pdb.set_trace()
'''
'----------------------------------------------------------'


'----------------------------------------------------------'
'TS AND SPRECTRAL ANALYSIS'
'----------------------------------------------------------'
'''
#columns_titles = ['All', 'OrbGhG', 'Orb', 'All_oldice']
columns_titles = ['Orb \n \n Precip (mm/day)', 'OrbGhG \n \n Precip (mm/day)', 'All \n \n Precip (mm/day)']

lines_titles   = ['15-20N \n \n Period (kyr)', '20-25N \n \n Period (kyr)', '25-30N \n \n Period (kyr)', '30-35N \n \n Period (kyr)', '35-40N \n \n Period (kyr)']



fig, ax        = plt.subplots(nrows = 5, ncols = 3, figsize = (20, 15))

for i in range(len(precip_mean)):

	for j in range(len(precip_mean[0])):
	
		wavelets_precip[i][j].plot(ax = ax[j, i], xlabel = ' ', ylabel = ' ', title = ' ', contourf_style={'cmap':'viridis'})


		if j == len(list_simu) - 1:

			ax[i, j].set_xlabel('Time (kyrs)')
		

		if i == 0:
			
			ax[j, i].set_ylabel(lines_titles[j], size = 'large')
			
			
for ax, col in zip(ax[0], columns_titles):
	
	ax.set_title(col, size = 'large')
			

fig.tight_layout()


plt.show()
plt.close()


pdb.set_trace()
'''

columns_titles = ['Orb \n \n Precip (mm/day)', 'Orb_GhG \n \n Precip (mm/day)', 'Orb_Ice  \n \n Precip (mm/day)','All \n \n Precip (mm/day)']

lines_titles   = ['15-20N \n \n Period (kyr)', '20-25N \n \n Period (kyr)', '25-30N \n \n Period (kyr)', '30-35N \n \n Period (kyr)', '35-40N \n \n Period (kyr)']



fig, ax        = plt.subplots(nrows = 5, ncols = 4, figsize = (20, 15))

t              = np.asarray(tools.list_time)
dt             = t[0] - t[1]


for i in range(len(precip_mean)):

	for j in range(len(precip_mean[0])):

		dat         = np.asarray(precip_mean[i][j])
		N           = dat.size
		p           = np.polyfit(t, dat, 1)  #Find the trend 
		dat_notrend = dat - np.polyval(p, t) #Remove trend from data 
		std         = dat_notrend.std()       
		var         = std ** 2
		dat_norm    = dat_notrend/std        #Normalize data

		mother      = wavelet.Morlet(6)
		s0          = 2*dt                    #Starting scale
		dj          = 1/12                    #Twelve sub-octaves per octaves 
		J           = 7/dj                    #Seven powers of two with df sub-octaves
		alpha, _, _ = wavelet.ar1(dat)        #Lag-1 autocorrelation for red noise
		
		wave, scales, freqs, coi, fft, fftfreqs = wavelet.cwt(dat_norm, dt, dj, s0, J, mother)

		iwave       = wavelet.icwt(wave, scales, dt, dj, mother) * std
		power       = (np.abs(wave)) ** 2
		fft_power   = np.abs(fft) ** 2
		period      = 1 / freqs

		signif, fft_theor = wavelet.significance(1.0, dt, scales, 0, alpha, significance_level=0.95, wavelet=mother)
		
		sig95       = np.ones([1, N]) * signif[:, None]
		sig95       = power / sig95
		glbl_power  = power.mean(axis=1)
		dof         = N - scales              #Correction for padding at edges


		#glbl_signif, tmp = wavelet.significance(var, dt, scales, 1, alpha, significance_level=0.95, wavelet=mother)
	
		#ax[j, i] = plt.axes([0.1, 0.37, 0.65, 0.28])
		
		levels = [0.0625, 0.125, 0.25, 0.5, 1, 2, 4, 8, 16]

		ax[j, i].contourf(t, np.log2(period), np.log2(power), np.log2(levels), extend='both', cmap=plt.cm.viridis)

		#extent = [t.min(), t.max(), 0, max(period)]

		#ax[j, i].contour(t, np.log2(period), sig95, [-99, 1], colors='k', linewidths=2, extent=extent)
		ax[j, i].contour(t, np.log2(period), sig95, [-99, 1], colors='k', linewidths=2)
		
		#tutu = np.concatenate([t, t[-1:] + dt, t[-1:] + dt, t[:1] - dt, t[:1] - dt])
		#toto = np.concatenate([np.log2(coi), [1e-9], np.log2(period[-1:]), np.log2(period[-1:]), [1e-9]])
		bottom = np.log2(period.max())
		
		tutu   = np.concatenate([t, t[::-1]])
		toto   = np.concatenate([np.log2(np.clip(coi, period.min(), period.max())), bottom * np.ones_like(t)])
		
		
		#tutu = np.concatenate([t, t[::-1]])
		#toto = np.concatenate([np.log2(coi), np.log2(period[-1]) * np.ones_like(t)])
		#toto = np.concatenate([np.log2(coi), [min(period)], np.log2(period[-1:]),  np.log2(period[-1:]),  [min(period)]])
		ax[j, i].fill(tutu, toto, color = 'k', alpha = 0.3, hatch = 'x')
		

		Yticks = 2 ** np.arange(np.ceil(np.log2(period.min())), np.ceil(np.log2(period.max())))

		ax[j, i].set_yticks(np.log2(Yticks))
		#ax[j, i].invert_yaxis()
		ax[j, i].set_yticklabels(Yticks)


		if j == len(precip) - 1:

			ax[i, j].set_xlabel('Time (kyrs)')
		

		if i == 0:
			
			ax[j, i].set_ylabel(lines_titles[j], size = 'large')
		

for ax, col in zip(ax[0], columns_titles):
	
	ax.set_title(col, size = 'large')

#fig.subplots_adjust(hspace=0)


#fig.tight_layout()

fig.savefig('/user/home/xk22684/work/fig/fig_paper_monsoon/lyu_spectre_rev.png')

plt.show()
plt.close()

pdb.set_trace()
'''
legend = ['All', 'OrbGhG', 'Orb', 'All_oldice']
colors = ['black', 'green', 'blue', 'red']
dom    = ['15-20N', '20-25N', '25-30N', '30-35N', '35-40N']


fig, ax = plt.subplots(5, 1, figsize = (20, 15))

for j in range(len(precip_mean[0])):

	for i in range((len(precip_mean))):

		ax[j].plot(tools.list_time, precip_mean[i][j], 'o-', color = colors[i], label = legend[i])
		ax[j].grid(True)
		ax[j].legend()
		ax[j].set_xlabel('Kyr')
		ax[j].set_ylabel(dom[j])
		ax[j].set_title('JJS precipitation (mm/day)')


plt.show()
plt.close()
'''

legend = ['All', 'OrbGhG', 'Orb', 'All_oldice', 'triffid']
colors = ['black', 'green', 'blue', 'red', 'purple']
dom    = ['15-20N', '20-25N', '25-30N', '30-35N', '35-40N']


fig, ax = plt.subplots(5, 1, figsize = (20, 15))

for j in range(len(precip_mean[0])):

	for i in range((len(precip_mean))):

		if i < 4 :

			ax[j].plot(list_time140, precip_mean140[i][j], 'o-', color = colors[i], label = legend[i])


		else :

			ax[j].plot(tools.list_timetriff, precip_mean140[i][j], 'o-', color = colors[i], label = legend[i])

				
		ax[j].grid(True)
		ax[j].set_ylabel(dom[j])
		ax[j].set_xlabel('Kyr')

		if j == 0 :

			ax[j].set_title('JJA precipitation (mm/day)')
			ax[j].legend()
			


plt.show()
plt.close()


'----------------------------------------------------------'

pdb.set_trace()
	
