Advertisement
makispaiktis

ML - Lab 3 - Gaussian Naive Bayes with FPR, TPR, AUC

Oct 19th, 2022 (edited)
787
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.70 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from sklearn import datasets
  4. from sklearn.naive_bayes import GaussianNB
  5. from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
  6. from sklearn.metrics import roc_curve, auc
  7.  
  8. # Load an iris dataset
  9. iris = datasets.load_iris()
  10. data = iris.data[:, [0, 1]]
  11. target = iris.target
  12. target[100:125] = 1
  13. target[125:150] = 0
  14.  
  15. # Split into training and testing samples
  16. xtrain = np.concatenate((data[0:40], data[50:90], data[100:140]))
  17. ytrain = np.concatenate((target[0:40], target[50:90], target[100:140]))
  18. xtest = np.concatenate((data[40:50], data[90:100], data[140:150]))
  19. ytest = np.concatenate((target[40:50], target[90:100], target[140:150]))
  20.  
  21. # Gaussian Naive Bayes Classifier
  22. clf = GaussianNB()
  23. clf.fit(xtrain, ytrain)
  24. pred = clf.predict(xtest)
  25. pred_proba = clf.predict_proba(xtest)
  26. print("ytest = " + str(ytest))
  27. print("pred  = " + str(pred))
  28. print()
  29. print("Confusion Matrix = ")
  30. print(confusion_matrix(ytest, pred))
  31. print()
  32. print("Accuracy: ", accuracy_score(ytest, pred))
  33. print("Precision: ", precision_score(ytest, pred, pos_label=1))
  34. print("Recall: ", recall_score(ytest, pred, pos_label=1))
  35. print("F1 Score: ", f1_score(ytest, pred, pos_label=1))
  36. print()
  37. print()
  38.  
  39.  
  40. # ROC Curve
  41. fpr, tpr, thresholds = roc_curve(ytest, pred_proba[:, 1])
  42. # Area under curve
  43. AUC = auc(fpr, tpr)
  44. print("AUC = " + str(AUC))
  45. # Plots
  46. plt.title('Receiver Operating Characteristic')
  47. plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % auc(fpr, tpr))
  48. plt.legend(loc = 'lower right')
  49. plt.plot([0, 1], [0, 1],'r--')
  50. plt.xlim([0, 1])
  51. plt.ylim([0, 1])
  52. plt.xlabel('False Positive Rate')
  53. plt.ylabel('True Positive Rate')
  54. plt.show()
  55.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement