Advertisement
Prottoy789

precision ,recall, confusion metrix code

Dec 18th, 2024
36
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.39 KB | Source Code | 0 0
  1. from sklearn.metrics import classification_report, precision_score, recall_score, confusion_matrix, average_precision_score, precision_recall_curve
  2. import seaborn as sns
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import pandas as pd
  6.  
  7. # Convert predictions and true labels to numpy arrays
  8. y_pred_probs = resnet_model.predict(validation_data)
  9. y_pred = np.argmax(y_pred_probs, axis=1)
  10. y_true = np.concatenate([y for x, y in validation_data], axis=0)
  11.  
  12. # Ensure the target_names match the two classes: '1L' and '2L'
  13. target_names = ['1L', '2L']
  14.  
  15. # Calculate metrics
  16. precision = precision_score(y_true, y_pred, average=None)
  17. recall = recall_score(y_true, y_pred, average=None)
  18. f1_scores = 2 * (precision * recall) / (precision + recall)
  19.  
  20. # AP50 and AR50
  21. precision_points, recall_points, _ = precision_recall_curve(y_true, y_pred_probs[:, 1])
  22. ap50 = average_precision_score(y_true, y_pred_probs[:, 1])
  23. ar50 = np.mean(recall_points)
  24.  
  25. # IoU Calculation
  26. def calculate_iou(y_true, y_pred):
  27.     intersection = np.sum((y_true == y_pred) & (y_true == 1))
  28.     union = np.sum((y_true | y_pred))
  29.     return intersection / union if union > 0 else 0
  30.  
  31. iou = calculate_iou(y_true, y_pred)
  32.  
  33. # Custom Table in Classification Report Style
  34. metrics_report = pd.DataFrame({
  35.     "precision": list(precision) + [ap50, ar50, None],
  36.     "recall": list(recall) + [None, None, None],
  37.     "f1-score": list(f1_scores) + [None, None, None],
  38.     "support": [sum(y_true == 0), sum(y_true == 1), None, None, None]
  39. }, index=["1L", "2L", "AP50", "AR50", "IoU"])
  40.  
  41. # Add overall accuracy, macro avg, and weighted avg
  42. accuracy = (y_pred == y_true).mean()
  43. metrics_report.loc["accuracy"] = [None, None, None, accuracy]
  44. metrics_report.loc["macro avg"] = [precision.mean(), recall.mean(), f1_scores.mean(), len(y_true)]
  45. metrics_report.loc["weighted avg"] = [precision_score(y_true, y_pred, average='weighted'),
  46.                                       recall_score(y_true, y_pred, average='weighted'),
  47.                                       None, len(y_true)]
  48.  
  49. # Display the table
  50. print("\nCustom Metrics Report:")
  51. print(metrics_report)
  52.  
  53. # Plot confusion matrix
  54. conf_matrix = confusion_matrix(y_true, y_pred)
  55. plt.figure(figsize=(8, 6))
  56. sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=target_names, yticklabels=target_names)
  57. plt.title("Confusion Matrix")
  58. plt.xlabel("Predicted")
  59. plt.ylabel("True")
  60. plt.show()
  61.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement