Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from sklearn.metrics import classification_report, precision_score, recall_score, confusion_matrix, average_precision_score, precision_recall_curve
- import seaborn as sns
- import matplotlib.pyplot as plt
- import numpy as np
- import pandas as pd
- # Convert predictions and true labels to numpy arrays
- y_pred_probs = resnet_model.predict(validation_data)
- y_pred = np.argmax(y_pred_probs, axis=1)
- y_true = np.concatenate([y for x, y in validation_data], axis=0)
- # Ensure the target_names match the two classes: '1L' and '2L'
- target_names = ['1L', '2L']
- # Calculate metrics
- precision = precision_score(y_true, y_pred, average=None)
- recall = recall_score(y_true, y_pred, average=None)
- f1_scores = 2 * (precision * recall) / (precision + recall)
- # AP50 and AR50
- precision_points, recall_points, _ = precision_recall_curve(y_true, y_pred_probs[:, 1])
- ap50 = average_precision_score(y_true, y_pred_probs[:, 1])
- ar50 = np.mean(recall_points)
- # IoU Calculation
- def calculate_iou(y_true, y_pred):
- intersection = np.sum((y_true == y_pred) & (y_true == 1))
- union = np.sum((y_true | y_pred))
- return intersection / union if union > 0 else 0
- iou = calculate_iou(y_true, y_pred)
- # Custom Table in Classification Report Style
- metrics_report = pd.DataFrame({
- "precision": list(precision) + [ap50, ar50, None],
- "recall": list(recall) + [None, None, None],
- "f1-score": list(f1_scores) + [None, None, None],
- "support": [sum(y_true == 0), sum(y_true == 1), None, None, None]
- }, index=["1L", "2L", "AP50", "AR50", "IoU"])
- # Add overall accuracy, macro avg, and weighted avg
- accuracy = (y_pred == y_true).mean()
- metrics_report.loc["accuracy"] = [None, None, None, accuracy]
- metrics_report.loc["macro avg"] = [precision.mean(), recall.mean(), f1_scores.mean(), len(y_true)]
- metrics_report.loc["weighted avg"] = [precision_score(y_true, y_pred, average='weighted'),
- recall_score(y_true, y_pred, average='weighted'),
- None, len(y_true)]
- # Display the table
- print("\nCustom Metrics Report:")
- print(metrics_report)
- # Plot confusion matrix
- conf_matrix = confusion_matrix(y_true, y_pred)
- plt.figure(figsize=(8, 6))
- sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=target_names, yticklabels=target_names)
- plt.title("Confusion Matrix")
- plt.xlabel("Predicted")
- plt.ylabel("True")
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement