|
@@ -9,7 +9,7 @@ import pandas as pd
|
|
|
import matplotlib.pyplot as plt
|
|
|
import time
|
|
|
import numpy as np
|
|
|
-from sklearn.metrics import roc_curve, auc, confusion_matrix
|
|
|
+from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
|
|
|
import seaborn as sns
|
|
|
|
|
|
class CNN_Net(nn.Module):
|
|
@@ -182,6 +182,10 @@ class CNN_Net(nn.Module):
|
|
|
plt.savefig('./confusion_matrix.png')
|
|
|
plt.show()
|
|
|
|
|
|
+ # Classification Report
|
|
|
+ report = classification_report(true_labels, predictionsLabels)
|
|
|
+ print(report)
|
|
|
+
|
|
|
self.train()
|
|
|
|
|
|
return(loss.item())
|