import tools
import numpy as np 
import netCDF4 as nc 
import os
import matplotlib.pyplot as plt 
from scipy.stats.stats import pearsonr
import pdb
import xarray as xr
import new_tools 



'----------------------------------------------------------------------------'
'DATA RECOVERY'
'----------------------------------------------------------------------------'
list_simu     = tools.list_simu_08

nc_files_jan  = new_tools.get_clfiles(list_simu, 'jan')
precip_jan    = new_tools.get_var2d('precip_mm_srf', nc_files_jan)
new_tools.close_files(nc_files_jan)


nc_files_feb  = new_tools.get_clfiles(list_simu, 'feb')
precip_feb    = new_tools.get_var2d('precip_mm_srf', nc_files_feb)
new_tools.close_files(nc_files_feb)


nc_files_mar  = new_tools.get_clfiles(list_simu, 'mar')
precip_mar    = new_tools.get_var2d('precip_mm_srf', nc_files_mar)
new_tools.close_files(nc_files_mar)


nc_files_apr  = new_tools.get_clfiles(list_simu, 'apr')
precip_apr    = new_tools.get_var2d('precip_mm_srf', nc_files_apr)
new_tools.close_files(nc_files_apr)


nc_files_may  = new_tools.get_clfiles(list_simu, 'may')
precip_may    = new_tools.get_var2d('precip_mm_srf', nc_files_may)
new_tools.close_files(nc_files_may)


nc_files_jun  = new_tools.get_clfiles(list_simu, 'jun')
precip_jun    = new_tools.get_var2d('precip_mm_srf', nc_files_jun)
lats,  lons   = nc_files_jun[0]. variables['latitude'][:], nc_files_jun[0]. variables['longitude'][:]
new_tools.close_files(nc_files_jun)


nc_files_jul  = new_tools.get_clfiles(list_simu, 'jul')
precip_jul    = new_tools.get_var2d('precip_mm_srf', nc_files_jul)
new_tools.close_files(nc_files_jul)


nc_files_aug  = new_tools.get_clfiles(list_simu, 'aug')
precip_aug    = new_tools.get_var2d('precip_mm_srf', nc_files_aug)
new_tools.close_files(nc_files_aug)


nc_files_sep  = new_tools.get_clfiles(list_simu, 'sep')
precip_sep    = new_tools.get_var2d('precip_mm_srf', nc_files_sep)
new_tools.close_files(nc_files_sep)


nc_files_oct  = new_tools.get_clfiles(list_simu, 'oct')
precip_oct    = new_tools.get_var2d('precip_mm_srf', nc_files_oct)
new_tools.close_files(nc_files_oct)


nc_files_nov  = new_tools.get_clfiles(list_simu, 'nov')
precip_nov    = new_tools.get_var2d('precip_mm_srf', nc_files_nov)
new_tools.close_files(nc_files_nov)

nc_files_dec  = new_tools.get_clfiles(list_simu, 'dec')
precip_dec    = new_tools.get_var2d('precip_mm_srf', nc_files_dec)
#new_tools.close_files(nc_files_dec)


precip_ann    = [precip_jan, precip_feb, precip_mar, precip_apr, precip_may, precip_jun, precip_jul, precip_aug, precip_sep, precip_oct, precip_nov, precip_dec]
'----------------------------------------------------------------------------'


'----------------------------------------------------------------------------'
'MEAN SEASON'
'----------------------------------------------------------------------------'

precip_winter      = precip_ann[:3] + precip_ann[10:]
precip_summer      = precip_ann[4:9]

precip_summer_sum  = np.ma.sum(np.asarray(precip_summer),  axis = 0)
precip_ann_sum     = np.ma.sum(np.asarray(precip_ann),     axis = 0)

precip_summer_mean = np.ma.mean(np.asarray(precip_summer), axis = 0)
precip_winter_mean = np.ma.mean(np.asarray(precip_winter), axis = 0)
'----------------------------------------------------------------------------'


'----------------------------------------------------------------------------'
'MONSOON REGION - summer minus winter precip > 2mm/day - summer/ann_sum > 0.55'
'----------------------------------------------------------------------------'
precip_diff       = np.asarray(precip_summer_mean)*86400 - np.asarray(precip_winter_mean)*86400
precip_ratio      = np.asarray(precip_summer_sum)/np.asarray(precip_ann_sum)


precip_diff_2mm   = [np.ma.where(precip_diff[i]>2, 2, 0)     for i in range(len(list_simu))]
precip_ratio_55p  = [np.ma.where(precip_ratio[i]>0.55, 2, 0) for i in range(len(list_simu))]

precip_monsoon    = [np.ma.where(( (precip_diff[i]>2) & (precip_ratio[i]>0.55)), 2, 0) for i in range(len(list_simu))]
'----------------------------------------------------------------------------'
pdb.set_trace()

'----------------------------------------------------------------------------'
'CREATE NETCDF'
'----------------------------------------------------------------------------'
'''
for simu in range(len(list_simu)) :

	output_file_path = '/user/work/xk22684/work/data/msk_monsoon_bbc8/' + list_simu[simu] + '_monsoon_msk.nc'
	
	with nc.Dataset(output_file_path, 'w') as new_nc:

		for dim_name, dim in nc_files_dec[0].dimensions.items():
		
			new_nc.createDimension(dim_name, len(dim))

		for var_name, var in nc_files_dec[0].variables.items():
		
			if var_name in ['lats', 'lons']:
			
				new_var    = new_nc.createVariable(var_name, var.datatype, var.dimensions)
				new_var[:] = var[:]
				
		
		new_nc.createVariable('precip_monsoon', 'f4', ('latitude', 'longitude'))
		new_nc.variables['precip_monsoon'][:] = precip_monsoon[simu]
		
		print(f'Created file {output_file_path} with precipdiff.')
'''
for simu in range(len(list_simu)):

    output_file_path = f'/user/work/xk22684/work/data/msk_monsoon_bbc8/{list_simu[simu]}_monsoon_msk.nc'
    
    with nc.Dataset(output_file_path, 'w') as new_nc:

        # Copy dimensions
        for dim_name, dim in nc_files_dec[0].dimensions.items():
            new_nc.createDimension(dim_name, (len(dim) if not dim.isunlimited() else None))

        # Copy latitude and longitude variables
        for var_name, var in nc_files_dec[0].variables.items():
            if var_name in ['latitude', 'longitude']:
                new_var = new_nc.createVariable(var_name, var.datatype, var.dimensions)
                new_var[:] = var[:]

        # Create precip_monsoon variable
        new_var = new_nc.createVariable('precip_monsoon', 'f4', ('latitude', 'longitude'))
        new_var[:] = precip_monsoon[simu]

        print(f'Created file {output_file_path} with precip_monsoon.')

'----------------------------------------------------------------------------'



new_tools.plot_Robinson(lons,      lats,      precip_monsoon[100],     palette = 'PuBu', units = ' ', title = 'test precip > 2mm/day \n bbc08')

pdb.set_trace()
