Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import asyncio
- import logging
- import time
- from abc import ABC, abstractmethod
- from collections.abc import AsyncIterator
- from io import BytesIO
- from typing import Literal, Optional, cast, get_args
- from uuid import UUID, uuid4
- import anthropic
- import openai
- import tiktoken
- from anthropic import NOT_GIVEN, Anthropic, AsyncAnthropic, NotGiven
- from anthropic.types import MessageParam
- from deepgram import DeepgramClient # type: ignore
- from elevenlabs.client import AsyncElevenLabs
- from openai import AsyncOpenAI, RateLimitError
- from openai.types.chat import ChatCompletionMessageParam
- from pydantic import BaseModel, Field
- from pyht import AsyncClient as AsyncPlayHtClient # type: ignore
- from pyht import TTSOptions
- from config import credentials
- logger = logging.getLogger("uvicorn")
- class KVStore(ABC):
- @abstractmethod
- async def get(self, key: str) -> Optional[str]: ...
- @abstractmethod
- async def set(self, key: str, value: str) -> None: ...
- class AIConnection:
- openai_client: AsyncOpenAI
- anthropic_client: AsyncAnthropic
- sync_anthropic_client: Anthropic
- eleven_labs_client: AsyncElevenLabs
- deepgram_client: DeepgramClient
- play_ht_client: AsyncPlayHtClient
- # Share one global Semaphore across all threads
- openai_ratelimit_semaphore = asyncio.Semaphore(1)
- anthropic_ratelimit_semaphore = asyncio.Semaphore(1)
- def __init__(self) -> None:
- self.openai_client = AsyncOpenAI(
- api_key=credentials.ai.openai_api_key.get_secret_value()
- )
- self.anthropic_client = AsyncAnthropic(
- api_key=credentials.ai.anthropic_api_key.get_secret_value()
- )
- self.sync_anthropic_client = Anthropic(
- api_key=credentials.ai.anthropic_api_key.get_secret_value()
- )
- self.eleven_labs_client = AsyncElevenLabs(
- api_key=credentials.ai.elevenlabs_api_key.get_secret_value()
- )
- self.deepgram_client = DeepgramClient(
- credentials.ai.deepgram_api_key.get_secret_value()
- )
- self.play_ht_client = AsyncPlayHtClient(
- credentials.ai.playht_user_id.get_secret_value(),
- credentials.ai.playht_api_key.get_secret_value(),
- )
- # NOTE: API Clients cannot be called from multiple event loops,
- # So every asyncio event loop needs its own API connection
- ai_connections: dict[asyncio.AbstractEventLoop, AIConnection] = {}
- def get_ai_connection() -> AIConnection:
- event_loop = asyncio.get_event_loop()
- if event_loop not in ai_connections:
- ai_connections[event_loop] = AIConnection()
- return ai_connections[event_loop]
- class TaskOutput(BaseModel):
- id: UUID = Field(default_factory=lambda: uuid4())
- class AIModel(BaseModel):
- company: Literal["openai", "anthropic"]
- model: str
- class AIMessage(BaseModel):
- role: Literal["system", "user", "assistant"]
- content: str
- class AIError(Exception):
- """A class for GPT Task Errors"""
- class AIModerationError(AIError):
- pass
- def ai_num_tokens(model: AIModel, s: str) -> int:
- if model.company == "anthropic":
- # Doesn't actually connect to the network
- return get_ai_connection().sync_anthropic_client.count_tokens(s)
- elif model.company == "openai":
- encoding = tiktoken.encoding_for_model(model.model)
- num_tokens = len(encoding.encode(s))
- return num_tokens
- async def ai_call(
- model: AIModel,
- messages: list[AIMessage],
- *,
- max_tokens: int = 4096,
- temperature: float = 0.0,
- num_ratelimit_retries: int = 10,
- # When using anthropic, the first message must be from the user.
- # If the first message is not a User, this message will be prepended to the messages.
- anthropic_initial_message: str | None = "<START>",
- # If two messages of the same role are given to anthropic, they must be concatenated.
- # This is the delimiter between concatenated.
- anthropic_combine_delimiter: str = "\n",
- ) -> str:
- if model.company == "openai":
- for i in range(num_ratelimit_retries):
- try:
- def ai_message_to_openai_message_param(
- message: AIMessage,
- ) -> ChatCompletionMessageParam:
- if message.role == "system": # noqa: SIM114
- return {"role": message.role, "content": message.content}
- elif message.role == "user": # noqa: SIM114
- return {"role": message.role, "content": message.content}
- elif message.role == "assistant":
- return {"role": message.role, "content": message.content}
- if i > 0:
- logger.debug("Trying again after RateLimitError...")
- response = (
- await get_ai_connection().openai_client.chat.completions.create(
- model=model.model,
- messages=[
- ai_message_to_openai_message_param(message)
- for message in messages
- ],
- temperature=temperature,
- max_tokens=max_tokens,
- )
- )
- if response.choices[0].message.content is None:
- raise RuntimeError("OpenAI returned nothing")
- return response.choices[0].message.content
- except RateLimitError:
- logger.warning("OpenAI RateLimitError")
- async with get_ai_connection().openai_ratelimit_semaphore:
- await asyncio.sleep(1)
- raise TimeoutError("Cannot overcome OpenAI RateLimitError")
- elif model.company == "anthropic":
- for i in range(num_ratelimit_retries):
- try:
- def ai_message_to_anthropic_message_param(
- message: AIMessage,
- ) -> MessageParam:
- if message.role == "user" or message.role == "assistant":
- return {"role": message.role, "content": message.content}
- elif message.role == "system":
- raise RuntimeError(
- "system not allowed in anthropic message param"
- )
- if i > 0:
- logger.debug("Trying again after RateLimitError...")
- # Extract system message if it exists
- system: str | NotGiven = NOT_GIVEN
- if len(messages) > 0 and messages[0].role == "system":
- system = messages[0].content
- messages = messages[1:]
- # Insert initial message if necessary
- if (
- anthropic_initial_message is not None
- and len(messages) > 0
- and messages[0].role != "user"
- ):
- messages = [
- AIMessage(role="user", content=anthropic_initial_message)
- ] + messages
- # Combined messages (By combining consecutive messages of the same role)
- combined_messages: list[AIMessage] = []
- for message in messages:
- if (
- len(combined_messages) == 0
- or combined_messages[-1].role != message.role
- ):
- combined_messages.append(message)
- else:
- # Copy before edit
- combined_messages[-1] = combined_messages[-1].model_copy(
- deep=True
- )
- # Merge consecutive messages with the same role
- combined_messages[-1].content += (
- anthropic_combine_delimiter + message.content
- )
- # Get the response
- response_message = (
- await get_ai_connection().anthropic_client.messages.create(
- model=model.model,
- system=system,
- messages=[
- ai_message_to_anthropic_message_param(message)
- for message in combined_messages
- ],
- temperature=0.0,
- max_tokens=max_tokens,
- )
- )
- return response_message.content[0].text
- except anthropic.RateLimitError as e:
- logger.warning(f"Anthropic Error: {repr(e)}")
- async with get_ai_connection().anthropic_ratelimit_semaphore:
- await asyncio.sleep(1)
- raise TimeoutError("Cannot overcome Anthropic RateLimitError")
- async def ai_stt(buffer: BytesIO) -> str:
- try:
- response = await get_ai_connection().openai_client.audio.transcriptions.create(
- model="whisper-1",
- file=buffer,
- )
- return response.text
- except openai.BadRequestError as e:
- # Return empty string for audio that's too short
- if e.code == "audio_too_short":
- return ""
- else:
- raise
- class AIVoiceModel(BaseModel):
- company: Literal["openai", "elevenlabs", "playht"]
- voice: str
- speed: float = 1
- async def ai_tts(
- transcript: str,
- *,
- voice_model: Optional[AIVoiceModel] = None,
- low_latency: bool = False,
- ) -> AsyncIterator[bytes]:
- if voice_model is None:
- voice_model = AIVoiceModel(company="openai", voice="nova")
- async def log_bytes_iterator(
- bytes_generator: AsyncIterator[bytes],
- ) -> AsyncIterator[bytes]:
- t1: float = time.time()
- t2: Optional[float] = None
- async for chunk in bytes_generator:
- if t2 is None:
- t2 = time.time()
- yield chunk
- if t2 is None:
- t2 = time.time()
- t3: float = time.time()
- logger.debug(
- f"TTS Latency ({t2-t1:.3f}): {repr(transcript)} (low_latency={low_latency})"
- )
- logger.debug(f"TTS ({t3-t2:.3f}): {repr(transcript)}")
- bytes_generator: AsyncIterator[bytes]
- match voice_model.company:
- case "openai":
- # Typing
- openai_voice_type = Literal[
- "alloy", "echo", "fable", "onyx", "nova", "shimmer"
- ]
- def voice_to_openai_voice(voice: str) -> openai_voice_type:
- if voice not in get_args(openai_voice_type):
- raise ValueError(
- f"voice must be one of {get_args(openai_voice_type)}, received {voice}"
- )
- return cast(openai_voice_type, voice)
- # Run it
- response = await get_ai_connection().openai_client.audio.speech.with_raw_response.create(
- model="tts-1",
- voice=voice_to_openai_voice(voice_model.voice),
- input=transcript,
- response_format="aac",
- speed=voice_model.speed,
- )
- bytes_generator = response.http_response.aiter_bytes(chunk_size=2048)
- case "playht":
- options = TTSOptions(voice=voice_model.voice, speed=voice_model.speed)
- try:
- bytes_generator = get_ai_connection().play_ht_client.tts(
- transcript.strip().replace(" ", " "), options
- )
- except Exception as e:
- logger.error(f"Error occurred in Play.ht TTS: {e}")
- raise
- case "elevenlabs":
- if voice_model.speed != 1:
- logger.warning(
- f"elevenlabs does not support speed change to {voice_model.speed}. Ignoring."
- )
- bytes_generator = await get_ai_connection().eleven_labs_client.generate(
- text=transcript.strip().replace(" ", " "),
- voice=voice_model.voice,
- optimize_streaming_latency=3 if low_latency else None,
- model="eleven_turbo_v2",
- stream=True,
- )
- return log_bytes_iterator(bytes_generator=bytes_generator)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement