Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make evenly spaced ticks when threshold_values are not evenly spaced in CSI plots #302

Merged
merged 10 commits into from
Nov 21, 2024
4 changes: 3 additions & 1 deletion melodies_monet/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,7 @@ def plotting(self):
threshold_list = grp_dict['threshold_list']
score_name = grp_dict['score_name']
model_name_list = grp_dict['model_name_list']
xtick_style = grp_dict.get('xtick_style',None)

# first get the observational obs labels
pair1 = self.paired[list(self.paired.keys())[0]]
Expand Down Expand Up @@ -2392,7 +2393,8 @@ def plotting(self):
text_dict=text_dict,
domain_type=domain_type,
domain_name=domain_name,
model_name_list=model_name_list)
model_name_list=model_name_list,
xtick_style=xtick_style)
#save figure
plt.tight_layout()
savefig(outname +'.'+score_name+'.png', loc=1, logo_height=100)
Expand Down
17 changes: 13 additions & 4 deletions melodies_monet/plots/surfplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,7 +1534,7 @@ def Calc_Score(score_name_input,threshold_input, model_input, obs_input):

return output_score

def Plot_CSI(score_name_input,threshold_list_input, comb_bx_input,plot_dict,fig_dict,text_dict,domain_type,domain_name,model_name_list):
def Plot_CSI(score_name_input,threshold_list_input, comb_bx_input,plot_dict,fig_dict,text_dict,domain_type,domain_name,model_name_list,xtick_style):

CSI_output = [] #(2, threshold len)
threshold_list = threshold_list_input
Expand Down Expand Up @@ -1566,7 +1566,10 @@ def Plot_CSI(score_name_input,threshold_list_input, comb_bx_input,plot_dict,fig_

#Make Plot
for i in range(len(CSI_output)):
plt.plot(threshold_list,CSI_output[i],'-*',label=model_name_list[i]) #CHANGE THIS ONE, MAIN PROGRAM
if xtick_style == 'equal':
plt.plot(range(len(threshold_list)),CSI_output[i],'-*',label=model_name_list[i])
else:
plt.plot(threshold_list,CSI_output[i],'-*',label=model_name_list[i])
ax.set_xlabel('Threshold',fontsize = text_kwargs['fontsize']*0.8)
ax.set_ylabel(score_name_input,fontsize = text_kwargs['fontsize']*0.8)
ax.tick_params(labelsize=text_kwargs['fontsize']*0.8)
Expand All @@ -1575,8 +1578,14 @@ def Plot_CSI(score_name_input,threshold_list_input, comb_bx_input,plot_dict,fig_
plt.grid()

#add '>' to xticks
labels = ['>'+item.get_text() for item in ax.get_xticklabels()]
ax.set_xticklabels(labels)
if xtick_style == 'equal':
threshold_string_array = [str(x) for x in threshold_list]
labels = ['>'+item for item in threshold_string_array]
ax.set_xticks(range(len(threshold_list)),labels=labels)
else:
labels = ['>'+item.get_text() for item in ax.get_xticklabels()]
ax.set_xticklabels(labels)

if domain_type is not None and domain_name is not None:
if domain_type == 'epa_region':
ax.set_title('EPA Region ' + domain_name,fontweight='bold',**text_kwargs)
Expand Down