Advertisement
slik1977

MLops_ONNX

Nov 21st, 2024
66
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 15.40 KB | None | 0 0
  1. import os
  2. import glob
  3. import cv2
  4. import time
  5. import numpy as np
  6. from ultralytics import YOLO
  7.  
  8. # Установка устройства для инференса ('cpu' или 'cuda')
  9. DEVICE = 'cuda'  # Измените на 'cpu', если хотите использовать CPU
  10.  
  11. DATA_YAML_PATH = 'C:/Users/edimv/Desktop/stenosis/data.yaml'
  12. model_path = 'best.onnx'
  13.  
  14. # Загрузка модели ONNX
  15. model = YOLO(model_path, task="detect")
  16. # Не используем model.to(DEVICE), так как для ONNX-моделей это не поддерживается
  17.  
  18. def load_annotations(annotation_path, img_width, img_height):
  19.     """
  20.    Загрузка аннотаций из TXT-файла в формате YOLO и преобразование в абсолютные пиксельные координаты.
  21.    """
  22.     annotations = []
  23.     with open(annotation_path, 'r') as f:
  24.         for line in f:
  25.             parts = line.strip().split()
  26.             label = int(parts[0])
  27.             x_center_norm, y_center_norm, width_norm, height_norm = map(float, parts[1:])
  28.             # Преобразование нормализованных координат в абсолютные пиксельные координаты
  29.             x_center = x_center_norm * img_width
  30.             y_center = y_center_norm * img_height
  31.             width = width_norm * img_width
  32.             height = height_norm * img_height
  33.             annotations.append([label, x_center, y_center, width, height])
  34.     return annotations
  35.  
  36. def calculate_iou(box1, box2):
  37.     """
  38.    Вычисление IoU между двумя боксами в формате (x_center, y_center, width, height).
  39.    """
  40.     # Преобразование боксов в (x1, y1, x2, y2)
  41.     box1_x1 = box1[0] - box1[2] / 2
  42.     box1_y1 = box1[1] - box1[3] / 2
  43.     box1_x2 = box1[0] + box1[2] / 2
  44.     box1_y2 = box1[1] + box1[3] / 2
  45.  
  46.     box2_x1 = box2[0] - box2[2] / 2
  47.     box2_y1 = box2[1] - box2[3] / 2
  48.     box2_x2 = box2[0] + box2[2] / 2
  49.     box2_y2 = box2[1] + box2[3] / 2
  50.  
  51.     # Вычисление координат пересечения
  52.     inter_x1 = max(box1_x1, box2_x1)
  53.     inter_y1 = max(box1_y1, box2_y1)
  54.     inter_x2 = min(box1_x2, box2_x2)
  55.     inter_y2 = min(box1_y2, box2_y2)
  56.  
  57.     # Вычисление площади пересечения
  58.     inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
  59.  
  60.     # Вычисление площадей боксов
  61.     box1_area = (box1_x2 - box1_x1) * (box1_y2 - box1_y1)
  62.     box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1)
  63.  
  64.     # Вычисление IoU
  65.     union_area = box1_area + box2_area - inter_area
  66.     if union_area == 0:
  67.         return 0  # Избегаем деления на ноль
  68.     iou = inter_area / union_area
  69.     return iou
  70.  
  71. def draw_boxes(img, pred_boxes, gt_boxes):
  72.     """
  73.    Отрисовка предсказанных и истинных боксов на изображении.
  74.    Предсказанные боксы — зелёные, истинные боксы — красные.
  75.    """
  76.     # Отрисовка предсказанных боксов
  77.     for pred in pred_boxes:
  78.         box = pred['box']
  79.         label = pred['label']
  80.         confidence = pred['conf']
  81.         # Преобразование из [x_center, y_center, width, height] в [x1, y1, x2, y2]
  82.         x_center, y_center, width, height = box
  83.         x1 = int(x_center - width / 2)
  84.         y1 = int(y_center - height / 2)
  85.         x2 = int(x_center + width / 2)
  86.         y2 = int(y_center + height / 2)
  87.         # Отрисовка прямоугольника
  88.         cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)  # Зелёный цвет для предсказаний
  89.         # Добавление метки и уверенности
  90.         cv2.putText(img, f'Pred {label}:{confidence:.2f}', (x1, y1 - 10),
  91.                     cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
  92.  
  93.     # Отрисовка истинных боксов
  94.     for gt in gt_boxes:
  95.         label = gt[0]
  96.         box = gt[1:]  # [x_center, y_center, width, height]
  97.         x_center, y_center, width, height = box
  98.         x1 = int(x_center - width / 2)
  99.         y1 = int(y_center - height / 2)
  100.         x2 = int(x_center + width / 2)
  101.         y2 = int(y_center + height / 2)
  102.         # Отрисовка прямоугольника
  103.         cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)  # Красный цвет для истинных боксов
  104.         # Добавление метки
  105.         cv2.putText(img, f'GT {label}', (x1, y1 - 25),
  106.                     cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
  107.     return img
  108.  
  109. def test_iou(test_folder):
  110.     """
  111.    Тестирование IoU на наборе изображений и аннотаций с использованием предсказаний модели YOLO.
  112.    Также вычисляет средние времена, метрики и сохраняет примеры.
  113.    """
  114.     iou_results = []
  115.     preprocessing_times = []
  116.     processing_times = []
  117.     postprocessing_times = []
  118.     per_image_data = []  # Для хранения данных по каждому изображению
  119.     low_iou_count = 0
  120.     total_TP = 0
  121.     total_FP = 0
  122.     total_FN = 0
  123.  
  124.     image_paths = glob.glob(os.path.join(test_folder, "*.bmp"))
  125.     if not os.path.exists('samples1'):
  126.         os.makedirs('samples1')
  127.  
  128.     for img_path in image_paths:
  129.         # Предобработка
  130.         t1 = time.time()
  131.         annotation_path = img_path.replace(".bmp", ".txt")
  132.         img_filename = os.path.basename(img_path)
  133.        
  134.         if not os.path.exists(annotation_path):
  135.             print(f"Annotation missing for {img_path}")
  136.             continue
  137.  
  138.         img = cv2.imread(img_path)
  139.         if img is None:
  140.             print(f"Failed to read image {img_path}")
  141.             continue
  142.  
  143.         # Сохранение оригинальных размеров изображения (если потребуется)
  144.         original_img_height, original_img_width = img.shape[:2]
  145.  
  146.         # Изменение размера изображения до 800x800
  147.         img_resized = img
  148.         img_height, img_width = img_resized.shape[:2]
  149.  
  150.         # Преобразование изображения в RGB
  151.         img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
  152.  
  153.         # Загрузка аннотаций с использованием новых размеров изображения
  154.         gt_boxes = load_annotations(annotation_path, img_width, img_height)
  155.         t2 = time.time()
  156.         preprocessing_times.append(t2 - t1)
  157.  
  158.         # Инференс
  159.         t3 = time.time()
  160.         results = model.predict(img_rgb, imgsz = 800, device=DEVICE)
  161.         t4 = time.time()
  162.         processing_times.append(t4 - t3)
  163.  
  164.         # Постобработка
  165.         t5 = time.time()
  166.         if results[0].boxes is None or len(results[0].boxes) == 0:
  167.             print(f"No predictions for {img_path}")
  168.             iou_results.append(0)
  169.             per_image_data.append({
  170.                 'img_path': img_path,
  171.                 'max_iou': 0,
  172.                 'ious': [0],
  173.                 'TP': 0,
  174.                 'FP': 0,
  175.                 'FN': len(gt_boxes),
  176.                 'pred_boxes': [],
  177.                 'gt_boxes': gt_boxes
  178.             })
  179.             total_FN += len(gt_boxes)
  180.             if 0 < 0.3:
  181.                 low_iou_count += 1
  182.             t6 = time.time()
  183.             postprocessing_times.append(t6 - t5)
  184.             continue
  185.  
  186.         # Получение предсказанных боксов и вероятностей
  187.         predictions = results[0].boxes.xywh.cpu().numpy()  # [x_center, y_center, width, height]
  188.         confidences = results[0].boxes.conf.cpu().numpy()  # Уверенности
  189.         labels = results[0].boxes.cls.cpu().numpy()  # Классы
  190.  
  191.         pred_boxes = []
  192.         for i in range(len(predictions)):
  193.             pred_box = predictions[i]
  194.             confidence = confidences[i]
  195.             label = int(labels[i])
  196.             pred_boxes.append({'box': pred_box, 'conf': confidence, 'label': label})
  197.  
  198.         # Сортировка предсказаний по уверенности
  199.         pred_boxes.sort(key=lambda x: x['conf'], reverse=True)
  200.  
  201.         matched_gt = []
  202.         ious = []
  203.         TP = 0
  204.         FP = 0
  205.         FN = 0
  206.         for pred in pred_boxes:
  207.             pred_box = pred['box']
  208.             pred_conf = pred['conf']
  209.             pred_label = pred['label']
  210.             best_iou = 0
  211.             best_gt_idx = -1
  212.             for idx, gt_box in enumerate(gt_boxes):
  213.                 if idx in matched_gt:
  214.                     continue  # Уже сопоставлено
  215.                 gt_label = gt_box[0]
  216.                 gt_box_coords = gt_box[1:]
  217.                 iou = calculate_iou(pred_box, gt_box_coords)
  218.                 if iou > best_iou:
  219.                     best_iou = iou
  220.                     best_gt_idx = idx
  221.             if best_iou >= 0.5:
  222.                 TP += 1
  223.                 total_TP += 1
  224.                 matched_gt.append(best_gt_idx)
  225.             else:
  226.                 FP += 1
  227.                 total_FP += 1
  228.             ious.append(best_iou)
  229.  
  230.         FN = len(gt_boxes) - len(matched_gt)
  231.         total_FN += FN
  232.  
  233.         max_iou = max(ious) if ious else 0
  234.         if max_iou < 0.3:
  235.             low_iou_count += 1
  236.  
  237.         iou_results.append(max_iou)
  238.         per_image_data.append({
  239.             'img_path': img_path,
  240.             'max_iou': max_iou,
  241.             'ious': ious,
  242.             'TP': TP,
  243.             'FP': FP,
  244.             'FN': FN,
  245.             'pred_boxes': pred_boxes,
  246.             'gt_boxes': gt_boxes
  247.         })
  248.         t6 = time.time()
  249.         postprocessing_times.append(t6 - t5)
  250.  
  251.     # Вычисление средних времен
  252.     avg_preprocessing_time = sum(preprocessing_times) / len(preprocessing_times) if preprocessing_times else 0
  253.     avg_processing_time = sum(processing_times) / len(processing_times) if processing_times else 0
  254.     avg_postprocessing_time = sum(postprocessing_times) / len(postprocessing_times) if postprocessing_times else 0
  255.  
  256.     # Вычисление метрик
  257.     precision = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0
  258.     recall = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0
  259.  
  260.     # Вычисление mAP
  261.     iou_thresholds = np.arange(0.5, 1.0, 0.05)
  262.     APs = []
  263.     for iou_thresh in iou_thresholds:
  264.         ap = calculate_ap(per_image_data, iou_thresh)
  265.         APs.append(ap)
  266.     mAP50 = APs[0]
  267.     mAP50_95 = np.mean(APs)
  268.  
  269.     # Сохранение лучших и худших примеров с отрисованными боксами
  270.     per_image_data.sort(key=lambda x: x['max_iou'], reverse=True)
  271.     best_samples = per_image_data[:3]
  272.     worst_samples = per_image_data[-3:]
  273.  
  274.     for sample in best_samples:
  275.         img = cv2.imread(sample['img_path'])
  276.         # Изменение размера изображения до 800x800
  277.         img_resized = img
  278.         pred_boxes = sample['pred_boxes']
  279.         gt_boxes = sample['gt_boxes']
  280.         img_with_boxes = draw_boxes(img_resized, pred_boxes, gt_boxes)
  281.         img_name = os.path.basename(sample['img_path'])
  282.         save_path = os.path.join('samples1', f'good_{img_name}')
  283.         cv2.imwrite(save_path, img_with_boxes)
  284.  
  285.     for sample in worst_samples:
  286.         img = cv2.imread(sample['img_path'])
  287.         # Изменение размера изображения до 800x800
  288.         img_resized = img
  289.         pred_boxes = sample['pred_boxes']
  290.         gt_boxes = sample['gt_boxes']
  291.         img_with_boxes = draw_boxes(img_resized, pred_boxes, gt_boxes)
  292.         img_name = os.path.basename(sample['img_path'])
  293.         save_path = os.path.join('samples1', f'bad_{img_name}')
  294.         cv2.imwrite(save_path, img_with_boxes)
  295.  
  296.     # Возврат результатов
  297.     results = {
  298.         'average_iou': sum(iou_results) / len(iou_results) if iou_results else 0,
  299.         'avg_preprocessing_time': avg_preprocessing_time,
  300.         'avg_processing_time': avg_processing_time,
  301.         'avg_postprocessing_time': avg_postprocessing_time,
  302.         'precision': precision,
  303.         'recall': recall,
  304.         'mAP50': mAP50,
  305.         'mAP50_95': mAP50_95,
  306.         'low_iou_count': low_iou_count
  307.     }
  308.     return results
  309.  
  310. def calculate_ap(per_image_data, iou_threshold=0.5):
  311.     """
  312.    Вычисление Average Precision (AP) при заданном пороге IoU.
  313.    """
  314.     # Собираем все предсказания и аннотации
  315.     detections = []
  316.     annotations = []
  317.  
  318.     for idx, data in enumerate(per_image_data):
  319.         preds = data['pred_boxes']
  320.         gts = data['gt_boxes']
  321.         # Добавляем предсказания
  322.         for pred in preds:
  323.             detections.append([idx, pred['label'], pred['conf'], *pred['box']])
  324.         # Добавляем аннотации
  325.         for gt in gts:
  326.             annotations.append([idx, gt[0], *gt[1:]])
  327.  
  328.     if len(annotations) == 0:
  329.         return 0
  330.  
  331.     # Сортировка предсказаний по уверенности
  332.     detections.sort(key=lambda x: x[2], reverse=True)
  333.  
  334.     TP = np.zeros(len(detections))
  335.     FP = np.zeros(len(detections))
  336.  
  337.     detected_annotations = []
  338.  
  339.     for d_idx, detection in enumerate(detections):
  340.         image_idx = detection[0]
  341.         detection_label = detection[1]
  342.         detection_conf = detection[2]
  343.         detection_box = detection[3:]
  344.  
  345.         gt_annotations = [ann for ann in annotations if ann[0] == image_idx and ann[1] == detection_label]
  346.         max_iou = 0
  347.         matched_ann_idx = -1
  348.         for ann_idx, ann in enumerate(gt_annotations):
  349.             ann_box = ann[2:]
  350.             iou = calculate_iou(detection_box, ann_box)
  351.             if iou > max_iou:
  352.                 max_iou = iou
  353.                 matched_ann_idx = ann_idx
  354.         if max_iou >= iou_threshold and (image_idx, matched_ann_idx) not in detected_annotations:
  355.             TP[d_idx] = 1
  356.             detected_annotations.append((image_idx, matched_ann_idx))
  357.         else:
  358.             FP[d_idx] = 1
  359.  
  360.     cumulative_TP = np.cumsum(TP)
  361.     cumulative_FP = np.cumsum(FP)
  362.     recalls = cumulative_TP / len(annotations)
  363.     precisions = cumulative_TP / (cumulative_TP + cumulative_FP)
  364.  
  365.     # Избегаем деления на ноль
  366.     recalls = np.concatenate(([0], recalls))
  367.     precisions = np.concatenate(([1], precisions))
  368.  
  369.     # Вычисление AP
  370.     AP = 0
  371.     for i in range(1, len(recalls)):
  372.         AP += (recalls[i] - recalls[i - 1]) * precisions[i]
  373.     return AP
  374.  
  375. if __name__ == '__main__':
  376.     TEST_FOLDER = "test"
  377.     results = test_iou(TEST_FOLDER)
  378.     print("Average IoU on test set:", results['average_iou'])
  379.     print("Average preprocessing time per image:", results['avg_preprocessing_time'])
  380.     print("Average processing time per image:", results['avg_processing_time'])
  381.     print("Average post-processing time per image:", results['avg_postprocessing_time'])
  382.     print("Precision:", results['precision'])
  383.     print("Recall:", results['recall'])
  384.     print("mAP@0.5:", results['mAP50'])
  385.     print("mAP@0.5:0.95:", results['mAP50_95'])
  386.     print("Number of examples with IoU below 0.3:", results['low_iou_count'])
  387.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement