Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import glob
- import random
- import re
- import platform
- import subprocess
- import time
- from pathlib import Path
- import time
- import logging
- import requests
- import torch
- import torch.nn as nn
- import numpy as np
- import torchvision
- import math
- import cv2
- from threading import Thread
- img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
- vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
- def time_synchronized():
- # pytorch-accurate time
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- return time.time()
- LOGGER = logging.getLogger("Timer")
- logger = logging.getLogger(__name__)
- logging.basicConfig(level=logging.INFO)
- def xywh2xyxy(x):
- # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
- y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
- y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
- y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
- y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
- return y
- def xyxy2xywh(x):
- # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
- y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
- y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
- y[:, 2] = x[:, 2] - x[:, 0] # width
- y[:, 3] = x[:, 3] - x[:, 1] # height
- return y
- def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
- # Rescale coords (xyxy) from img1_shape to img0_shape
- if ratio_pad is None: # calculate from img0_shape
- gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
- pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
- else:
- gain = ratio_pad[0][0]
- pad = ratio_pad[1]
- coords[:, [0, 2]] -= pad[0] # x padding
- coords[:, [1, 3]] -= pad[1] # y padding
- coords[:, :4] /= gain
- clip_coords(coords, img0_shape)
- return coords
- def clip_coords(boxes, img_shape):
- # Clip bounding xyxy bounding boxes to image shape (height, width)
- boxes[:, 0].clamp_(0, img_shape[1]) # x1
- boxes[:, 1].clamp_(0, img_shape[0]) # y1
- boxes[:, 2].clamp_(0, img_shape[1]) # x2
- boxes[:, 3].clamp_(0, img_shape[0]) # y2
- def plot_one_box(x, img, color=None, label=None, line_thickness=3):
- # Plots one bounding box on image img
- tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
- color = color or [random.randint(0, 255) for _ in range(3)]
- c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
- cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
- if label:
- tf = max(tl - 1, 1) # font thickness
- t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
- c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
- cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
- cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
- def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
- # Strip optimizer from 'f' to finalize training, optionally save as 's'
- x = torch.load(f, map_location=torch.device('cpu'))
- if x.get('ema'):
- x['model'] = x['ema'] # replace model with ema
- for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
- x[k] = None
- x['epoch'] = -1
- x['model'].half() # to FP16
- for p in x['model'].parameters():
- p.requires_grad = False
- torch.save(x, s or f)
- mb = os.path.getsize(s or f) / 1E6 # filesize
- print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
- def apply_classifier(x, model, img, im0):
- # applies a second stage classifier to yolo outputs
- im0 = [im0] if isinstance(im0, np.ndarray) else im0
- for i, d in enumerate(x): # per image
- if d is not None and len(d):
- d = d.clone()
- # Reshape and pad cutouts
- b = xyxy2xywh(d[:, :4]) # boxes
- b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
- b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
- d[:, :4] = xywh2xyxy(b).long()
- # Rescale boxes from img_size to im0 size
- scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
- # Classes
- pred_cls1 = d[:, 5].long()
- ims = []
- for j, a in enumerate(d): # per item
- cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
- im = cv2.resize(cutout, (224, 224)) # BGR
- # cv2.imwrite('test%i.jpg' % j, cutout)
- im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
- im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
- im /= 255.0 # 0 - 255 to 0.0 - 1.0
- ims.append(im)
- pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
- x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
- return x
- def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
- # Resize and pad image while meeting stride-multiple constraints
- shape = img.shape[:2] # current shape [height, width]
- if isinstance(new_shape, int):
- new_shape = (new_shape, new_shape)
- # Scale ratio (new / old)
- r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
- if not scaleup: # only scale down, do not scale up (for better test mAP)
- r = min(r, 1.0)
- # Compute padding
- ratio = r, r # width, height ratios
- new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
- dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
- if auto: # minimum rectangle
- dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
- elif scaleFill: # stretch
- dw, dh = 0.0, 0.0
- new_unpad = (new_shape[1], new_shape[0])
- ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
- dw /= 2 # divide padding into 2 sides
- dh /= 2
- if shape[::-1] != new_unpad: # resize
- img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
- top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
- left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
- img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
- return img, ratio, (dw, dh)
- def check_imshow():
- # Check if environment supports image displays
- try:
- cv2.imshow('test', np.zeros((1, 1, 3)))
- cv2.waitKey(1)
- cv2.destroyAllWindows()
- cv2.waitKey(1)
- return True
- except Exception as e:
- print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
- return False
- def load_classifier(name='resnet101', n=2):
- # Loads a pretrained model reshaped to n-class output
- model = torchvision.models.__dict__[name](pretrained=True)
- # ResNet model properties
- # input_size = [3, 224, 224]
- # input_space = 'RGB'
- # input_range = [0, 1]
- # mean = [0.485, 0.456, 0.406]
- # std = [0.229, 0.224, 0.225]
- # Reshape output to n classes
- filters = model.fc.weight.shape[1]
- model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
- model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
- model.fc.out_features = n
- return model
- def make_divisible(x, divisor):
- # Returns x evenly divisible by divisor
- return math.ceil(x / divisor) * divisor
- def check_img_size(img_size, s=32):
- # Verify img_size is a multiple of stride s
- new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
- if new_size != img_size:
- print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
- return new_size
- def select_device(device='', batch_size=None):
- # device = 'cpu' or '0' or '0,1,2,3'
- cpu = device.lower() == 'cpu'
- if cpu:
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
- elif device: # non-cpu device requested
- os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
- assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
- cuda = not cpu and torch.cuda.is_available()
- if cuda:
- n = torch.cuda.device_count()
- if n > 1 and batch_size: # check that batch_size is compatible with device_count
- assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
- space = ' ' * len(s)
- for i, d in enumerate(device.split(',') if device else range(n)):
- p = torch.cuda.get_device_properties(i)
- s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
- else:
- s += 'CPU\n'
- logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
- return torch.device('cuda:0' if cuda else 'cpu')
- def autopad(k, p=None): # kernel, padding
- # Pad to 'same'
- if p is None:
- p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
- return p
- def increment_path(path, exist_ok=True, sep=''):
- # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
- path = Path(path) # os-agnostic
- if (path.exists() and exist_ok) or (not path.exists()):
- return str(path)
- else:
- dirs = glob.glob(f"{path}{sep}*") # similar paths
- matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
- i = [int(m.groups()[0]) for m in matches if m] # indices
- n = max(i) + 1 if i else 2 # increment number
- return f"{path}{sep}{n}" # update path
- def box_iou(box1, box2):
- # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
- """
- Return intersection-over-union (Jaccard index) of boxes.
- Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
- Arguments:
- box1 (Tensor[N, 4])
- box2 (Tensor[M, 4])
- Returns:
- iou (Tensor[N, M]): the NxM matrix containing the pairwise
- IoU values for every element in boxes1 and boxes2
- """
- def box_area(box):
- # box = 4xn
- return (box[2] - box[0]) * (box[3] - box[1])
- area1 = box_area(box1.T)
- area2 = box_area(box2.T)
- # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
- inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
- return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
- class BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
- def _check_input_dim(self, input):
- # The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
- # is this method that is overwritten by the sub-class
- # This original goal of this method was for tensor sanity checks
- # If you're ok bypassing those sanity checks (eg. if you trust your inference
- # to provide the right dimensional inputs), then you can just use this method
- # for easy conversion from SyncBatchNorm
- # (unfortunately, SyncBatchNorm does not store the original class - if it did
- # we could return the one that was originally created)
- return
- def revert_sync_batchnorm(module):
- # this is very similar to the function that it is trying to revert:
- # https://github.com/pytorch/pytorch/blob/c8b3686a3e4ba63dc59e5dcfe5db3430df256833/torch/nn/modules/batchnorm.py#L679
- module_output = module
- if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
- new_cls = BatchNormXd
- module_output = BatchNormXd(module.num_features,
- module.eps, module.momentum,
- module.affine,
- module.track_running_stats)
- if module.affine:
- with torch.no_grad():
- module_output.weight = module.weight
- module_output.bias = module.bias
- module_output.running_mean = module.running_mean
- module_output.running_var = module.running_var
- module_output.num_batches_tracked = module.num_batches_tracked
- if hasattr(module, "qconfig"):
- module_output.qconfig = module.qconfig
- for name, child in module.named_children():
- module_output.add_module(name, revert_sync_batchnorm(child))
- del module
- return module_output
- def clean_str(s):
- # Cleans a string by replacing special characters with underscore _
- return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
- class LoadImages: # for inference
- def __init__(self, path, img_size=640, stride=32):
- p = str(Path(path).absolute()) # os-agnostic absolute path
- if '*' in p:
- files = sorted(glob.glob(p, recursive=True)) # glob
- elif os.path.isdir(p):
- files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
- elif os.path.isfile(p):
- files = [p] # files
- else:
- raise Exception(f'ERROR: {p} does not exist')
- images = [x for x in files if x.split('.')[-1].lower() in img_formats]
- videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
- ni, nv = len(images), len(videos)
- self.img_size = img_size
- self.stride = stride
- self.files = images + videos
- self.nf = ni + nv # number of files
- self.video_flag = [False] * ni + [True] * nv
- self.mode = 'image'
- if any(videos):
- self.new_video(videos[0]) # new video
- else:
- self.cap = None
- assert self.nf > 0, f'No images or videos found in {p}. '
- def __iter__(self):
- self.count = 0
- return self
- def __next__(self):
- if self.count == self.nf:
- raise StopIteration
- path = self.files[self.count]
- if self.video_flag[self.count]:
- # Read video
- self.mode = 'video'
- ret_val, img0 = self.cap.read()
- if not ret_val:
- self.count += 1
- self.cap.release()
- if self.count == self.nf: # last video
- raise StopIteration
- else:
- path = self.files[self.count]
- self.new_video(path)
- ret_val, img0 = self.cap.read()
- self.frame += 1
- print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')
- else:
- # Read image
- self.count += 1
- img0 = cv2.imread(path) # BGR
- assert img0 is not None, 'Image Not Found ' + path
- #print(f'image {self.count}/{self.nf} {path}: ', end='')
- # Padded resize
- img = letterbox(img0, self.img_size, stride=self.stride)[0]
- # Convert
- img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
- img = np.ascontiguousarray(img)
- return path, img, img0, self.cap
- def new_video(self, path):
- self.frame = 0
- self.cap = cv2.VideoCapture(path)
- self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
- def __len__(self):
- return self.nf # number of files
- class LoadStreams: # multiple IP or RTSP cameras
- def __init__(self, sources='streams.txt', img_size=640, stride=32):
- self.mode = 'stream'
- self.img_size = img_size
- self.stride = stride
- if os.path.isfile(sources):
- with open(sources, 'r') as f:
- sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
- else:
- sources = [sources]
- n = len(sources)
- self.imgs = [None] * n
- self.sources = [clean_str(x) for x in sources] # clean source names for later
- for i, s in enumerate(sources):
- # Start the thread to read frames from the video stream
- print(f'{i + 1}/{n}: {s}... ', end='')
- url = eval(s) if s.isnumeric() else s
- cap = cv2.VideoCapture(url)
- assert cap.isOpened(), f'Failed to open {s}'
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- self.fps = cap.get(cv2.CAP_PROP_FPS) % 100
- _, self.imgs[i] = cap.read() # guarantee first frame
- thread = Thread(target=self.update, args=([i, cap]), daemon=True)
- print(f' success ({w}x{h} at {self.fps:.2f} FPS).')
- thread.start()
- print('') # newline
- # check for common shapes
- s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
- self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
- if not self.rect:
- print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
- def update(self, index, cap):
- # Read next stream frame in a daemon thread
- n = 0
- while cap.isOpened():
- n += 1
- # _, self.imgs[index] = cap.read()
- cap.grab()
- if n == 4: # read every 4th frame
- success, im = cap.retrieve()
- self.imgs[index] = im if success else self.imgs[index] * 0
- n = 0
- time.sleep(1 / self.fps) # wait time
- def __iter__(self):
- self.count = -1
- return self
- def __next__(self):
- self.count += 1
- img0 = self.imgs.copy()
- if cv2.waitKey(1) == ord('q'): # q to quit
- cv2.destroyAllWindows()
- raise StopIteration
- # Letterbox
- img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
- # Stack
- img = np.stack(img, 0)
- # Convert
- img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
- img = np.ascontiguousarray(img)
- return self.sources, img, img0, None
- def __len__(self):
- return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
- class TracedModel(nn.Module):
- def __init__(self, model=None, device=None, img_size=(640,640)):
- super(TracedModel, self).__init__()
- print(" Convert model to Traced-model... ")
- self.stride = model.stride
- self.names = model.names
- self.model = model
- self.model = revert_sync_batchnorm(self.model)
- self.model.to('cpu')
- self.model.eval()
- self.detect_layer = self.model.model[-1]
- self.model.traced = True
- rand_example = torch.rand(1, 3, img_size, img_size)
- traced_script_module = torch.jit.trace(self.model, rand_example, strict=False)
- #traced_script_module = torch.jit.script(self.model)
- traced_script_module.save("traced_model.pt")
- print(" traced_script_module saved! ")
- self.model = traced_script_module
- self.model.to(device)
- self.detect_layer.to(device)
- print(" model is traced! \n")
- def forward(self, x, augment=False, profile=False):
- out = self.model(x)
- out = self.detect_layer(out)
- return out
- class Ensemble(nn.ModuleList):
- # Ensemble of models
- def __init__(self):
- super(Ensemble, self).__init__()
- def forward(self, x, augment=False):
- y = []
- for module in self:
- y.append(module(x, augment)[0])
- # y = torch.stack(y).max(0)[0] # max ensemble
- # y = torch.stack(y).mean(0) # mean ensemble
- y = torch.cat(y, 1) # nms ensemble
- return y, None # inference, train output
- class Conv(nn.Module):
- # Standard convolution
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
- super(Conv, self).__init__()
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
- self.bn = nn.BatchNorm2d(c2)
- self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
- def forward(self, x):
- return self.act(self.bn(self.conv(x)))
- def fuseforward(self, x):
- return self.act(self.conv(x))
- def timeIt(func):
- def wrapper(*args, **kwargs):
- start_time = time.time() * 1000
- data = func(*args, **kwargs)
- end_time = time.time() * 1000
- time_diff = end_time - start_time
- time_in_seconds = time_diff / 1000
- message = func.__name__ + " took " + str(time_in_seconds) + "s"
- LOGGER.warning(message)
- return data
- return wrapper
- def attempt_download(file, repo='WongKinYiu/yolov7'):
- # Attempt file download if does not exist
- file = Path(str(file).strip().replace("'", '').lower())
- if not file.exists():
- try:
- response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
- assets = [x['name'] for x in response['assets']] # release assets
- tag = response['tag_name'] # i.e. 'v1.0'
- except: # fallback plan
- assets = ['yolov7.pt', 'yolov7-tiny.pt', 'yolov7x.pt', 'yolov7-d6.pt', 'yolov7-e6.pt',
- 'yolov7-e6e.pt', 'yolov7-w6.pt']
- tag = subprocess.check_output('git tag', shell=True).decode().split()[-1]
- name = file.name
- if name in assets:
- msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/'
- redundant = False # second download option
- try: # GitHub
- url = f'https://github.com/{repo}/releases/download/{tag}/{name}'
- print(f'Downloading {url} to {file}...')
- torch.hub.download_url_to_file(url, file)
- assert file.exists() and file.stat().st_size > 1E6 # check
- except Exception as e: # GCP
- print(f'Download error: {e}')
- assert redundant, 'No secondary mirror'
- url = f'https://storage.googleapis.com/{repo}/ckpt/{name}'
- print(f'Downloading {url} to {file}...')
- os.system(f'curl -L {url} -o {file}') # torch.hub.download_url_to_file(url, weights)
- finally:
- if not file.exists() or file.stat().st_size < 1E6: # check
- file.unlink(missing_ok=True) # remove partial downloads
- print(f'ERROR: Download failure: {msg}')
- print('')
- return
- def attempt_load(weights, map_location=None):
- # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
- model = Ensemble()
- for w in weights if isinstance(weights, list) else [weights]:
- attempt_download(w)
- ckpt = torch.load(w, map_location=map_location) # load
- model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
- # Compatibility updates
- for m in model.modules():
- if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
- m.inplace = True # pytorch 1.7.0 compatibility
- elif type(m) is nn.Upsample:
- m.recompute_scale_factor = None # torch 1.11.0 compatibility
- elif type(m) is Conv:
- m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
- if len(model) == 1:
- return model[-1] # return model
- else:
- print('Ensemble created with %s\n' % weights)
- for k in ['names', 'stride']:
- setattr(model, k, getattr(model[-1], k))
- return model # return ensemble
- def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
- labels=()):
- """Runs Non-Maximum Suppression (NMS) on inference results
- Returns:
- list of detections, on (n,6) tensor per image [xyxy, conf, cls]
- """
- nc = prediction.shape[2] - 5 # number of classes
- xc = prediction[..., 4] > conf_thres # candidates
- # Settings
- min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
- max_det = 300 # maximum number of detections per image
- max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
- time_limit = 10.0 # seconds to quit after
- redundant = True # require redundant detections
- multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
- merge = False # use merge-NMS
- t = time.time()
- output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
- for xi, x in enumerate(prediction): # image index, image inference
- # Apply constraints
- # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
- x = x[xc[xi]] # confidence
- # Cat apriori labels if autolabelling
- if labels and len(labels[xi]):
- l = labels[xi]
- v = torch.zeros((len(l), nc + 5), device=x.device)
- v[:, :4] = l[:, 1:5] # box
- v[:, 4] = 1.0 # conf
- v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
- x = torch.cat((x, v), 0)
- # If none remain process next image
- if not x.shape[0]:
- continue
- # Compute conf
- if nc == 1:
- x[:, 5:] = x[:, 4:5] # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
- # so there is no need to multiplicate.
- else:
- x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
- # Box (center x, center y, width, height) to (x1, y1, x2, y2)
- box = xywh2xyxy(x[:, :4])
- # Detections matrix nx6 (xyxy, conf, cls)
- if multi_label:
- i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
- x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
- else: # best class only
- conf, j = x[:, 5:].max(1, keepdim=True)
- x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
- # Filter by class
- if classes is not None:
- x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
- # Apply finite constraint
- # if not torch.isfinite(x).all():
- # x = x[torch.isfinite(x).all(1)]
- # Check shape
- n = x.shape[0] # number of boxes
- if not n: # no boxes
- continue
- elif n > max_nms: # excess boxes
- x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
- # Batched NMS
- c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
- boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
- i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
- if i.shape[0] > max_det: # limit detections
- i = i[:max_det]
- if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
- # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
- iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
- weights = iou * scores[None] # box weights
- x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
- if redundant:
- i = i[iou.sum(1) > 1] # require redundancy
- output[xi] = x[i]
- if (time.time() - t) > time_limit:
- print(f'WARNING: NMS time limit {time_limit}s exceeded')
- break # time limit exceeded
- return output
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement