#from __future__ import division

import tools 
import numpy as np 
import netCDF4 as nc 
import os
import matplotlib.pyplot as plt
import pandas as pd 
import pdb
#import pycwt as wavelet
#from pycwt.helpers import find
from scipy.stats.stats import pearsonr
import new_tools 
#import pyleoclim as pyleo
#pyleo.set_style('web')
'----------------------------------------------------------'
'FUNCTIONS'
'----------------------------------------------------------'
def get_clfiles(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'a.pdcljjs.nc'))
	
	return nc_files 


def get_clfiles_may(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'a.pdclmay.nc'))
	
	return nc_files 


def get_clofiles(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'o.pfcljjs.nc'))
	
	return nc_files 
	
	
	
def get_clofiles_may(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/climate/' + simu + 'o.pfclmay.nc'))
	
	return nc_files 


def get_mskfiles(list_simu):
	
	nc_files = []
	
	for i, simu in enumerate(list_simu):

		nc_files.append(nc.Dataset('/user/home/ggpjv/swsvalde/ummodel/data_olig/' + simu + '/inidata/' + simu + '.qrparm.mask.nc'))
	
	return nc_files 
	

def get_var(nomvar, nc_files):
	
	var = []
	
	for i, f in enumerate(nc_files):

		var     .append(f.variables[nomvar][:])
	
	return var 
'----------------------------------------------------------'


'----------------------------------------------------------'
'DOMAIN'
'----------------------------------------------------------'
lat_min, lat_max = 50, 0 
lon_min, lon_max = 60, 130
'----------------------------------------------------------'


'----------------------------------------------------------'
'DATA RECOVERY'
'----------------------------------------------------------'
os.chdir(os.pardir) 

path_su2025     = '/user/work/xk22684/work/data/su_2025/'

nc_precALL      = nc.Dataset(path_su2025 + 'ALL.PR.monthly.nc') 
nc_precCO2      = nc.Dataset(path_su2025 + 'CO2.PR.monthly.nc') 
nc_precICE      = nc.Dataset(path_su2025 + 'ICE.PR.monthly.nc') 
nc_precORB      = nc.Dataset(path_su2025 + 'ORB.PR.monthly.nc') 

precALL         = nc_precALL.variables['PRECT'][:]
precCO2         = nc_precCO2.variables['PRECT'][:]
precICE         = nc_precICE.variables['PRECT'][:]
precORB         = nc_precORB.variables['PRECT'][:]

lat_su, lon_su  = nc_precALL.variables['lat'  ][:], nc_precALL.variables['lon'][:]
time_su         = nc_precALL.variables['time' ][:]

nc_precALL.close()
nc_precCO2.close()
nc_precICE.close()
nc_precORB.close()


nc_TASALL       = nc.Dataset(path_su2025 + 'ALL.TAS.monthly.nc') 
nc_TASCO2       = nc.Dataset(path_su2025 + 'CO2.TAS.monthly.nc') 
nc_TASICE       = nc.Dataset(path_su2025 + 'ICE.TAS.monthly.nc') 
nc_TASORB       = nc.Dataset(path_su2025 + 'ORB.TAS.monthly.nc') 

tasALL          = nc_TASALL.variables['TREFHT'][:]
tasCO2          = nc_TASCO2.variables['TREFHT'][:]
tasICE          = nc_TASICE.variables['TREFHT'][:]
tasORB          = nc_TASORB.variables['TREFHT'][:]

nc_TASALL.close()
nc_TASCO2.close()
nc_TASICE.close()
nc_TASORB.close()


nc_msk_monsoon  = nc.Dataset(path_su2025 + 'monsoon_msk_CESM.nc')

precip_msk_monsoon = nc_msk_monsoon.variables['precip_monsoon'][:]


nb_year         = len(time_su)//12

precALL_yr      = precALL.reshape(nb_year, 12, 48, 96)
precAll_MJJS    = precALL_yr[:, 5:10, :, :] * 1000 * 86400
precALL_mjjs_av = np.ma.mean(precAll_MJJS, axis = 1) 
precALL_yr_av   = np.ma.mean(precALL_yr,   axis = 1) * 1000 * 86400

precCO2_yr      = precCO2.reshape(nb_year, 12, 48, 96)
precCO2_MJJS    = precCO2_yr[:, 5:10, :, :] * 1000 * 86400
precCO2_mjjs_av = np.ma.mean(precCO2_MJJS, axis = 1) 
precCO2_yr_av   = np.ma.mean(precCO2_yr,   axis = 1) * 1000 * 86400

precICE_yr      = precICE.reshape(nb_year, 12, 48, 96)
precICE_MJJS    = precICE_yr[:, 5:10, :, :] * 1000 * 86400
precICE_mjjs_av = np.ma.mean(precICE_MJJS, axis = 1) 
precICE_yr_av   = np.ma.mean(precICE_yr,   axis = 1) * 1000 * 86400


precORB_yr      = precORB.reshape(nb_year, 12, 48, 96)
precORB_MJJS    = precORB_yr[:, 5:10, :, :] * 1000 * 86400
precORB_mjjs_av = np.ma.mean(precORB_MJJS, axis = 1) 
precORB_yr_av   = np.ma.mean(precORB_yr,   axis = 1) * 1000 * 86400




tasALL_yr       = tasALL.reshape(nb_year, 12, 48, 96)
tasAll_MJJS     = tasALL_yr[:, 5:10, :, :] - 273.15
tasALL_mjjs_av  = np.ma.mean(tasAll_MJJS, axis = 1) 
tasALL_yr_av    = np.ma.mean(tasALL_yr,   axis = 1) - 273.15


tasCO2_yr       = tasCO2.reshape(nb_year, 12, 48, 96)
tasCO2_MJJS     = tasCO2_yr[:, 5:10, :, :] - 273.15
tasCO2_mjjs_av  = np.ma.mean(tasCO2_MJJS, axis = 1) 
tasCO2_yr_av    = np.ma.mean(tasCO2_yr,   axis = 1) - 273.15


tasICE_yr       = tasICE.reshape(nb_year, 12, 48, 96)
tasICE_MJJS     = tasICE_yr[:, 5:10, :, :] - 273.15
tasICE_mjjs_av  = np.ma.mean(tasICE_MJJS, axis = 1) 
tasICE_yr_av    = np.ma.mean(tasICE_yr,   axis = 1) - 273.15



tasORB_yr       = tasORB.reshape(nb_year, 12, 48, 96)
tasORB_MJJS     = tasORB_yr[:, 5:10, :, :] - 273.15
tasORB_mjjs_av  = np.ma.mean(tasORB_MJJS, axis = 1) 
tasORB_yr_av    = np.ma.mean(tasORB_yr,   axis = 1) - 273.15


time_yr         = time_su.reshape(nb_year, 12)
time_yr_av      = np.mean(time_yr, axis = 1)


nc_mask         = nc.Dataset(path_su2025 + 'mask_ocean_cesm.nc') 
msk             = nc_mask.variables['topo'][:]
tutu            = np.where(msk == 1, np.nan, 1)
msk_land        = np.ma.masked_where(np.isnan(tutu), tutu)
'----------------------------------------------------------'
#pdb.set_trace()

'----------------------------------------------------------'
'CUT'
'----------------------------------------------------------'
idx_lat_min         = tools.find_nearest(lat_su, lat_min)
idx_lat_max         = tools.find_nearest(lat_su, lat_max)

idx_lon_min         = tools.find_nearest(lon_su, lon_min)
idx_lon_max         = tools.find_nearest(lon_su, lon_max)

precALL_mjjs_av_cut = precALL_mjjs_av[:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
precCO2_mjjs_av_cut = precCO2_mjjs_av[:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
precICE_mjjs_av_cut = precICE_mjjs_av[:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
precORB_mjjs_av_cut = precORB_mjjs_av[:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]


tasALL_mjjs_av_cut  = tasALL_mjjs_av [:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
tasCO2_mjjs_av_cut  = tasCO2_mjjs_av [:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
tasICE_mjjs_av_cut  = tasICE_mjjs_av [:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
tasORB_mjjs_av_cut  = tasORB_mjjs_av [:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
'''
precALL_mjjs_av_cut = precALL_yr_av[:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
precCO2_mjjs_av_cut = precCO2_yr_av[:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
precICE_mjjs_av_cut = precICE_yr_av[:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
precORB_mjjs_av_cut = precORB_yr_av[:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]


tasALL_mjjs_av_cut  = tasALL_yr_av [:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
tasCO2_mjjs_av_cut  = tasCO2_yr_av [:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
tasICE_mjjs_av_cut  = tasICE_yr_av [:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
tasORB_mjjs_av_cut  = tasORB_yr_av [:, idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
'''

mask_cut            = msk            [idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]
mask_land_cut       = msk_land       [idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]

precip_msk_monscut  = precip_msk_monsoon[idx_lat_max : idx_lat_min + 1, idx_lon_min : idx_lon_max + 1]

'----------------------------------------------------------'


'----------------------------------------------------------'
'MSK APPLICATION'
'----------------------------------------------------------'
'''
precALL_mjjs_av_land = precALL_mjjs_av_cut * mask_cut
precCO2_mjjs_av_land = precCO2_mjjs_av_cut * mask_cut
precICE_mjjs_av_land = precICE_mjjs_av_cut * mask_cut
precORB_mjjs_av_land = precORB_mjjs_av_cut * mask_cut
'''
precALL_mjjs_av_land = precALL_mjjs_av_cut * precip_msk_monscut
precCO2_mjjs_av_land = precCO2_mjjs_av_cut * precip_msk_monscut
precICE_mjjs_av_land = precICE_mjjs_av_cut * precip_msk_monscut
precORB_mjjs_av_land = precORB_mjjs_av_cut * precip_msk_monscut

tasALL_mjjs_av_land  = tasALL_mjjs_av_cut  * mask_cut
tasCO2_mjjs_av_land  = tasCO2_mjjs_av_cut  * mask_cut
tasICE_mjjs_av_land  = tasICE_mjjs_av_cut  * mask_cut
tasORB_mjjs_av_land  = tasORB_mjjs_av_cut  * mask_cut

precALL_mjjs_av_sea  = precALL_mjjs_av_cut * mask_land_cut
precCO2_mjjs_av_sea  = precCO2_mjjs_av_cut * mask_land_cut
precICE_mjjs_av_sea  = precICE_mjjs_av_cut * mask_land_cut
precORB_mjjs_av_sea  = precORB_mjjs_av_cut * mask_land_cut

tasALL_mjjs_av_sea   = tasALL_mjjs_av_cut  * mask_land_cut
tasCO2_mjjs_av_sea   = tasCO2_mjjs_av_cut  * mask_land_cut
tasICE_mjjs_av_sea   = tasICE_mjjs_av_cut  * mask_land_cut
tasORB_mjjs_av_sea   = tasORB_mjjs_av_cut  * mask_land_cut
'----------------------------------------------------------'

'----------------------------------------------------------'
'MEAN'
'----------------------------------------------------------'
precALL_mjjs_av_land_m = [precALL_mjjs_av_land[i, :, :].mean() for i in range(len(precALL_mjjs_av_land))]
precCO2_mjjs_av_land_m = [precCO2_mjjs_av_land[i, :, :].mean() for i in range(len(precALL_mjjs_av_land))]
precICE_mjjs_av_land_m = [precICE_mjjs_av_land[i, :, :].mean() for i in range(len(precALL_mjjs_av_land))]
precORB_mjjs_av_land_m = [precORB_mjjs_av_land[i, :, :].mean() for i in range(len(precALL_mjjs_av_land))]

precALL_mjjs_av_sea_m  = [precALL_mjjs_av_sea[i, :, :].mean()  for i in range(len(precALL_mjjs_av_sea ))]
precCO2_mjjs_av_sea_m  = [precCO2_mjjs_av_sea[i, :, :].mean()  for i in range(len(precALL_mjjs_av_sea ))]
precICE_mjjs_av_sea_m  = [precICE_mjjs_av_sea[i, :, :].mean()  for i in range(len(precALL_mjjs_av_sea ))]
precORB_mjjs_av_sea_m  = [precORB_mjjs_av_sea[i, :, :].mean()  for i in range(len(precALL_mjjs_av_sea ))]

tasALL_mjjs_av_land_m  = [tasALL_mjjs_av_land[i, :, :].mean()  for i in range(len(tasALL_mjjs_av_land ))]
tasCO2_mjjs_av_land_m  = [tasCO2_mjjs_av_land[i, :, :].mean()  for i in range(len(tasALL_mjjs_av_land ))]
tasICE_mjjs_av_land_m  = [tasICE_mjjs_av_land[i, :, :].mean()  for i in range(len(tasALL_mjjs_av_land ))]
tasORB_mjjs_av_land_m  = [tasORB_mjjs_av_land[i, :, :].mean()  for i in range(len(tasALL_mjjs_av_land ))]

tasALL_mjjs_av_sea_m   = [tasALL_mjjs_av_sea[i, :, :].mean()   for i in range(len(tasALL_mjjs_av_sea  ))]
tasCO2_mjjs_av_sea_m   = [tasCO2_mjjs_av_sea[i, :, :].mean()   for i in range(len(tasALL_mjjs_av_sea  ))]
tasICE_mjjs_av_sea_m   = [tasICE_mjjs_av_sea[i, :, :].mean()   for i in range(len(tasALL_mjjs_av_sea  ))]
tasORB_mjjs_av_sea_m   = [tasORB_mjjs_av_sea[i, :, :].mean()   for i in range(len(tasALL_mjjs_av_sea  ))]

prec_sea               = [precALL_mjjs_av_sea_m,  precCO2_mjjs_av_sea_m,  precICE_mjjs_av_sea_m,  precORB_mjjs_av_sea_m ]
prec_land              = [precALL_mjjs_av_land_m, precCO2_mjjs_av_land_m, precICE_mjjs_av_land_m, precORB_mjjs_av_land_m]
tas_sea                = [tasALL_mjjs_av_sea_m,   tasCO2_mjjs_av_sea_m,   tasICE_mjjs_av_sea_m,   tasORB_mjjs_av_sea_m  ]
tas_land               = [tasALL_mjjs_av_land_m,  tasCO2_mjjs_av_land_m,  tasICE_mjjs_av_land_m,  tasORB_mjjs_av_land_m ]
'----------------------------------------------------------'
tutu = np.asarray(tasALL_mjjs_av_land_m) - np.asarray(tasALL_mjjs_av_sea_m)


'----------------------------------------------------------'
'CORR PLOT'
'----------------------------------------------------------'
plt.scatter(tutu, precALL_mjjs_av_land_m, color = 'green', edgecolor = 'grey', alpha = 0.7)
plt.grid()
plt.xlabel('$\Delta$T (land-sea) degree C')
plt.ylabel('Precipitation in mm/day')
plt.title('$\Delta$T (land-sea) as a function of precip (MJJAS) \n precip only monsoon area defined with GPCC data \n - ALL (Su et al., 2025)')
plt.legend(loc = 'lower right')
plt.savefig('/user/work/xk22684/work/fig/fig_paper_monsoon/corr_SU_rev.pdf')
plt.show()
plt.close()

'----------------------------------------------------------'


pdb.set_trace()

'----------------------------------------------------------'
'SPECTRAL ANALYSIS - FIG'
'----------------------------------------------------------'
columns_titles = ['SST_MJJAS', 'LST_MJJAS', 'Land Precip MJJAS \n monsoon area from GPCC', 'Sea Precip MJJAS']
#columns_titles = ['TASsea_ann', 'TASland_ann', 'Land Precip ann', 'Sea Precip ann']
rows_titles    = ['All', 'CO2', 'ICE', 'ORB']
list_var       = [tas_sea, tas_land, prec_land, prec_sea]
time_yr_av     = np.linspace(0, 800, len(time_yr_av))

#rows_titles    = ['All Forcings \n \n Period (kyr)', 'Orb_GhG \n \n Period (kyr)', 'Orb_Only \n \n Period (kyr)']

'''
ts_prec_sea    = []
ts_prec_land   = []
ts_tas_sea     = []
ts_tas_land    = []

#pdb.set_trace()


def wavelet(ts):
	
	ts_std   = ts.detrend().interp().standardize()
	scal     = ts_std.wavelet (method = 'cwt')
	scal_sig = scal.signif_test(method = 'ar1asym')

	return scal_sig
	
	
def psd_sig(ts):

	ts_std   = ts.detrend().interp().standardize()	
	psd_sig  = ts_std.spectral(method = 'mtm').signif_test()
	
	return psd_sig



for i in range(len(prec_sea)):

	ts_ps = pyleo.Series(time = time_yr_av, value = prec_sea [i], time_name = 'yrs', value_name = 'Precip')
	ts_pl = pyleo.Series(time = time_yr_av, value = prec_land[i], time_name = 'yrs', value_name = 'Precip')
	ts_ts = pyleo.Series(time = time_yr_av, value = tas_sea  [i], time_name = 'yrs', value_name = 'Tas'   )
	ts_tl = pyleo.Series(time = time_yr_av, value = tas_land [i], time_name = 'yrs', value_name = 'Tas'   )
	
	ts_prec_sea .append(ts_ps)
	ts_prec_land.append(ts_pl)
	ts_tas_sea  .append(ts_ts)
	ts_tas_land .append(ts_tl)


list_var       = [ts_tas_sea, ts_tas_land, ts_prec_land, ts_prec_sea]

wave_var       = []

#pdb.set_trace()

for i in range(len(list_var)):

	wave   = []
	
	for j in range(len(list_var[i])):
	
		wave.append(wavelet(list_var[i][j]))
	
	wave_var.append(wave)
	
	
	
	
	
#pdb.set_trace()	
	
fig, ax        = plt.subplots(nrows = len(rows_titles), ncols = len(columns_titles), figsize = (20, 20))

for i in range(len(wave_var)):

	for j in range(len(wave_var[0])):
	
		wave_var[i][j].plot(ax = ax[j, i], xlabel = ' ', ylabel = ' ', title = ' ', contourf_style={'cmap':'viridis'})


		if j == len(columns_titles) - 1:

			ax[i, j].set_xlabel('Time (kyrs)')
		

		if i == 0:
			
			ax[j, i].set_ylabel(rows_titles[j], size = 'large')
			
			
for ax, col in zip(ax[0], columns_titles):
	
	ax.set_title(col, size = 'large')
			

fig.tight_layout()


plt.show()
plt.close()
'''

fig, ax        = plt.subplots(nrows = len(rows_titles), ncols = len(list_var), figsize = (20, 20))

fig2, bx       = plt.subplots(nrows = len(rows_titles), ncols = len(list_var), figsize = (20, 20))

t              = np.linspace(0, 800, len(time_yr_av))
dt             = t[1] - t[0]


for i in range(len(rows_titles)):

	for j in range(len(list_var)):	

		dat         = np.asarray(list_var[j][i])
		N           = dat.size
		p           = np.polyfit(t, dat, 1)  #Find the trend 
		dat_notrend = dat - np.polyval(p, t) #Remove trend from data 
		std         = dat_notrend.std()       
		var         = std ** 2
		dat_norm    = dat_notrend/std        #Normalize data

		mother      = wavelet.Morlet(6)
		s0          = 2*dt*10                    #Starting scale
		dj          = 1/12                    #Twelve sub-octaves per octaves 
		J           = 7/dj                    #Seven powers of two with df sub-octaves
		alpha, _, _ = wavelet.ar1(dat)        #Lag-1 autocorrelation for red noise
		
		wave, scales, freqs, coi, fft, fftfreqs = wavelet.cwt(dat_norm, dt, dj, s0, J, mother)

		iwave       = wavelet.icwt(wave, scales, dt, dj, mother) * std
		power       = (np.abs(wave)) ** 2
		fft_power   = np.abs(fft) ** 2
		period      = (1 / freqs)  

		signif, fft_theor = wavelet.significance(1.0, dt, scales, 0, alpha, significance_level=0.95, wavelet=mother)
		
		sig95       = np.ones([1, N]) * signif[:, None]
		sig95       = power / sig95
		glbl_power  = power.mean(axis=1)
		dof         = N - scales              #Correction for padding at edges

		glbl_signif, tmp = wavelet.significance(var, dt, scales, 1, alpha, significance_level=0.95, dof = dof, wavelet=mother)
	
		#ax[j, i] = plt.axes([0.1, 0.37, 0.65, 0.28])
		
		levels = [0.0625, 0.125, 0.25, 0.5, 1, 2, 4, 8, 16]
	
		ax[i, j].contourf(t, np.log2(period), np.log2(power), np.log2(levels), extend='both', cmap=plt.cm.viridis, levels = np.arange(0, 7, 0.5))

		extent = [t.min(), t.max(), 0, max(period)]

		ax[i, j].contour(t, np.log2(period), sig95, [-99, 1], colors='k', linewidths=2, extent=extent)


		#tutu = np.concatenate([t, t[-1:] + dt, t[-1:] + dt, t[:1] - dt, t[:1] - dt])
		#toto = np.concatenate([np.log2(coi), [1e-9], np.log2(period[-1:]), np.log2(period[-1:]), [1e-9]])
		bottom = np.log2(period.max())
		tutu   = np.concatenate([t, t[::-1]])
		toto   = np.concatenate([np.log2(np.clip(coi, period.min(), period.max())), bottom * np.ones_like(t)])
		
		ax[i, j].fill(tutu, toto, 'k', alpha = 0.3, hatch = 'x')

		Yticks = 2 ** np.arange(np.ceil(np.log2(period.min())), np.ceil(np.log2(period.max())))


		ax[i, j].set_yticks(np.log2(Yticks))
		#ax[i, j].invert_yaxis()
		ax[i, j].set_yticklabels(Yticks)


		if j == 0:
			
			ax[i, j].set_ylabel(rows_titles[i], size = 'large')

		if i == len(rows_titles) - 1:

			ax[i, j].set_xlabel('Time (kyrs)')

		bx[i, j].plot(glbl_signif, np.log2(period), 'k--')
		bx[i, j].plot(var * fft_theor, np.log2(period), '--', color='#cccccc')
		bx[i, j].plot(var * fft_power, np.log2(1./fftfreqs), '-', color='#cccccc', linewidth=1.)
		bx[i, j].plot(var * glbl_power, np.log2(period), 'k-', linewidth=1.5)
		#bx[i, j].set_xlabel(r'Power [({})^2]'.format(units))
		
		if j > 1: 
			bx[i, j].set_xlim([0, 1])
			
		else:
		
			bx[i, j].set_xlim([0, 9])

		bx[i, j].set_ylim(np.log2([period.min(), period.max()]))
		bx[i, j].set_yticks(np.log2(Yticks))
		bx[i, j].set_yticklabels(Yticks)
		bx[i, j].grid(alpha = 0.7)

		if j == 0:

			bx[i, j].set_ylabel(rows_titles[i], size = 'large')

		if i == len(rows_titles) - 1:

			bx[i, j].set_xlabel('Power')


for ax, col in zip(ax[0], columns_titles):
	
	ax.set_title(col, size = 'large')



for bx, col in zip(bx[0], columns_titles):
	
	bx.set_title(col, size = 'large')


fig.savefig('/user/work/xk22684/work/fig/fig_paper_monsoon/spectreCESM.png')
plt.show()
plt.close()


'----------------------------------------------------------'


pdb.set_trace()












