Commit 514d6d50 authored by Leonie Pick's avatar Leonie Pick

Changes to Plots.Curves()

parent b85ef0f6
This diff is collapsed.
......@@ -21,7 +21,7 @@ def LowpassFilter(data, cutoff, fs, order=1):
return y
###
###
def Search_TargetEvents(HMC, HMC11y, HMC5d, dHMC, HTime, DTime, grid, Save):
def Search_TargetEvents(HMC, HMC11y, HMC5d, dHMC, HTime, DTime, grid, Plot, Save):
###### STEP 1 ## Threshold scaled in dependence on solar cycle phase
Index = HMC
......@@ -129,7 +129,8 @@ def Search_TargetEvents(HMC, HMC11y, HMC5d, dHMC, HTime, DTime, grid, Save):
if HMC_diff <= -thresHMC:
IndexMin2.append(i)
pl.Selection(HTime,DTime,HMC,Index_thres1,StormIndices,IndexMin,IndexMin1,IndexMin2,Save)
if Plot == True:
pl.Selection(HTime,DTime,HMC,Index_thres1,StormIndices,IndexMin,IndexMin1,IndexMin2,Save)
return IndexMin2
###
......@@ -168,7 +169,7 @@ def Get_TargetEvents(HMC, HMC11y, HMC5d, dHMC, Kp_all, KpHours_all, Training, Ti
for i in range(len(grid)):
Storms = Search_TargetEvents(HMC[YearsIndex], HMC11y[YearsIndex], HMC5d[YearsIndex], dHMC[YearsIndex],
Time[YearsIndex,:], Date[YearsIndex], grid[i,:], Plot)
Time[YearsIndex,:], Date[YearsIndex], grid[i,:], Plot, Save)
#Found_CIRs = np.where(np.in1d(TrTimeIndex[np.logical_or(TrClass==0,TrClass==2)],Storms+YearsIndex[0]))[0]
#Found_CMEs = np.where(np.in1d(TrTimeIndex[TrClass==1],Storms+YearsIndex[0]))[0]
......@@ -210,7 +211,8 @@ def Get_TargetEvents(HMC, HMC11y, HMC5d, dHMC, Kp_all, KpHours_all, Training, Ti
display(pd.DataFrame(data=TrainingResults,columns=['No. CIRs','No. CMEs', 'Total'], index=['Training set', 'in Target set']))
###
pl.IndexDist(Time,YearsIndex,StormsWin[SelectStr],Kp_all,KpHours_all,HMC,Plot)
if Plot == True:
pl.IndexDist(Time,YearsIndex,StormsWin[SelectStr],Kp_all,KpHours_all,HMC,Save)
return StormsWin, TrFound, FoundStorms
###
......
......@@ -423,8 +423,8 @@ def Curves(N2,K2,Curves,curve_i,Model_Mean,Model_Std,C,Save,SaveName):
ROC_inner = np.nanmean(tpr,axis=1); ROC_outer = np.nanmean(ROC_inner,axis=0)
PR_inner = np.nanmean(precision,axis=1); PR_outer = np.nanmean(PR_inner,axis=0)
axs[0].plot(curve_i,ROC_outer,color='maroon',label=r'AUC = '+str(np.around(Model_Mean[1],3))+'$\pm$'+str(np.around(Model_Std[1],5)),zorder=2)
axs[1].plot(curve_i,PR_outer,color='maroon',label=r'AUC = '+str(np.around(Model_Mean[0],3))+'$\pm$'+str(np.around(Model_Std[0],5)),zorder=2)
axs[0].plot(curve_i,ROC_outer,color='maroon',label='Total mean',zorder=2)
axs[1].plot(curve_i,PR_outer,color='maroon',label='Total mean',zorder=2)
P = C[1,0]; PP = C[1,1]; N = C[0,0]; PN = C[0,1]; POP = sum(C[:,0])
#P,PP,N,PN = C
......@@ -436,18 +436,20 @@ def Curves(N2,K2,Curves,curve_i,Model_Mean,Model_Std,C,Save,SaveName):
axs[0].set_ylabel(r'TPR = TP/P', fontsize=18); axs[1].set_ylabel(r'PPV = TP/PP', fontsize=18)
axs[0].legend(loc=0, frameon=False, fontsize=16); axs[1].legend(loc=0, frameon=False, fontsize=16)
axs[0].set_title('ROC curve',fontsize=18)
axs[0].text(0.275,0.5,r'AUC = '+str(np.around(Model_Mean[1],3))+'$\pm$'+str(np.around(Model_Std[1],5)),bbox=dict(boxstyle='square',ec=(0.,0.,0.),fc=(1.,1.,1.)),fontsize=16,transform=axs[0].transAxes)
axs[0].tick_params(axis ='x',which='both',direction='inout',labelsize=16)
axs[0].tick_params(axis ='y',which='both',direction='inout',labelsize=16)
axs[1].set_title('Precision-Recall curve',fontsize=18)
axs[1].text(0.275,0.5,r'AUC = '+str(np.around(Model_Mean[0],3))+'$\pm$'+str(np.around(Model_Std[0],5)),bbox=dict(boxstyle='square',ec=(0.,0.,0.),fc=(1.,1.,1.)),fontsize=16,transform=axs[1].transAxes)
axs[1].yaxis.set_label_position('right'); axs[1].yaxis.set_ticks_position('right')
axs[1].tick_params(axis ='x',which='both',direction='inout',labelsize=16)
axs[1].tick_params(axis ='y',which='both',direction='inout',labelsize=16)
axs[0].set_xlim([0,1]);axs[0].set_ylim([0,1.03]);axs[1].set_xlim([0,1]);axs[1].set_ylim([0,1.03])
if Save == True:
fig.savefig('./Dump/Fig/development/Curves_'+SaveName+'.pdf',format='pdf',dpi=300,transparent=True)
#fig.savefig('./Dump/Fig/Curves.png',format='png',dpi=300,transparent=True)
#if Save == True:
fig.savefig('./Dump/Fig/development/Curves_'+SaveName+'.pdf',format='pdf',dpi=300,transparent=True)
fig.savefig('./Dump/Fig/development/Curves_'+SaveName+'.png',format='png',dpi=300,transparent=True)
plt.show()
###
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment