Advertisement
rujain

utils

Dec 17th, 2023 (edited)
19
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 28.22 KB | None | 0 0
  1. import os
  2. import glob
  3. import random
  4. import re
  5. import platform
  6. import subprocess
  7. import time
  8. from pathlib import Path
  9. import time
  10. import logging
  11. import requests
  12. import torch
  13. import torch.nn as nn
  14. import numpy as np
  15. import torchvision
  16. import math
  17. import cv2
  18. from threading import Thread
  19.  
  20. img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
  21. vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  22.  
  23. def time_synchronized():
  24. # pytorch-accurate time
  25. if torch.cuda.is_available():
  26. torch.cuda.synchronize()
  27. return time.time()
  28.  
  29. LOGGER = logging.getLogger("Timer")
  30. logger = logging.getLogger(__name__)
  31. logging.basicConfig(level=logging.INFO)
  32.  
  33. def xywh2xyxy(x):
  34. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  35. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  36. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  37. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  38. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  39. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  40. return y
  41.  
  42. def xyxy2xywh(x):
  43. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  44. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  45. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  46. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  47. y[:, 2] = x[:, 2] - x[:, 0] # width
  48. y[:, 3] = x[:, 3] - x[:, 1] # height
  49. return y
  50.  
  51. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  52. # Rescale coords (xyxy) from img1_shape to img0_shape
  53. if ratio_pad is None: # calculate from img0_shape
  54. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  55. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  56. else:
  57. gain = ratio_pad[0][0]
  58. pad = ratio_pad[1]
  59.  
  60. coords[:, [0, 2]] -= pad[0] # x padding
  61. coords[:, [1, 3]] -= pad[1] # y padding
  62. coords[:, :4] /= gain
  63. clip_coords(coords, img0_shape)
  64. return coords
  65.  
  66. def clip_coords(boxes, img_shape):
  67. # Clip bounding xyxy bounding boxes to image shape (height, width)
  68. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  69. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  70. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  71. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  72.  
  73. def plot_one_box(x, img, color=None, label=None, line_thickness=3):
  74. # Plots one bounding box on image img
  75. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  76. color = color or [random.randint(0, 255) for _ in range(3)]
  77. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  78. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  79. if label:
  80. tf = max(tl - 1, 1) # font thickness
  81. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  82. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  83. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  84. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  85.  
  86. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  87. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  88. x = torch.load(f, map_location=torch.device('cpu'))
  89. if x.get('ema'):
  90. x['model'] = x['ema'] # replace model with ema
  91. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  92. x[k] = None
  93. x['epoch'] = -1
  94. x['model'].half() # to FP16
  95. for p in x['model'].parameters():
  96. p.requires_grad = False
  97. torch.save(x, s or f)
  98. mb = os.path.getsize(s or f) / 1E6 # filesize
  99. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  100.  
  101. def apply_classifier(x, model, img, im0):
  102. # applies a second stage classifier to yolo outputs
  103. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  104. for i, d in enumerate(x): # per image
  105. if d is not None and len(d):
  106. d = d.clone()
  107.  
  108. # Reshape and pad cutouts
  109. b = xyxy2xywh(d[:, :4]) # boxes
  110. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  111. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  112. d[:, :4] = xywh2xyxy(b).long()
  113.  
  114. # Rescale boxes from img_size to im0 size
  115. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  116.  
  117. # Classes
  118. pred_cls1 = d[:, 5].long()
  119. ims = []
  120. for j, a in enumerate(d): # per item
  121. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  122. im = cv2.resize(cutout, (224, 224)) # BGR
  123. # cv2.imwrite('test%i.jpg' % j, cutout)
  124.  
  125. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  126. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  127. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  128. ims.append(im)
  129.  
  130. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  131. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  132.  
  133. return x
  134.  
  135. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
  136. # Resize and pad image while meeting stride-multiple constraints
  137. shape = img.shape[:2] # current shape [height, width]
  138. if isinstance(new_shape, int):
  139. new_shape = (new_shape, new_shape)
  140.  
  141. # Scale ratio (new / old)
  142. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  143. if not scaleup: # only scale down, do not scale up (for better test mAP)
  144. r = min(r, 1.0)
  145.  
  146. # Compute padding
  147. ratio = r, r # width, height ratios
  148. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  149. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  150. if auto: # minimum rectangle
  151. dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
  152. elif scaleFill: # stretch
  153. dw, dh = 0.0, 0.0
  154. new_unpad = (new_shape[1], new_shape[0])
  155. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  156.  
  157. dw /= 2 # divide padding into 2 sides
  158. dh /= 2
  159.  
  160. if shape[::-1] != new_unpad: # resize
  161. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  162. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  163. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  164. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  165. return img, ratio, (dw, dh)
  166.  
  167. def check_imshow():
  168. # Check if environment supports image displays
  169. try:
  170. cv2.imshow('test', np.zeros((1, 1, 3)))
  171. cv2.waitKey(1)
  172. cv2.destroyAllWindows()
  173. cv2.waitKey(1)
  174. return True
  175. except Exception as e:
  176. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  177. return False
  178.  
  179. def load_classifier(name='resnet101', n=2):
  180. # Loads a pretrained model reshaped to n-class output
  181. model = torchvision.models.__dict__[name](pretrained=True)
  182.  
  183. # ResNet model properties
  184. # input_size = [3, 224, 224]
  185. # input_space = 'RGB'
  186. # input_range = [0, 1]
  187. # mean = [0.485, 0.456, 0.406]
  188. # std = [0.229, 0.224, 0.225]
  189.  
  190. # Reshape output to n classes
  191. filters = model.fc.weight.shape[1]
  192. model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
  193. model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
  194. model.fc.out_features = n
  195. return model
  196.  
  197. def make_divisible(x, divisor):
  198. # Returns x evenly divisible by divisor
  199. return math.ceil(x / divisor) * divisor
  200.  
  201. def check_img_size(img_size, s=32):
  202. # Verify img_size is a multiple of stride s
  203. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  204. if new_size != img_size:
  205. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  206. return new_size
  207.  
  208. def select_device(device='', batch_size=None):
  209. # device = 'cpu' or '0' or '0,1,2,3'
  210. cpu = device.lower() == 'cpu'
  211. if cpu:
  212. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  213. elif device: # non-cpu device requested
  214. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
  215. assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
  216.  
  217. cuda = not cpu and torch.cuda.is_available()
  218. if cuda:
  219. n = torch.cuda.device_count()
  220. if n > 1 and batch_size: # check that batch_size is compatible with device_count
  221. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  222. space = ' ' * len(s)
  223. for i, d in enumerate(device.split(',') if device else range(n)):
  224. p = torch.cuda.get_device_properties(i)
  225. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
  226. else:
  227. s += 'CPU\n'
  228.  
  229. logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
  230. return torch.device('cuda:0' if cuda else 'cpu')
  231.  
  232. def autopad(k, p=None): # kernel, padding
  233. # Pad to 'same'
  234. if p is None:
  235. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  236. return p
  237.  
  238. def increment_path(path, exist_ok=True, sep=''):
  239. # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
  240. path = Path(path) # os-agnostic
  241. if (path.exists() and exist_ok) or (not path.exists()):
  242. return str(path)
  243. else:
  244. dirs = glob.glob(f"{path}{sep}*") # similar paths
  245. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  246. i = [int(m.groups()[0]) for m in matches if m] # indices
  247. n = max(i) + 1 if i else 2 # increment number
  248. return f"{path}{sep}{n}" # update path
  249.  
  250.  
  251.  
  252. def box_iou(box1, box2):
  253. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  254. """
  255. Return intersection-over-union (Jaccard index) of boxes.
  256. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  257. Arguments:
  258. box1 (Tensor[N, 4])
  259. box2 (Tensor[M, 4])
  260. Returns:
  261. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  262. IoU values for every element in boxes1 and boxes2
  263. """
  264.  
  265. def box_area(box):
  266. # box = 4xn
  267. return (box[2] - box[0]) * (box[3] - box[1])
  268.  
  269. area1 = box_area(box1.T)
  270. area2 = box_area(box2.T)
  271.  
  272. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  273. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  274. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  275.  
  276. class BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
  277. def _check_input_dim(self, input):
  278. # The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
  279. # is this method that is overwritten by the sub-class
  280. # This original goal of this method was for tensor sanity checks
  281. # If you're ok bypassing those sanity checks (eg. if you trust your inference
  282. # to provide the right dimensional inputs), then you can just use this method
  283. # for easy conversion from SyncBatchNorm
  284. # (unfortunately, SyncBatchNorm does not store the original class - if it did
  285. # we could return the one that was originally created)
  286. return
  287.  
  288. def revert_sync_batchnorm(module):
  289. # this is very similar to the function that it is trying to revert:
  290. # https://github.com/pytorch/pytorch/blob/c8b3686a3e4ba63dc59e5dcfe5db3430df256833/torch/nn/modules/batchnorm.py#L679
  291. module_output = module
  292. if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
  293. new_cls = BatchNormXd
  294. module_output = BatchNormXd(module.num_features,
  295. module.eps, module.momentum,
  296. module.affine,
  297. module.track_running_stats)
  298. if module.affine:
  299. with torch.no_grad():
  300. module_output.weight = module.weight
  301. module_output.bias = module.bias
  302. module_output.running_mean = module.running_mean
  303. module_output.running_var = module.running_var
  304. module_output.num_batches_tracked = module.num_batches_tracked
  305. if hasattr(module, "qconfig"):
  306. module_output.qconfig = module.qconfig
  307. for name, child in module.named_children():
  308. module_output.add_module(name, revert_sync_batchnorm(child))
  309. del module
  310. return module_output
  311.  
  312. def clean_str(s):
  313. # Cleans a string by replacing special characters with underscore _
  314. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  315.  
  316. class LoadImages: # for inference
  317. def __init__(self, path, img_size=640, stride=32):
  318. p = str(Path(path).absolute()) # os-agnostic absolute path
  319. if '*' in p:
  320. files = sorted(glob.glob(p, recursive=True)) # glob
  321. elif os.path.isdir(p):
  322. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  323. elif os.path.isfile(p):
  324. files = [p] # files
  325. else:
  326. raise Exception(f'ERROR: {p} does not exist')
  327.  
  328. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
  329. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
  330. ni, nv = len(images), len(videos)
  331.  
  332. self.img_size = img_size
  333. self.stride = stride
  334. self.files = images + videos
  335. self.nf = ni + nv # number of files
  336. self.video_flag = [False] * ni + [True] * nv
  337. self.mode = 'image'
  338. if any(videos):
  339. self.new_video(videos[0]) # new video
  340. else:
  341. self.cap = None
  342. assert self.nf > 0, f'No images or videos found in {p}. '
  343.  
  344. def __iter__(self):
  345. self.count = 0
  346. return self
  347.  
  348. def __next__(self):
  349. if self.count == self.nf:
  350. raise StopIteration
  351. path = self.files[self.count]
  352.  
  353. if self.video_flag[self.count]:
  354. # Read video
  355. self.mode = 'video'
  356. ret_val, img0 = self.cap.read()
  357. if not ret_val:
  358. self.count += 1
  359. self.cap.release()
  360. if self.count == self.nf: # last video
  361. raise StopIteration
  362. else:
  363. path = self.files[self.count]
  364. self.new_video(path)
  365. ret_val, img0 = self.cap.read()
  366.  
  367. self.frame += 1
  368. print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')
  369.  
  370. else:
  371. # Read image
  372. self.count += 1
  373. img0 = cv2.imread(path) # BGR
  374. assert img0 is not None, 'Image Not Found ' + path
  375. #print(f'image {self.count}/{self.nf} {path}: ', end='')
  376.  
  377. # Padded resize
  378. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  379.  
  380. # Convert
  381. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  382. img = np.ascontiguousarray(img)
  383.  
  384. return path, img, img0, self.cap
  385.  
  386. def new_video(self, path):
  387. self.frame = 0
  388. self.cap = cv2.VideoCapture(path)
  389. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  390.  
  391. def __len__(self):
  392. return self.nf # number of files
  393.  
  394. class LoadStreams: # multiple IP or RTSP cameras
  395. def __init__(self, sources='streams.txt', img_size=640, stride=32):
  396. self.mode = 'stream'
  397. self.img_size = img_size
  398. self.stride = stride
  399.  
  400. if os.path.isfile(sources):
  401. with open(sources, 'r') as f:
  402. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  403. else:
  404. sources = [sources]
  405.  
  406. n = len(sources)
  407. self.imgs = [None] * n
  408. self.sources = [clean_str(x) for x in sources] # clean source names for later
  409. for i, s in enumerate(sources):
  410. # Start the thread to read frames from the video stream
  411. print(f'{i + 1}/{n}: {s}... ', end='')
  412. url = eval(s) if s.isnumeric() else s
  413. cap = cv2.VideoCapture(url)
  414. assert cap.isOpened(), f'Failed to open {s}'
  415. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  416. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  417. self.fps = cap.get(cv2.CAP_PROP_FPS) % 100
  418.  
  419. _, self.imgs[i] = cap.read() # guarantee first frame
  420. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  421. print(f' success ({w}x{h} at {self.fps:.2f} FPS).')
  422. thread.start()
  423. print('') # newline
  424.  
  425. # check for common shapes
  426. s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
  427. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  428. if not self.rect:
  429. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  430.  
  431. def update(self, index, cap):
  432. # Read next stream frame in a daemon thread
  433. n = 0
  434. while cap.isOpened():
  435. n += 1
  436. # _, self.imgs[index] = cap.read()
  437. cap.grab()
  438. if n == 4: # read every 4th frame
  439. success, im = cap.retrieve()
  440. self.imgs[index] = im if success else self.imgs[index] * 0
  441. n = 0
  442. time.sleep(1 / self.fps) # wait time
  443.  
  444. def __iter__(self):
  445. self.count = -1
  446. return self
  447.  
  448. def __next__(self):
  449. self.count += 1
  450. img0 = self.imgs.copy()
  451. if cv2.waitKey(1) == ord('q'): # q to quit
  452. cv2.destroyAllWindows()
  453. raise StopIteration
  454.  
  455. # Letterbox
  456. img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
  457.  
  458. # Stack
  459. img = np.stack(img, 0)
  460.  
  461. # Convert
  462. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  463. img = np.ascontiguousarray(img)
  464.  
  465. return self.sources, img, img0, None
  466.  
  467. def __len__(self):
  468. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  469.  
  470.  
  471. class TracedModel(nn.Module):
  472.  
  473. def __init__(self, model=None, device=None, img_size=(640,640)):
  474. super(TracedModel, self).__init__()
  475.  
  476. print(" Convert model to Traced-model... ")
  477. self.stride = model.stride
  478. self.names = model.names
  479. self.model = model
  480.  
  481. self.model = revert_sync_batchnorm(self.model)
  482. self.model.to('cpu')
  483. self.model.eval()
  484.  
  485. self.detect_layer = self.model.model[-1]
  486. self.model.traced = True
  487.  
  488. rand_example = torch.rand(1, 3, img_size, img_size)
  489.  
  490. traced_script_module = torch.jit.trace(self.model, rand_example, strict=False)
  491. #traced_script_module = torch.jit.script(self.model)
  492. traced_script_module.save("traced_model.pt")
  493. print(" traced_script_module saved! ")
  494. self.model = traced_script_module
  495. self.model.to(device)
  496. self.detect_layer.to(device)
  497. print(" model is traced! \n")
  498.  
  499. def forward(self, x, augment=False, profile=False):
  500. out = self.model(x)
  501. out = self.detect_layer(out)
  502. return out
  503.  
  504. class Ensemble(nn.ModuleList):
  505. # Ensemble of models
  506. def __init__(self):
  507. super(Ensemble, self).__init__()
  508.  
  509. def forward(self, x, augment=False):
  510. y = []
  511. for module in self:
  512. y.append(module(x, augment)[0])
  513. # y = torch.stack(y).max(0)[0] # max ensemble
  514. # y = torch.stack(y).mean(0) # mean ensemble
  515. y = torch.cat(y, 1) # nms ensemble
  516. return y, None # inference, train output
  517.  
  518.  
  519. class Conv(nn.Module):
  520. # Standard convolution
  521. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  522. super(Conv, self).__init__()
  523. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  524. self.bn = nn.BatchNorm2d(c2)
  525. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  526.  
  527. def forward(self, x):
  528. return self.act(self.bn(self.conv(x)))
  529.  
  530. def fuseforward(self, x):
  531. return self.act(self.conv(x))
  532.  
  533.  
  534. def timeIt(func):
  535. def wrapper(*args, **kwargs):
  536. start_time = time.time() * 1000
  537. data = func(*args, **kwargs)
  538. end_time = time.time() * 1000
  539. time_diff = end_time - start_time
  540. time_in_seconds = time_diff / 1000
  541. message = func.__name__ + " took " + str(time_in_seconds) + "s"
  542. LOGGER.warning(message)
  543. return data
  544. return wrapper
  545.  
  546. def attempt_download(file, repo='WongKinYiu/yolov7'):
  547. # Attempt file download if does not exist
  548. file = Path(str(file).strip().replace("'", '').lower())
  549.  
  550. if not file.exists():
  551. try:
  552. response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
  553. assets = [x['name'] for x in response['assets']] # release assets
  554. tag = response['tag_name'] # i.e. 'v1.0'
  555. except: # fallback plan
  556. assets = ['yolov7.pt', 'yolov7-tiny.pt', 'yolov7x.pt', 'yolov7-d6.pt', 'yolov7-e6.pt',
  557. 'yolov7-e6e.pt', 'yolov7-w6.pt']
  558. tag = subprocess.check_output('git tag', shell=True).decode().split()[-1]
  559.  
  560. name = file.name
  561. if name in assets:
  562. msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/'
  563. redundant = False # second download option
  564. try: # GitHub
  565. url = f'https://github.com/{repo}/releases/download/{tag}/{name}'
  566. print(f'Downloading {url} to {file}...')
  567. torch.hub.download_url_to_file(url, file)
  568. assert file.exists() and file.stat().st_size > 1E6 # check
  569. except Exception as e: # GCP
  570. print(f'Download error: {e}')
  571. assert redundant, 'No secondary mirror'
  572. url = f'https://storage.googleapis.com/{repo}/ckpt/{name}'
  573. print(f'Downloading {url} to {file}...')
  574. os.system(f'curl -L {url} -o {file}') # torch.hub.download_url_to_file(url, weights)
  575. finally:
  576. if not file.exists() or file.stat().st_size < 1E6: # check
  577. file.unlink(missing_ok=True) # remove partial downloads
  578. print(f'ERROR: Download failure: {msg}')
  579. print('')
  580. return
  581.  
  582. def attempt_load(weights, map_location=None):
  583. # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
  584. model = Ensemble()
  585. for w in weights if isinstance(weights, list) else [weights]:
  586. attempt_download(w)
  587. ckpt = torch.load(w, map_location=map_location) # load
  588. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
  589.  
  590. # Compatibility updates
  591. for m in model.modules():
  592. if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  593. m.inplace = True # pytorch 1.7.0 compatibility
  594. elif type(m) is nn.Upsample:
  595. m.recompute_scale_factor = None # torch 1.11.0 compatibility
  596. elif type(m) is Conv:
  597. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  598.  
  599. if len(model) == 1:
  600. return model[-1] # return model
  601. else:
  602. print('Ensemble created with %s\n' % weights)
  603. for k in ['names', 'stride']:
  604. setattr(model, k, getattr(model[-1], k))
  605. return model # return ensemble
  606.  
  607. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  608. labels=()):
  609. """Runs Non-Maximum Suppression (NMS) on inference results
  610.  
  611. Returns:
  612. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  613. """
  614.  
  615. nc = prediction.shape[2] - 5 # number of classes
  616. xc = prediction[..., 4] > conf_thres # candidates
  617.  
  618. # Settings
  619. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  620. max_det = 300 # maximum number of detections per image
  621. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  622. time_limit = 10.0 # seconds to quit after
  623. redundant = True # require redundant detections
  624. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  625. merge = False # use merge-NMS
  626.  
  627. t = time.time()
  628. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  629. for xi, x in enumerate(prediction): # image index, image inference
  630. # Apply constraints
  631. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  632. x = x[xc[xi]] # confidence
  633.  
  634. # Cat apriori labels if autolabelling
  635. if labels and len(labels[xi]):
  636. l = labels[xi]
  637. v = torch.zeros((len(l), nc + 5), device=x.device)
  638. v[:, :4] = l[:, 1:5] # box
  639. v[:, 4] = 1.0 # conf
  640. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  641. x = torch.cat((x, v), 0)
  642.  
  643. # If none remain process next image
  644. if not x.shape[0]:
  645. continue
  646.  
  647. # Compute conf
  648. if nc == 1:
  649. x[:, 5:] = x[:, 4:5] # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
  650. # so there is no need to multiplicate.
  651. else:
  652. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  653.  
  654. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  655. box = xywh2xyxy(x[:, :4])
  656.  
  657. # Detections matrix nx6 (xyxy, conf, cls)
  658. if multi_label:
  659. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  660. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  661. else: # best class only
  662. conf, j = x[:, 5:].max(1, keepdim=True)
  663. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  664.  
  665. # Filter by class
  666. if classes is not None:
  667. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  668.  
  669. # Apply finite constraint
  670. # if not torch.isfinite(x).all():
  671. # x = x[torch.isfinite(x).all(1)]
  672.  
  673. # Check shape
  674. n = x.shape[0] # number of boxes
  675. if not n: # no boxes
  676. continue
  677. elif n > max_nms: # excess boxes
  678. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  679.  
  680. # Batched NMS
  681. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  682. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  683. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  684. if i.shape[0] > max_det: # limit detections
  685. i = i[:max_det]
  686. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  687. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  688. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  689. weights = iou * scores[None] # box weights
  690. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  691. if redundant:
  692. i = i[iou.sum(1) > 1] # require redundancy
  693.  
  694. output[xi] = x[i]
  695. if (time.time() - t) > time_limit:
  696. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  697. break # time limit exceeded
  698.  
  699. return output
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement