import new_tools as tools
import numpy as np 
import netCDF4 as nc 
import os
import matplotlib.pyplot as plt 
import pdb


'----------------------------------------------------------------------------'
'DATA RECOVERY'
'----------------------------------------------------------------------------'
nc_file  = nc.Dataset('/user/work/xk22684/work/data/normals_1951_2000_v2022_25.nc')

precip   = nc_file.variables['precip'][:]
time     = nc_file.variables['time'  ][:]
lat, lon = nc_file.variables['lat'   ][:], nc_file.variables['lon'   ][:]
'----------------------------------------------------------------------------'


'----------------------------------------------------------------------------'
'SEASON MEAN'
'----------------------------------------------------------------------------'
precip_MJJAS      = precip[4:9] #may to september
precip_MJJAS_mean = np.ma.mean(precip_MJJAS, axis = 0) #mm/month
precip_MJJAS_mean = precip_MJJAS_mean/30               #mm/day
precip_MJJAS_sum  = np.ma.sum(precip_MJJAS, axis = 0)
precip_MJJAS_sum  = precip_MJJAS_sum/30


precip_NDJFM      = np.concatenate([precip[-2:], precip[:3]]) 
precip_NDJFM_mean = np.ma.mean(precip_NDJFM, axis=0)
precip_NDJFM_mean = precip_NDJFM_mean/30

precip_ann_sum    = np.ma.sum(precip, axis = 0)
precip_ann_sum    = precip_ann_sum/30 

'''
precip_JFM        = precip[:3 ]
precip_JFM_mean   = np.ma.mean(precip_JFM,   axis = 0)

precip_ND         = precip[-2:]
precip_ND_mean    = np.ma.mean(precip_ND,    axis = 0)

precip_NDJFM_mean = (3*precip_JFM_mean + 2*precip_ND_mean)/2
'''
'----------------------------------------------------------------------------'


'----------------------------------------------------------------------------'
'MONSOON REGION - summer minus winter precip > 2mm/day - summer/ann_sum > 0.55'
'----------------------------------------------------------------------------'
precip_diff       = precip_MJJAS_mean - precip_NDJFM_mean
precip_ratio      = precip_MJJAS_sum/precip_ann_sum

precip_diff_2mm   = np.ma.where(precip_diff>2, 2, 0)
precip_ratio_55p  = np.ma.where(precip_ratio>0.55, 2, 0)

precip_monsoon    = np.ma.where(( (precip_diff>2) & (precip_ratio>0.55)), 2, 0)
'----------------------------------------------------------------------------'


'----------------------------------------------------------------------------'
'CUT'
'----------------------------------------------------------------------------'
lat_min, lat_max = 50, 0 
lon_min, lon_max = 60, 130

idx_latmin, idx_latmax = tools.find_nearest(lat, lat_min), tools.find_nearest(lat, lat_max)
idx_lonmin, idx_lonmax = tools.find_nearest(lon, lon_min), tools.find_nearest(lon, lon_max)

lat_asia               = lat [idx_latmin : idx_latmax + 1]
lon_asia               = lon [idx_lonmin : idx_lonmax + 1]

precip_diff_2mm_asia   = precip_diff_2mm [idx_latmin : idx_latmax + 1, idx_lonmin : idx_lonmax + 1]
precip_monsoon_asia    = precip_monsoon  [idx_latmin : idx_latmax + 1, idx_lonmin : idx_lonmax + 1]
'----------------------------------------------------------------------------'



'----------------------------------------------------------------------------'
'CREATE NETCDF'
'----------------------------------------------------------------------------'
output_file_path   = '/user/work/xk22684/work/data/monsoon_msk.nc'


# Save the combined masks to a single netCDF file
with nc.Dataset(output_file_path, 'w') as new_nc:

	# Create dimensions
	for dim_name, dim in nc_file.dimensions.items():

		new_nc.createDimension(dim_name, len(dim))

	# Create coordinate variables
	for var_name, var in nc_file.variables.items():

		if var_name in ['lat', 'lon']:

			new_var    = new_nc.createVariable(var_name, var.datatype, var.dimensions)
			new_var[:] = var[:]
			
			
	# Create mask variables for each basin
	new_nc.createVariable('precip_monsoon', 'f4', ('lat', 'lon'))
	new_nc.variables['precip_monsoon'][:] = precip_monsoon
 
	print(f'Created file {output_file_path} with precipdiff.')

'----------------------------------------------------------------------------'

pdb.set_trace()


'----------------------------------------------------------------------------'
'PLOT'
'----------------------------------------------------------------------------'
tools.plot_Robinson(lon,      lat,      precip_diff_2mm,      palette = 'PuBu', units = ' ', title = 'Where MJJAS precip - NDJFM precip > 2mm/day \n GPCC data 1951_2000')
tools.plot_Robinson(lon_asia, lat_asia, precip_diff_2mm_asia, palette = 'PuBu', units = ' ', title = 'Where MJJAS precip - NDJFM precip > 2mm/day \n GPCC data 1951_2000')

'----------------------------------------------------------------------------'


pdb.set_trace()
