Advertisement
dan-masek

track small -- v4

Oct 9th, 2019
414
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 15.03 KB | None | 0 0
  1. from cv2_41 import cv2
  2. import numpy as np
  3.  
  4. import logging
  5. import logging.handlers
  6. import os
  7. import time
  8.  
  9. # ============================================================================
  10.  
  11. def init_logging(log_to_console = True):
  12.     import sys
  13.     main_logger = logging.getLogger()
  14.  
  15.     LOG_FILENAME = 'debug.log'
  16.  
  17.     # Check if log exists and should therefore be rolled
  18.     needRoll = os.path.isfile(LOG_FILENAME)
  19.  
  20.     formatter = logging.Formatter(
  21.         fmt='%(asctime)s.%(msecs)03d %(levelname)-8s <%(threadName)s> [%(name)s] %(message)s'
  22.         , datefmt='%Y-%m-%d %H:%M:%S')
  23.  
  24.     handler_file = logging.handlers.RotatingFileHandler(LOG_FILENAME
  25.         , maxBytes = 2**24
  26.         , backupCount = 10)
  27.     handler_file.setFormatter(formatter)
  28.     main_logger.addHandler(handler_file)
  29.  
  30.     if not sys.executable.endswith("pythonw.exe") and log_to_console:
  31.         handler_stream = logging.StreamHandler(sys.stdout)
  32.         handler_stream.setFormatter(formatter)
  33.         main_logger.addHandler(handler_stream)
  34.  
  35.     main_logger.setLevel(logging.DEBUG)
  36.  
  37.     if needRoll:
  38.         # Roll over on application start
  39.         handler_file.doRollover()
  40.  
  41. # ============================================================================
  42.  
  43. FILE_NAME = 'testGood.mp4'
  44.  
  45. MIN_CONTOUR_AREA = 1350
  46. MAX_CONTOUR_AREA = 3500
  47.  
  48. # Region of interest for processing
  49. ROI_POINTS = [
  50.     (400, 0) # Top Left
  51.     , (400, 800) # Bottom Left
  52.     , (1480, 800) # Bottom Right
  53.     , (1480, 0) # Top Right
  54.     ]
  55.  
  56. TRACKING_BOX_SCALE = (1.10, 1.10)
  57.  
  58. # Colors for visualization
  59. LABEL_COLOR = (0, 200, 255)
  60. STALE_COLOR = (127, 255, 255)
  61. TRACKED_COLOR = (127, 255, 127)
  62. UNTRACKED_COLOR = (127, 127, 255)
  63.  
  64. # ============================================================================
  65.  
  66. class IDManager(object):
  67.     def __init__(self):
  68.         self.available_ids = set()
  69.         self.next_id = 0
  70.        
  71.     def acquire_id(self):
  72.         if self.available_ids:
  73.             return self.available_ids.pop()
  74.         new_id = self.next_id
  75.         self.next_id += 1
  76.         return new_id
  77.    
  78.     def release_id(self, id):
  79.         self.available_ids.add(id)
  80.  
  81.  
  82. # ============================================================================
  83.    
  84. class TrackingEntry(object):
  85.     def __init__(self, id, frame_number, frame, bounding_box):
  86.         self.id = id
  87.         self.first_frame_number = frame_number
  88.         self.tracker = cv2.TrackerCSRT_create()
  89.         self.bounding_boxes = [bounding_box]
  90.        
  91.         extended_bounding_box = (bounding_box[0]
  92.             , bounding_box[1]
  93.             , TRACKING_BOX_SCALE[0] * bounding_box[2]
  94.             , TRACKING_BOX_SCALE[1] * bounding_box[3])
  95.         self.tracker.init(frame, extended_bounding_box)
  96.        
  97.     def __repr__(self):
  98.         return "TrackingEntry(id=%d,first_frame=%d,bounds=%s)" % (self.id, self.first_frame_number, self.bounding_boxes)
  99.        
  100.     @property
  101.     def last_bounding_box(self):
  102.         return self.bounding_boxes[-1]
  103.        
  104.     def update(self, frame):
  105.         return self.tracker.update(frame)
  106.    
  107.     def add_bounding_box(self, bounding_box):
  108.         self.bounding_boxes.append(bounding_box)
  109.        
  110.     def add_missed_detection(self):
  111.         self.bounding_boxes.append(None)
  112.        
  113.     @property
  114.     def detection_count(self):
  115.         return len(self.bounding_boxes)
  116.  
  117. # ============================================================================
  118.  
  119. class FrameExtractor(object):
  120.     def __init__(self, frame_shape, roi_points, bg_subtractor):
  121.         self.frame_shape = frame_shape
  122.         self.roi_bounding_box = cv2.boundingRect(roi_points)
  123.         self.bg_subtractor = bg_subtractor
  124.        
  125.         self._create_roi_mask(roi_points)
  126.        
  127.     def _create_roi_mask(self, vertices):
  128.         mask = np.zeros(self.frame_shape, np.uint8)
  129.         color = (255, ) * self.frame_shape[2]
  130.         cv2.fillPoly(mask, vertices, color)
  131.        
  132.         x,y,w,h = self.roi_bounding_box    
  133.         mask = cv2.pyrDown(mask[y:y+h,x:x+w])
  134.         if (mask.size - cv2.countNonZero(mask.flatten())) == 0:
  135.             mask = None
  136.        
  137.         self.roi_mask = mask
  138.        
  139.     def extract_frame(self, source_image):
  140.         x,y,w,h = self.roi_bounding_box    
  141.         source_roi = source_image[y:y+h,x:x+w]
  142.         #drop resolution of working frames
  143.         source_roi = cv2.pyrDown(source_roi)
  144.         #isolate region of interest
  145.         roi = cv2.bitwise_and(source_roi, self.roi_mask) if self.roi_mask is not None else source_roi
  146.         #apply background subraction
  147.         fgmask = self.bg_subtractor.apply(roi)
  148.         #remove shadow pixels and replace them with black pixels
  149.         _,thresh = cv2.threshold(fgmask, 127, 255, cv2.THRESH_BINARY)
  150.        
  151.         return roi, thresh
  152.  
  153. # ============================================================================
  154.  
  155. def update_tracked_objects(tracked_objects, frame):
  156.     #update existing IDs
  157.     for tracked in tracked_objects.itervalues():
  158.         success, box = tracked.update(frame)
  159.         if success:
  160.             x,y,w,h = [int(v) for v in box]
  161.             tracked.add_bounding_box([x,y,w,h])
  162.         else:
  163.             tracked.add_missed_detection()
  164.            
  165. # ----------------------------------------------------------------------------
  166.  
  167. def visualize_tracked_objects(tracked_objects, frame, color):
  168.     for tracked in tracked_objects.itervalues():
  169.         if tracked.last_bounding_box:
  170.             x,y,w,h = tracked.last_bounding_box
  171.             cv2.rectangle(frame, (x,y), (x+w, y+h), color, 4)
  172.            
  173. # ============================================================================
  174.  
  175. def detect_stale_objects(tracked_objects):
  176.     # check for tracking which has stopped or tracking which hasnt moved
  177.     del_list = []
  178.    
  179.     for tracked in tracked_objects.itervalues():
  180.         n = tracked.detection_count - 1
  181.         if n <= 0:
  182.             continue            
  183.         bounds = tracked.bounding_boxes
  184.         if (bounds[n] == bounds[n-1]) and (bounds[0] != bounds[n]):
  185.             if (bounds[n][1] > bounds[0][1]):
  186.                 del_list.append((tracked.id, True, 'Counted(1)'))
  187.             else:
  188.                 del_list.append((tracked.id, False, 'Discarded(1)'))
  189.         elif (n > 5) and (bounds[n] == bounds[n-1]) and (bounds[0] == bounds[n]):
  190.             del_list.append((tracked.id, False, 'Discarded(2)'))
  191.         elif bounds[-1] == None:
  192.             del_list.append((tracked.id, True, 'Counted(2)'))
  193.    
  194.     return del_list
  195.    
  196. # ----------------------------------------------------------------------------
  197.  
  198. def visualize_stale_objects(tracked_objects, stale_list, frame, color, text_color):  
  199.     for id, to_count, label in stale_list:
  200.         tracked = tracked_objects[id]
  201.         pos = (tracked.bounding_boxes[-2][0], tracked.bounding_boxes[-2][1] - 10)
  202.         cv2.putText(frame, label, pos, cv2.FONT_HERSHEY_SIMPLEX, 0.6, text_color, 2)
  203.         if tracked.last_bounding_box:
  204.             x,y,w,h = tracked.last_bounding_box
  205.             cv2.rectangle(frame, (x,y), (x+w, y+h), color, 2)
  206.        
  207. # ----------------------------------------------------------------------------
  208.        
  209. def visualize_total_count(count, frame, color):
  210.     msg = ('Count = %d' % count)
  211.     cv2.putText(frame, msg, (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
  212.  
  213. # ----------------------------------------------------------------------------
  214.  
  215. def update_count(count, stale_list):
  216.     for _, to_count, _ in stale_list:
  217.         if to_count:
  218.             count += 1
  219.    
  220.     return count
  221.  
  222. # ----------------------------------------------------------------------------
  223.            
  224. def retire_stale_objects(tracked_objects, id_manager, stale_list):
  225.     logger = logging.getLogger('retire_stale_objects')
  226.  
  227.     for id, to_count, label in stale_list:
  228.         tracked = tracked_objects[id]
  229.        
  230.         del tracked_objects[id]
  231.         id_manager.release_id(id)
  232.  
  233.         logger.debug("Retired: '%s' -- %s", label, tracked)
  234.        
  235.         del tracked
  236.        
  237.     logger.debug("There are %d tracked objects remaining.", len(tracked_objects))
  238.            
  239. # ============================================================================
  240.    
  241. def is_valid_countour(contour):
  242.     contour_area = cv2.contourArea(contour)
  243.     return (MIN_CONTOUR_AREA <= contour_area <= MAX_CONTOUR_AREA)
  244.    
  245. # ----------------------------------------------------------------------------
  246.  
  247. def detect_objects(frame):
  248.     contours, _ = cv2.findContours(frame, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)  
  249.  
  250.     object_list = []
  251.     for contour in contours:
  252.         if not is_valid_countour(contour):
  253.             continue
  254.         bounding_box = list(cv2.boundingRect(contour))
  255.         object_list.append(bounding_box)
  256.        
  257.     return object_list
  258.  
  259. # ----------------------------------------------------------------------------
  260.    
  261. def find_movement(thresh1, thresh2):
  262.     return detect_objects(cv2.absdiff(thresh1, thresh2))
  263.    
  264. # ----------------------------------------------------------------------------
  265.  
  266. def is_already_tracked(tracked_objects, bounding_box):
  267.     (x1,y1,w1,h1) = bounding_box
  268.    
  269.     for entry in tracked_objects.itervalues():
  270.         (x2,y2,w2,h2) = entry.last_bounding_box
  271.         if (x2 < (x1 + w1 / 2) < (x2 + w2)) and (y2 < (y1 + h1 / 2) < (y2 + h2)):
  272.             return True
  273.    
  274.     return False
  275.    
  276. # ----------------------------------------------------------------------------
  277.  
  278. def detect_untracked(tracked_objects, thresh):
  279.     # Check if movement was being tracked
  280.     object_list = detect_objects(thresh)
  281.    
  282.     untracked_list = []
  283.     for bounding_rect in object_list:
  284.         if not is_already_tracked(tracked_objects, bounding_rect):
  285.             untracked_list.append(bounding_rect)      
  286.            
  287.     return untracked_list
  288.    
  289. # ----------------------------------------------------------------------------
  290.  
  291. def visualize_untracked(untracked_list, frame, color):
  292.     for untracked in untracked_list:
  293.         (x,y,w,h) = untracked
  294.         cv2.rectangle(frame, (x, y), (x + w, y + h), color, 2)
  295.        
  296. # ----------------------------------------------------------------------------
  297.  
  298. def register_untracked(tracked_objects, id_manager, untracked_list, frame_number, frame):
  299.     logger = logging.getLogger('process_video_stream')
  300.     # Assign tracking
  301.     for untracked in untracked_list:
  302.         new_id = id_manager.acquire_id()
  303.         new_entry = TrackingEntry(new_id, frame_number, frame, untracked)
  304.         tracked_objects[new_id] = new_entry
  305.         logger.debug('Registered new object: %s' % new_entry)
  306.  
  307. # ============================================================================
  308.  
  309. def process_video_stream(cap):
  310.     logger = logging.getLogger('process_video_stream')
  311.    
  312.     frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  313.     frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  314.     frame_shape = (frame_height, frame_width, 3)
  315.  
  316.     # Declare region of interest eliminating some background issues
  317.     # and use it for isolation mask
  318.     roi_points_array = np.array([ROI_POINTS], np.int32)
  319.     logger.debug("ROI points: %s", ROI_POINTS)
  320.    
  321.     # Background removal initiation either KNN or MOG2, KNN yeilded best results in testing
  322.     bg_subtractor = cv2.createBackgroundSubtractorKNN()
  323.    
  324.     frame_extractor = FrameExtractor(frame_shape, roi_points_array, bg_subtractor)
  325.    
  326.     # Grab initial frame
  327.     check, raw_frame = cap.read()
  328.     if not check:
  329.         raise RuntimeError("No frames in the video stream.")
  330.        
  331.     frame_number = 0
  332.  
  333.     # Preprocess the first frame
  334.     curr_frame, curr_thresh = frame_extractor.extract_frame(raw_frame)
  335.    
  336.     tracked_objects = {}
  337.     id_manager = IDManager()
  338.     count = 0
  339.    
  340.     # Main loop
  341.     while True:
  342.         prev_frame, prev_thresh = curr_frame, curr_thresh
  343.         # Read new frames until no more are left
  344.         check, raw_frame = cap.read()
  345.         if not check:
  346.             logger.debug("Reached end of stream.")
  347.             break
  348.            
  349.         frame_number += 1        
  350.         logger.debug("Frame #%04d started.", frame_number)
  351.  
  352.         curr_frame, curr_thresh = frame_extractor.extract_frame(raw_frame)
  353.  
  354.         if tracked_objects:
  355.             logger.debug("* Updating %d tracked object%s..."
  356.                 , len(tracked_objects)
  357.                 , ('s' if len(tracked_objects) != 1 else ''))
  358.             update_tracked_objects(tracked_objects, prev_frame)
  359.            
  360.             visualize_tracked_objects(tracked_objects, prev_frame, TRACKED_COLOR)
  361.  
  362.             stale_list = detect_stale_objects(tracked_objects)
  363.             if stale_list:
  364.                 logger.debug("* Retiring %d stale tracked object%s..."
  365.                     , len(stale_list)
  366.                     , ('s' if len(stale_list) != 1 else ''))
  367.                
  368.                 visualize_stale_objects(tracked_objects, stale_list, prev_frame, STALE_COLOR, LABEL_COLOR)
  369.                 count = update_count(count, stale_list)
  370.                 retire_stale_objects(tracked_objects, id_manager, stale_list)
  371.            
  372.         visualize_total_count(count, prev_frame, LABEL_COLOR)
  373.  
  374.         # Find movement
  375.         movement_bounds = find_movement(prev_thresh, curr_thresh)
  376.         if movement_bounds:
  377.             untracked_list = detect_untracked(tracked_objects, prev_thresh)
  378.             logger.debug("* Detected %d movement%s and %d untracked object%s."
  379.                 , len(movement_bounds), ('s' if len(movement_bounds) != 1 else '')
  380.                 , len(untracked_list), ('s' if len(untracked_list) != 1 else ''))
  381.            
  382.             visualize_untracked(untracked_list, prev_frame, UNTRACKED_COLOR)
  383.             register_untracked(tracked_objects, id_manager, untracked_list, frame_number, prev_frame)
  384.            
  385.         logger.debug("Frame #%04d complete.", frame_number)
  386.        
  387.         # Visualization
  388.         cv2.imshow("Frame", prev_frame)
  389.        
  390.         key = cv2.waitKey(1) & 0xFF
  391.         if key == 27:
  392.             logger.debug("Exit requested.")
  393.             break
  394.        
  395.     # TODO Handle remaining tracked objects
  396.     if tracked_objects:
  397.         logger.debug("There %s %d tracked object%s remaining."
  398.         , ('are' if len(tracked_objects) != 1 else 'is')
  399.         , len(tracked_objects)
  400.         , ('s' if len(tracked_objects) != 1 else ''))
  401.    
  402.     return count
  403.  
  404. # ============================================================================
  405.  
  406. def main():
  407.     logger = logging.getLogger('main')
  408.    
  409.     # Initialize video input
  410.     cap = cv2.VideoCapture(FILE_NAME)
  411.     if not cap.isOpened():
  412.         raise RuntimeError("Unable to open file '%s'%." % FILE_NAME)
  413.  
  414.     count  = process_video_stream(cap)
  415.    
  416.     logger.debug("TOTAL COUNT = %d", count)
  417.  
  418.     cap.release()
  419.     cv2.destroyAllWindows()
  420.  
  421. # ============================================================================
  422.  
  423. if __name__ == "__main__":
  424.     init_logging(True)
  425.    
  426.     start_time = time.time()
  427.  
  428.     main()
  429.    
  430.     print("\nRuntime: %s seconds" % (time.time() - start_time))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement