Advertisement
Nickpips

ai.py

Jul 31st, 2024
270
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 12.27 KB | None | 0 0
  1. import asyncio
  2. import logging
  3. import time
  4. from abc import ABC, abstractmethod
  5. from collections.abc import AsyncIterator
  6. from io import BytesIO
  7. from typing import Literal, Optional, cast, get_args
  8. from uuid import UUID, uuid4
  9.  
  10. import anthropic
  11. import openai
  12. import tiktoken
  13. from anthropic import NOT_GIVEN, Anthropic, AsyncAnthropic, NotGiven
  14. from anthropic.types import MessageParam
  15. from deepgram import DeepgramClient  # type: ignore
  16. from elevenlabs.client import AsyncElevenLabs
  17. from openai import AsyncOpenAI, RateLimitError
  18. from openai.types.chat import ChatCompletionMessageParam
  19. from pydantic import BaseModel, Field
  20. from pyht import AsyncClient as AsyncPlayHtClient  # type: ignore
  21. from pyht import TTSOptions
  22.  
  23. from config import credentials
  24.  
  25. logger = logging.getLogger("uvicorn")
  26.  
  27.  
  28. class KVStore(ABC):
  29.     @abstractmethod
  30.     async def get(self, key: str) -> Optional[str]: ...
  31.  
  32.     @abstractmethod
  33.     async def set(self, key: str, value: str) -> None: ...
  34.  
  35.  
  36. class AIConnection:
  37.     openai_client: AsyncOpenAI
  38.     anthropic_client: AsyncAnthropic
  39.     sync_anthropic_client: Anthropic
  40.     eleven_labs_client: AsyncElevenLabs
  41.     deepgram_client: DeepgramClient
  42.     play_ht_client: AsyncPlayHtClient
  43.     # Share one global Semaphore across all threads
  44.     openai_ratelimit_semaphore = asyncio.Semaphore(1)
  45.     anthropic_ratelimit_semaphore = asyncio.Semaphore(1)
  46.  
  47.     def __init__(self) -> None:
  48.         self.openai_client = AsyncOpenAI(
  49.             api_key=credentials.ai.openai_api_key.get_secret_value()
  50.         )
  51.         self.anthropic_client = AsyncAnthropic(
  52.             api_key=credentials.ai.anthropic_api_key.get_secret_value()
  53.         )
  54.         self.sync_anthropic_client = Anthropic(
  55.             api_key=credentials.ai.anthropic_api_key.get_secret_value()
  56.         )
  57.         self.eleven_labs_client = AsyncElevenLabs(
  58.             api_key=credentials.ai.elevenlabs_api_key.get_secret_value()
  59.         )
  60.         self.deepgram_client = DeepgramClient(
  61.             credentials.ai.deepgram_api_key.get_secret_value()
  62.         )
  63.         self.play_ht_client = AsyncPlayHtClient(
  64.             credentials.ai.playht_user_id.get_secret_value(),
  65.             credentials.ai.playht_api_key.get_secret_value(),
  66.         )
  67.  
  68.  
  69. # NOTE: API Clients cannot be called from multiple event loops,
  70. # So every asyncio event loop needs its own API connection
  71. ai_connections: dict[asyncio.AbstractEventLoop, AIConnection] = {}
  72.  
  73.  
  74. def get_ai_connection() -> AIConnection:
  75.     event_loop = asyncio.get_event_loop()
  76.     if event_loop not in ai_connections:
  77.         ai_connections[event_loop] = AIConnection()
  78.     return ai_connections[event_loop]
  79.  
  80.  
  81. class TaskOutput(BaseModel):
  82.     id: UUID = Field(default_factory=lambda: uuid4())
  83.  
  84.  
  85. class AIModel(BaseModel):
  86.     company: Literal["openai", "anthropic"]
  87.     model: str
  88.  
  89.  
  90. class AIMessage(BaseModel):
  91.     role: Literal["system", "user", "assistant"]
  92.     content: str
  93.  
  94.  
  95. class AIError(Exception):
  96.     """A class for GPT Task Errors"""
  97.  
  98.  
  99. class AIModerationError(AIError):
  100.     pass
  101.  
  102.  
  103. def ai_num_tokens(model: AIModel, s: str) -> int:
  104.     if model.company == "anthropic":
  105.         # Doesn't actually connect to the network
  106.         return get_ai_connection().sync_anthropic_client.count_tokens(s)
  107.     elif model.company == "openai":
  108.         encoding = tiktoken.encoding_for_model(model.model)
  109.         num_tokens = len(encoding.encode(s))
  110.         return num_tokens
  111.  
  112.  
  113. async def ai_call(
  114.     model: AIModel,
  115.     messages: list[AIMessage],
  116.     *,
  117.     max_tokens: int = 4096,
  118.     temperature: float = 0.0,
  119.     num_ratelimit_retries: int = 10,
  120.     # When using anthropic, the first message must be from the user.
  121.     # If the first message is not a User, this message will be prepended to the messages.
  122.     anthropic_initial_message: str | None = "<START>",
  123.     # If two messages of the same role are given to anthropic, they must be concatenated.
  124.     # This is the delimiter between concatenated.
  125.     anthropic_combine_delimiter: str = "\n",
  126. ) -> str:
  127.     if model.company == "openai":
  128.         for i in range(num_ratelimit_retries):
  129.             try:
  130.  
  131.                 def ai_message_to_openai_message_param(
  132.                     message: AIMessage,
  133.                 ) -> ChatCompletionMessageParam:
  134.                     if message.role == "system":  # noqa: SIM114
  135.                         return {"role": message.role, "content": message.content}
  136.                     elif message.role == "user":  # noqa: SIM114
  137.                         return {"role": message.role, "content": message.content}
  138.                     elif message.role == "assistant":
  139.                         return {"role": message.role, "content": message.content}
  140.  
  141.                 if i > 0:
  142.                     logger.debug("Trying again after RateLimitError...")
  143.                 response = (
  144.                     await get_ai_connection().openai_client.chat.completions.create(
  145.                         model=model.model,
  146.                         messages=[
  147.                             ai_message_to_openai_message_param(message)
  148.                             for message in messages
  149.                         ],
  150.                         temperature=temperature,
  151.                         max_tokens=max_tokens,
  152.                     )
  153.                 )
  154.                 if response.choices[0].message.content is None:
  155.                     raise RuntimeError("OpenAI returned nothing")
  156.                 return response.choices[0].message.content
  157.             except RateLimitError:
  158.                 logger.warning("OpenAI RateLimitError")
  159.                 async with get_ai_connection().openai_ratelimit_semaphore:
  160.                     await asyncio.sleep(1)
  161.         raise TimeoutError("Cannot overcome OpenAI RateLimitError")
  162.  
  163.     elif model.company == "anthropic":
  164.         for i in range(num_ratelimit_retries):
  165.             try:
  166.  
  167.                 def ai_message_to_anthropic_message_param(
  168.                     message: AIMessage,
  169.                 ) -> MessageParam:
  170.                     if message.role == "user" or message.role == "assistant":
  171.                         return {"role": message.role, "content": message.content}
  172.                     elif message.role == "system":
  173.                         raise RuntimeError(
  174.                             "system not allowed in anthropic message param"
  175.                         )
  176.  
  177.                 if i > 0:
  178.                     logger.debug("Trying again after RateLimitError...")
  179.  
  180.                 # Extract system message if it exists
  181.                 system: str | NotGiven = NOT_GIVEN
  182.                 if len(messages) > 0 and messages[0].role == "system":
  183.                     system = messages[0].content
  184.                     messages = messages[1:]
  185.                 # Insert initial message if necessary
  186.                 if (
  187.                     anthropic_initial_message is not None
  188.                     and len(messages) > 0
  189.                     and messages[0].role != "user"
  190.                 ):
  191.                     messages = [
  192.                         AIMessage(role="user", content=anthropic_initial_message)
  193.                     ] + messages
  194.                 # Combined messages (By combining consecutive messages of the same role)
  195.                 combined_messages: list[AIMessage] = []
  196.                 for message in messages:
  197.                     if (
  198.                         len(combined_messages) == 0
  199.                         or combined_messages[-1].role != message.role
  200.                     ):
  201.                         combined_messages.append(message)
  202.                     else:
  203.                         # Copy before edit
  204.                         combined_messages[-1] = combined_messages[-1].model_copy(
  205.                             deep=True
  206.                         )
  207.                         # Merge consecutive messages with the same role
  208.                         combined_messages[-1].content += (
  209.                             anthropic_combine_delimiter + message.content
  210.                         )
  211.                 # Get the response
  212.                 response_message = (
  213.                     await get_ai_connection().anthropic_client.messages.create(
  214.                         model=model.model,
  215.                         system=system,
  216.                         messages=[
  217.                             ai_message_to_anthropic_message_param(message)
  218.                             for message in combined_messages
  219.                         ],
  220.                         temperature=0.0,
  221.                         max_tokens=max_tokens,
  222.                     )
  223.                 )
  224.                 return response_message.content[0].text
  225.             except anthropic.RateLimitError as e:
  226.                 logger.warning(f"Anthropic Error: {repr(e)}")
  227.                 async with get_ai_connection().anthropic_ratelimit_semaphore:
  228.                     await asyncio.sleep(1)
  229.         raise TimeoutError("Cannot overcome Anthropic RateLimitError")
  230.  
  231.  
  232. async def ai_stt(buffer: BytesIO) -> str:
  233.     try:
  234.         response = await get_ai_connection().openai_client.audio.transcriptions.create(
  235.             model="whisper-1",
  236.             file=buffer,
  237.         )
  238.         return response.text
  239.     except openai.BadRequestError as e:
  240.         # Return empty string for audio that's too short
  241.         if e.code == "audio_too_short":
  242.             return ""
  243.         else:
  244.             raise
  245.  
  246.  
  247. class AIVoiceModel(BaseModel):
  248.     company: Literal["openai", "elevenlabs", "playht"]
  249.     voice: str
  250.     speed: float = 1
  251.  
  252.  
  253. async def ai_tts(
  254.     transcript: str,
  255.     *,
  256.     voice_model: Optional[AIVoiceModel] = None,
  257.     low_latency: bool = False,
  258. ) -> AsyncIterator[bytes]:
  259.     if voice_model is None:
  260.         voice_model = AIVoiceModel(company="openai", voice="nova")
  261.  
  262.     async def log_bytes_iterator(
  263.         bytes_generator: AsyncIterator[bytes],
  264.     ) -> AsyncIterator[bytes]:
  265.         t1: float = time.time()
  266.         t2: Optional[float] = None
  267.         async for chunk in bytes_generator:
  268.             if t2 is None:
  269.                 t2 = time.time()
  270.             yield chunk
  271.         if t2 is None:
  272.             t2 = time.time()
  273.         t3: float = time.time()
  274.         logger.debug(
  275.             f"TTS Latency ({t2-t1:.3f}): {repr(transcript)} (low_latency={low_latency})"
  276.         )
  277.         logger.debug(f"TTS ({t3-t2:.3f}): {repr(transcript)}")
  278.  
  279.     bytes_generator: AsyncIterator[bytes]
  280.     match voice_model.company:
  281.         case "openai":
  282.             # Typing
  283.             openai_voice_type = Literal[
  284.                 "alloy", "echo", "fable", "onyx", "nova", "shimmer"
  285.             ]
  286.  
  287.             def voice_to_openai_voice(voice: str) -> openai_voice_type:
  288.                 if voice not in get_args(openai_voice_type):
  289.                     raise ValueError(
  290.                         f"voice must be one of {get_args(openai_voice_type)}, received {voice}"
  291.                     )
  292.                 return cast(openai_voice_type, voice)
  293.  
  294.             # Run it
  295.             response = await get_ai_connection().openai_client.audio.speech.with_raw_response.create(
  296.                 model="tts-1",
  297.                 voice=voice_to_openai_voice(voice_model.voice),
  298.                 input=transcript,
  299.                 response_format="aac",
  300.                 speed=voice_model.speed,
  301.             )
  302.             bytes_generator = response.http_response.aiter_bytes(chunk_size=2048)
  303.         case "playht":
  304.             options = TTSOptions(voice=voice_model.voice, speed=voice_model.speed)
  305.             try:
  306.                 bytes_generator = get_ai_connection().play_ht_client.tts(
  307.                     transcript.strip().replace("  ", " "), options
  308.                 )
  309.             except Exception as e:
  310.                 logger.error(f"Error occurred in Play.ht TTS: {e}")
  311.                 raise
  312.         case "elevenlabs":
  313.             if voice_model.speed != 1:
  314.                 logger.warning(
  315.                     f"elevenlabs does not support speed change to {voice_model.speed}. Ignoring."
  316.                 )
  317.             bytes_generator = await get_ai_connection().eleven_labs_client.generate(
  318.                 text=transcript.strip().replace("  ", " "),
  319.                 voice=voice_model.voice,
  320.                 optimize_streaming_latency=3 if low_latency else None,
  321.                 model="eleven_turbo_v2",
  322.                 stream=True,
  323.             )
  324.     return log_bytes_iterator(bytes_generator=bytes_generator)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement