Advertisement
miketwo

End Conditions in Multiprocessing Chain

Nov 16th, 2023
96
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.74 KB | Software | 0 0
  1. import multiprocessing
  2. import os
  3. import time
  4. import random
  5. import functools
  6.  
  7. # This program is a toy example of a fanout chain where the output of one
  8. # function is the input of the next.
  9. #
  10. # FuncA --> (multiple) FuncB's --> (multiple) FuncC's --> (multiple) FuncD's
  11. #
  12. # It's particularly interesting because the number of tasks is not known in
  13. # advance. Each function can generate a random number of results. Functions B,
  14. # C, and D are all parallelized and can run at the same time. A,B,and C create
  15. # results, and D is a terminal function that prints the results.
  16. #
  17. # The program should run until all tasks are complete, but I'm not sure how to
  18. # determine that efficiently.
  19.  
  20.  
  21. # Global settings
  22. WORKERS = 4
  23. MIN_RESULTS_PER_FUNCTION = 2
  24. MAX_RESULTS_PER_FUNCTION = 8
  25. MINIMUM_WORK_TIME = 0.05
  26. MAXIMUM_WORK_TIME = 1.00
  27.  
  28.  
  29. class TaskCounter():
  30.     ''' Helper class to track the tasks running in parallel '''
  31.     MANAGER = multiprocessing.Manager()
  32.     TASKS = MANAGER.dict()
  33.     LOCK = MANAGER.Lock()
  34.  
  35.     @staticmethod
  36.     def track_task(func):
  37.         @functools.wraps(func)
  38.         def wrapped(*args, **kwargs):
  39.             with TaskCounter.LOCK:
  40.                 TaskCounter.TASKS[func.__name__] = TaskCounter.TASKS.get(func.__name__, 0) + 1
  41.             retval = func(*args, **kwargs)
  42.             with TaskCounter.LOCK:
  43.                 TaskCounter.TASKS[func.__name__] = TaskCounter.TASKS.get(func.__name__, 0) - 1
  44.             return retval
  45.         return wrapped
  46.  
  47.     @staticmethod
  48.     def tasks_running_string():
  49.         with TaskCounter.LOCK:
  50.             return [f"{key}: {value}" for key, value in TaskCounter.TASKS.items()]
  51.  
  52.     @staticmethod
  53.     def all_tasks_complete() -> bool:
  54.         with TaskCounter.LOCK:
  55.             return all([v == 0 for v in TaskCounter.TASKS.values()])
  56.  
  57.  
  58. def timer(func):
  59.     """Print the runtime of the decorated function"""
  60.     @functools.wraps(func)
  61.     def wrapper_timer(*args, **kwargs):
  62.         start_time = time.perf_counter()
  63.         value = func(*args, **kwargs)
  64.         end_time = time.perf_counter()
  65.         run_time = end_time - start_time
  66.         print(f"Finished {func.__name__!r} in {run_time:.4f} secs")
  67.         return value
  68.     return wrapper_timer
  69.  
  70.  
  71. def do_work():
  72.     # simulate a long operation and return the last digits of the process ID
  73.     time.sleep(random.uniform(MINIMUM_WORK_TIME, MAXIMUM_WORK_TIME))
  74.     return str(os.getpid())[-2:]
  75.  
  76.  
  77. @TaskCounter.track_task
  78. def funcA(queue):
  79.     print(f"funcA: {TaskCounter.tasks_running_string()}")
  80.     number_of_results_to_generate = random.randint(MIN_RESULTS_PER_FUNCTION, MAX_RESULTS_PER_FUNCTION)
  81.     for _ in range(number_of_results_to_generate):
  82.       result = do_work()
  83.       queue.put(("B", result))
  84.  
  85.  
  86. @TaskCounter.track_task
  87. def funcB(history, queue):
  88.     print(f"funcB: {TaskCounter.tasks_running_string()}")
  89.     number_of_results_to_generate = random.randint(MIN_RESULTS_PER_FUNCTION, MAX_RESULTS_PER_FUNCTION)
  90.     for _ in range(number_of_results_to_generate):
  91.       result = do_work()
  92.       queue.put(("C", f"{history} -> {result}"))
  93.  
  94.  
  95. @TaskCounter.track_task
  96. def funcC(history, queue):
  97.     print(f"funcC: {TaskCounter.tasks_running_string()}")
  98.     number_of_results_to_generate = random.randint(MIN_RESULTS_PER_FUNCTION, MAX_RESULTS_PER_FUNCTION)
  99.     for _ in range(number_of_results_to_generate):
  100.       result = do_work()
  101.       queue.put(("D", f"{history} -> {result}"))
  102.  
  103.  
  104. def funcD(history):
  105.     print(f"funcD: {TaskCounter.tasks_running_string()} -- Final path: {history}")
  106.  
  107.  
  108. # ---- This is the part where I'm stuck. Is there a better way? ----
  109. def end_condition(q):
  110.     ''' Try to determine if all tasks are complete. '''
  111.     if q.empty() and TaskCounter.all_tasks_complete():
  112.         print(f"Consumer #{os.getpid()}: End condition met once...")
  113.         # unecessary sleep? give the processes enough time to finish. There could be a race condition
  114.         # between the queue being populated and the task counter being updated.
  115.         time.sleep(MAXIMUM_WORK_TIME)
  116.         return q.empty() and TaskCounter.all_tasks_complete()
  117.     else:
  118.         return False
  119. # -----------------------------------------
  120.  
  121.  
  122. def shutdown(q):
  123.     ''' Send a poison pill to all workers '''
  124.     for _ in range(WORKERS):
  125.         q.put(None)
  126.  
  127.  
  128. def worker(q):
  129.   print(f"Consumer #{os.getpid()}: Alive")
  130.   while True:
  131.     item = q.get(block=True)
  132.     if item is None:
  133.         break
  134.     try:
  135.         if item[0] == "A":
  136.             funcA(q)
  137.         elif item[0] == "B":
  138.             funcB(item[1], q)
  139.         elif item[0] == "C":
  140.             funcC(item[1], q)
  141.         elif item[0] == "D":
  142.             funcD(item[1])
  143.         else:
  144.             raise Exception("Unknown item type")
  145.     except Exception as e:
  146.         print(e)
  147.     finally:
  148.         # ---- This is the part where I'm stuck ----
  149.         # Each worker needs to check if all tasks are complete
  150.         # But the end conditions are not clear to me
  151.         if end_condition(q):
  152.             print(f"Consumer #{os.getpid()}: All tasks complete")
  153.             shutdown(q)
  154.   print(f"Consumer #{os.getpid()}: Exiting")
  155.  
  156.  
  157. def dump_queue(q):
  158.     final = []
  159.     while not q.empty():
  160.         final.append(q.get())
  161.     print(f"Final queue: {final}")
  162.  
  163.  
  164. @timer
  165. def main():
  166.     queue = multiprocessing.Queue()
  167.     queue.put(("A", None))  # load the first call into the queue
  168.     pool = multiprocessing.Pool(processes=WORKERS,
  169.                                 initializer=worker,
  170.                                 initargs=(queue,))
  171.  
  172.     # prevent adding anything more to the worker pool
  173.     # and wait for all workers to finish
  174.     pool.close()
  175.     pool.join()
  176.  
  177.     # See if there's anything left in the queue
  178.     dump_queue(queue)
  179.  
  180.  
  181. if __name__ == '__main__':
  182.     main()
  183.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement