Advertisement
drinfernoo

Untitled

Jun 20th, 2023
849
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.64 KB | None | 0 0
  1. class BaseModelView(views.MethodView):
  2.     """
  3.    Base class for all ModelView classes.
  4.  
  5.    Attributes:
  6.        item_type (Type[Item]): The type of the item that this view handles.
  7.    """
  8.  
  9.     def __init__(self, item_type: Type[Item]):
  10.         self.item_type = item_type
  11.         self.api_key = ""
  12.         self.chat_api = None
  13.         self.data = None
  14.  
  15.     @classmethod
  16.     def as_view(cls, name, *class_args, **class_kwargs):
  17.         view = super().as_view(name, *class_args, **class_kwargs)
  18.         view.model = cls(*class_args, **class_kwargs)
  19.         return view
  20.  
  21.     async def dispatch_request(self, **kwargs: Any) -> ResponseReturnValue:
  22.         openai_api_key = request.headers.pop(
  23.             "Authorization", "Invalid Authorization"
  24.         ).split(" ")[1]
  25.  
  26.         story_generator = ChatAPI(ai_prefix="Betty", openai_api_key=openai_api_key)
  27.  
  28.         self.api_key = openai_api_key
  29.         self.chat_api = story_generator
  30.         self.data = await request.get_json()
  31.  
  32.         return await super().dispatch_request(**kwargs)
  33.  
  34.  
  35. class StreamItemsView(BaseModelView):
  36.     async def handle_stream_request(self):
  37.         def stream(*args, **kwargs):
  38.             queue = Queue()
  39.             job_done = object()
  40.             queue.put(job_done)
  41.  
  42.             chat_thread = Thread(
  43.                 target=asyncio.run,
  44.                 args=(self.chat_api.stream(*args, queue=queue, **kwargs),),
  45.             )
  46.             chat_thread.start()
  47.  
  48.             while True:
  49.                 try:
  50.                     next_obj = queue.get(True, timeout=1)
  51.                     if next_obj is job_done:
  52.                         if not chat_thread.is_alive():
  53.                             break
  54.                         continue
  55.                     json_obj = json.dumps(asdict(next_obj))
  56.                     byte_array = bytearray(f"{json_obj}\n", encoding="utf-8")
  57.                     print(
  58.                         f"{time.time()} - yield from {type(byte_array)} from stream()"
  59.                     )
  60.                     yield from byte_array
  61.                 except Empty:
  62.                     if not chat_thread.is_alive():
  63.                         break
  64.                     continue
  65.  
  66.         item_request = (
  67.             self.item_type.get_completion_request_model().parse_obj(self.data).dict()
  68.         )
  69.  
  70.         async_gen = stream(self.item_type, **item_request)
  71.         print(f"{time.time()} - yield {type(async_gen)} from handle_stream_request()")
  72.         yield async_gen
  73.  
  74.     async def post(self):
  75.         async_gen = self.handle_stream_request()
  76.         print(f"{time.time()} - return {type(async_gen)} from post()")
  77.         return async_gen
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement