################################################################################
# Written By Jared Rennie
################################################################################

# Import Packages
import sys, time, datetime, os
import numpy as np
import geopandas as gpd
import pandas as pd

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader

import shapely.geometry as sgeom
import warnings
warnings.filterwarnings("ignore")

# Define directories
main_directory="/store/sfcnet/datasets/fema_risk"
inShapefile_directory=main_directory+'/input_shapefile'
fema_directory=main_directory+'/source'
outShapefile_directory=main_directory+'/results_shapefile'
plot_directory=main_directory+'/results_plots'

#################################################
# BEGIN PROGRAM
start=time.time()

#################################################
# READ IN SHAPEFILE
input_shapefile=inShapefile_directory+'/cb_2018_us_county_500k.shp'

print("READ IN SHAPEFILE: ",input_shapefile)
geo_shapefile = gpd.read_file(input_shapefile)
geo_shapefile['GEOID2_INT'] = np.array(geo_shapefile['GEOID'].values,dtype='i')

# Get Projection
projection=geo_shapefile.crs

#################################################
# READ IN CSV File (fema)
input_csv=fema_directory+'/NRI_Table_Counties.csv'
print("READING IN fema DATA: ",input_csv)
data_fema = pd.read_csv(input_csv,sep=',')

# Clean Up
data_fema = data_fema[['STCOFIPS','RISK_RATNG', 'HWAV_RISKR']]
data_fema['GEOID2_INT']=data_fema['STCOFIPS'].astype(int)
data_fema=data_fema.sort_values(by=['GEOID2_INT']) # Sort

#################################################
# Join
print("\nJOIN")

# Perform the Join
out_shapefile=geo_shapefile.merge(data_fema, on='GEOID2_INT', how='left')

# Set Projection to same as Shapefile
out_shapefile=out_shapefile.set_crs(projection)

# Save as new shapefile
finalShapeFile=outShapefile_directory+'/cdc_fema_fromcsv.shp'
print("OUTPUT TO: "+str(finalShapeFile))
out_shapefile.to_file(finalShapeFile)

#################################################
# PLOTTING (FEMA RISK)
print("PLOTTING (FEMA RISK)")

# Set Bounds
minLat = 22    
maxLat = 50   
minLon = -120 
maxLon = -73 

dpi=300
plt.style.use('dark_background')
land_hex='#efefef'  # ESRI Light Gray Canvas
ocean_hex='#cfd3d4' # ESRI Light Gray Canvas

# Grab Data By variable and plot
inCode='RISK_RATNG'
inTitle='FEMA National Risk Index'
inName='fema'

# Set Up Figure
fig= plt.figure(num=1, figsize=(8,5), dpi=dpi, facecolor='w', edgecolor='k')

# CONUS AXES
conus_ax = fig.add_axes([0, 0, 1, 1], projection=ccrs.LambertConformal())
conus_ax.set_facecolor(ocean_hex)
conus_ax.set_extent([-120, -73, 22, 50], crs=ccrs.Geodetic())  

# ALASKA AXES
ak_ax = fig.add_axes([0.05, 0.01, 0.20, 0.20], projection=ccrs.Orthographic(central_longitude=-133.66666667, central_latitude=57.00000000))
ak_ax.set_facecolor(ocean_hex)
ak_ax.set_extent([-184, -128, 67, 53], crs=ccrs.Geodetic())  

# HAWAII AXES
hi_ax = fig.add_axes([0.25, 0.01, 0.15, 0.15], projection=ccrs.Mercator())
hi_ax.set_facecolor(ocean_hex)
hi_ax.set_extent([-162, -154, 18, 23], crs=ccrs.Geodetic())  

# PUERTO RICO AXES
pr_ax = fig.add_axes([0.60, 0.01, 0.15, 0.15], projection=ccrs.Mercator())
pr_ax.set_facecolor(ocean_hex)
pr_ax.set_extent([-67.5, -65.5, 17.75, 18.75], crs=ccrs.Geodetic())  

# Plot Data For Each County
attribute_counter=0
for county in shpreader.Reader(finalShapeFile).geometries():
    val=out_shapefile.iloc[attribute_counter][inCode]
    stateFips=out_shapefile.iloc[attribute_counter]['STATEFP']

    outColor='#9E9E9E'
    if val == 'Insuffucient Data':
        outColor='#9E9E9E'
    if val == 'Not Applicable' or val == 'nan':
        outColor='#CCCCCC'
    if val == 'No Rating':
        outColor='#FFFFFF'
    if val == 'Very Low':
        outColor='#4D6DBD' 
    if val == 'Relatively Low':
        outColor='#509bc7'
    if val == 'Relatively Moderate':
        outColor='#f0d55d'
    if val == 'Relatively High':
        outColor='#e07068'
    if val == 'Very High':
        outColor='#c7445d' 
        
    if stateFips =='02': #AK
        ak_ax.add_geometries([county], ccrs.PlateCarree(),facecolor=outColor, edgecolor='black',linewidth=0.20)
    if stateFips =='15': #HI
        hi_ax.add_geometries([county], ccrs.PlateCarree(),facecolor=outColor, edgecolor='black',linewidth=0.50)
    if stateFips =='72': #PR
        pr_ax.add_geometries([county], ccrs.PlateCarree(),facecolor=outColor, edgecolor='black',linewidth=0.20)
    else:
        conus_ax.add_geometries([county], ccrs.PlateCarree(),facecolor=outColor, edgecolor='black',linewidth=0.10)
    attribute_counter+=1
conus_ax.add_feature(cfeature.STATES,linewidth=0.5,zorder=10)

# Add ColorMap 
cmap = mpl.colors.ListedColormap(['#9E9E9E','#CCCCCC','#FFFFFF','#4D6DBD','#509bc7','#f0d55d','#e07068','#c7445d'])
bounds = np.arange(cmap.N+1) 
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
cax = fig.add_axes([0.1, -0.035, 0.8, 0.03])
cbar = plt.colorbar(mpl.cm.ScalarMappable(cmap=cmap, norm=norm), cax=cax, boundaries=bounds, ticks=bounds, spacing='uniform', orientation='horizontal')

# Define tick locations and labels
labels=np.array(['Insuffucient Data','Not Applicable','No Rating','Very Low','Relatively Low', 'Relatively Moderate', 'Relatively High', 'Very High'],dtype='str')
tick_locations = np.arange(0.5, float(len(labels)+0.5), 1)  # Adjust the number of tick locations to match the number of boundaries
cbar.set_ticks(tick_locations)
cbar.set_ticklabels(labels)
cbar.ax.tick_params(labelsize=6)

# Add Titles
plt.suptitle(inTitle,size=15,color='white',y=1.05) 
plt.annotate('Source: FEMA\nMade By Jared Rennie (@jjrennie)',xy=(1.045, -3.51), xycoords='axes fraction', fontsize=7,backgroundcolor='black',color='white',horizontalalignment='right', verticalalignment='bottom')

# Save Figure
plt.savefig(plot_directory+"/femaRisk_FULL.png",bbox_inches='tight') 
plt.clf()
plt.close()

####################
# DONE
####################
print("DONE!")
end=time.time()
print(("Runtime: %8.1f seconds" % (end-start)))
sys.exit()