Advertisement
kopyl

Untitled

Jul 6th, 2023
809
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.33 KB | None | 0 0
  1. import time
  2. import cv2
  3. import numpy as np
  4.  
  5. import numpy
  6. from deepsparse import Engine
  7.  
  8. from yolov8.utils import xywh2xyxy, nms, draw_detections
  9.  
  10.  
  11. class YOLOv8_DeepSparse:
  12.  
  13.     def __init__(self, path, conf_thres=0.7, iou_thres=0.5):
  14.         self.conf_threshold = conf_thres
  15.         self.iou_threshold = iou_thres
  16.  
  17.         # Initialize model
  18.         self.initialize_model(path)
  19.  
  20.     def __call__(self, image):
  21.         return self.detect_objects(image)
  22.  
  23.     def initialize_model(self, path):
  24.         self.session = Engine(path)
  25.         # Get model info
  26.         # self.get_input_details()
  27.         # self.get_output_details()
  28.  
  29.  
  30.     def detect_objects(self, image):
  31.         input_tensor = self.prepare_input(image)
  32.  
  33.         # Perform inference on the image
  34.         outputs = self.inference(input_tensor)
  35.  
  36.         self.boxes, self.scores, self.class_ids = self.process_output(outputs)
  37.  
  38.         return self.boxes, self.scores, self.class_ids
  39.  
  40.     def prepare_input(self, image):
  41.         self.img_height, self.img_width = image.shape[:2]
  42.  
  43.         input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  44.  
  45.         # Resize input image
  46.         input_img = cv2.resize(input_img, (800, 800))
  47.  
  48.         # Scale input pixel values to 0 to 1
  49.         input_img = input_img / 255.0
  50.         input_img = input_img.transpose(2, 0, 1)
  51.         input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
  52.  
  53.         return input_tensor
  54.  
  55.  
  56.     def inference(self, input_tensor):
  57.         start = time.perf_counter()
  58.         outputs = numpy.ascontiguousarray(input_tensor)
  59.         outputs = self.session([outputs])
  60.         return outputs
  61.  
  62.     def process_output(self, output):
  63.         predictions = np.squeeze(output[0]).T
  64.  
  65.         # Filter out object confidence scores below threshold
  66.         scores = np.max(predictions[:, 4:], axis=1)
  67.         predictions = predictions[scores > self.conf_threshold, :]
  68.         scores = scores[scores > self.conf_threshold]
  69.  
  70.         if len(scores) == 0:
  71.             return [], [], []
  72.  
  73.         # Get the class with the highest confidence
  74.         class_ids = np.argmax(predictions[:, 4:], axis=1)
  75.  
  76.         # Get bounding boxes for each object
  77.         boxes = self.extract_boxes(predictions)
  78.  
  79.         # Apply non-maxima suppression to suppress weak, overlapping bounding boxes
  80.         indices = nms(boxes, scores, self.iou_threshold)
  81.  
  82.         return boxes[indices], scores[indices], class_ids[indices]
  83.  
  84.     def extract_boxes(self, predictions):
  85.         # Extract boxes from predictions
  86.         boxes = predictions[:, :4]
  87.  
  88.         # Scale boxes to original image dimensions
  89.         boxes = self.rescale_boxes(boxes)
  90.  
  91.         # Convert boxes to xyxy format
  92.         boxes = xywh2xyxy(boxes)
  93.  
  94.         return boxes
  95.  
  96.     def rescale_boxes(self, boxes):
  97.  
  98.         # Rescale boxes to original image dimensions
  99.         input_shape = np.array([800, 800, 800, 800])
  100.         boxes = np.divide(boxes, input_shape, dtype=np.float32)
  101.         boxes *= np.array([self.img_width, self.img_height, self.img_width, self.img_height])
  102.         return boxes
  103.  
  104.     def draw_detections(self, image, draw_scores=True, mask_alpha=0.4):
  105.  
  106.         return draw_detections(image, self.boxes, self.scores,
  107.                                self.class_ids, mask_alpha)
  108.  
  109.     def get_input_details(self):
  110.         model_inputs = self.session.get_inputs()
  111.         self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
  112.  
  113.         self.input_shape = model_inputs[0].shape
  114.         self.input_height = self.input_shape[2]
  115.         self.input_width = self.input_shape[3]
  116.  
  117.     def get_output_details(self):
  118.         model_outputs = self.session.get_outputs()
  119.         self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
  120.  
  121.  
  122. if __name__ == '__main__':
  123.     from imread_from_url import imread_from_url
  124.  
  125.     model_path = "../models/yolov8m.onnx"
  126.  
  127.     # Initialize YOLOv7 object detector
  128.     yolov7_detector = YOLOv8(model_path, conf_thres=0.3, iou_thres=0.5)
  129.  
  130.     img_url = "https://live.staticflickr.com/13/19041780_d6fd803de0_3k.jpg"
  131.     img = imread_from_url(img_url)
  132.  
  133.     # Detect Objects
  134.     yolov7_detector(img)
  135.  
  136.     # Draw detections
  137.     combined_img = yolov7_detector.draw_detections(img)
  138.     cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
  139.     cv2.imshow("Output", combined_img)
  140.     cv2.waitKey(0)
  141.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement