Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import asyncio
- import base64
- import hashlib
- import os
- from collections.abc import Callable, Coroutine
- from enum import Enum
- from typing import Any, Literal, cast
- import anthropic
- import cohere
- import diskcache as dc # pyright: ignore[reportMissingTypeStubs]
- import httpx
- import numpy as np
- import openai
- import tiktoken
- from anthropic import NOT_GIVEN, Anthropic, AsyncAnthropic, NotGiven
- from anthropic.types import MessageParam
- from loguru import logger
- from openai import AsyncOpenAI
- from openai.types.chat import ChatCompletionMessageParam
- from openlimit.rate_limiters import ( # pyright: ignore[reportMissingTypeStubs]
- RateLimiter,
- )
- from pydantic import BaseModel
- from .credentials import credentials
- # Size in bytes
- AI_CACHE_SIZE_LIMIT = cast(int | None, 2 * 2**30)
- # AI Types
- class AIModel(BaseModel):
- company: Literal["openai", "anthropic"]
- model: str
- @property
- def ratelimit_tpm(self) -> float:
- match self.company:
- case "openai":
- # Tier 5
- match self.model:
- case _ if self.model.startswith("gpt-4o-mini"):
- return 150_000_000
- case _ if self.model.startswith("gpt-4o"):
- return 30_000_000
- case "gpt-4-turbo":
- return 2_000_000
- case _:
- raise NotImplementedError("Unknown OpenAI Model")
- case "anthropic":
- # Tier 4
- return 400_000
- @property
- def ratelimit_rpm(self) -> float:
- match self.company:
- case "openai":
- # Tier 5
- match self.model:
- case _ if self.model.startswith("gpt-4o-mini"):
- return 30_000
- case _:
- return 10_000
- case "anthropic":
- # Tier 4
- return 4_000
- class AIMessage(BaseModel):
- role: Literal["system", "user", "assistant"]
- content: str
- class AIEmbeddingModel(BaseModel):
- company: Literal["openai", "cohere"]
- model: str
- @property
- def dimensions(self) -> int:
- match self.company:
- case "openai":
- match self.model:
- case "text-embedding-3-large":
- return 3072
- case "text-embedding-3-small":
- return 1536
- case _:
- pass
- case "cohere":
- pass
- raise NotImplementedError("Unknown Dimensions")
- @property
- def ratelimit_tpm(self) -> float:
- match self.company:
- case "openai":
- return 10_000_000
- case "cohere":
- return float("inf")
- @property
- def ratelimit_rpm(self) -> float:
- match self.company:
- case "openai":
- return 10_000
- case "cohere":
- return 1_000
- @property
- def max_batch_len(self) -> int:
- match self.company:
- case "openai":
- return 2048
- case "cohere":
- return 96
- class AIEmbeddingType(Enum):
- DOCUMENT = 1
- QUERY = 2
- class AIRerankModel(BaseModel):
- company: Literal["cohere"]
- model: str
- @property
- def ratelimit_rpm(self) -> float:
- match self.company:
- case "cohere":
- return 10_000
- @property
- def ratelimit_tpm(self) -> float:
- match self.company:
- case "cohere":
- return float("inf")
- class AIConnection:
- openai_client: AsyncOpenAI
- anthropic_client: AsyncAnthropic
- sync_anthropic_client: Anthropic
- cohere_client: cohere.AsyncClient
- # Mapping from (company, model) to RateLimiter
- rate_limiters: dict[str, RateLimiter]
- semaphores: dict[str, asyncio.Semaphore]
- def __init__(self) -> None:
- self.openai_client = AsyncOpenAI(
- api_key=credentials.ai.openai_api_key.get_secret_value()
- )
- self.anthropic_client = AsyncAnthropic(
- api_key=credentials.ai.anthropic_api_key.get_secret_value()
- )
- self.sync_anthropic_client = Anthropic(
- api_key=credentials.ai.anthropic_api_key.get_secret_value()
- )
- self.cohere_client = cohere.AsyncClient(
- api_key=credentials.ai.cohere_api_key.get_secret_value()
- )
- self.rate_limiters = {}
- self.semaphores = {}
- async def ai_wait_ratelimit(
- self,
- model: AIModel | AIEmbeddingModel | AIRerankModel,
- num_tokens: int,
- backoff: float | None = None,
- ) -> None:
- key = f"{model.__class__}|{model.model}|{model.company}"
- if key not in self.rate_limiters:
- self.rate_limiters[key] = RateLimiter(
- request_limit=model.ratelimit_rpm * RATE_LIMIT_RATIO,
- token_limit=model.ratelimit_tpm * RATE_LIMIT_RATIO,
- token_counter=None,
- bucket_size_in_seconds=15,
- )
- self.semaphores[key] = asyncio.Semaphore(1)
- if backoff is not None:
- async with self.semaphores[key]:
- await asyncio.sleep(backoff)
- await self.rate_limiters[key].wait_for_capacity(num_tokens) # pyright: ignore[reportUnknownMemberType]
- class AIError(Exception):
- """A class for AI Task Errors"""
- def __init__(self, message: str) -> None:
- super().__init__(message)
- class AIValueError(AIError, ValueError): # pyright: ignore[reportUnsafeMultipleInheritance]
- """A class for AI Value Errors"""
- def __init__(self, message: str) -> None:
- super().__init__(message)
- class AITimeoutError(AIError, TimeoutError): # pyright: ignore[reportUnsafeMultipleInheritance]
- """A class for AI Task Timeout Errors"""
- def __init__(self, message: str) -> None:
- super().__init__(message)
- # NOTE: API Clients cannot be called from multiple event loops,
- # So every asyncio event loop needs its own API connection
- ai_connections: dict[asyncio.AbstractEventLoop, AIConnection] = {}
- def get_ai_connection() -> AIConnection:
- event_loop = asyncio.get_event_loop()
- if event_loop not in ai_connections:
- ai_connections[event_loop] = AIConnection()
- return ai_connections[event_loop]
- # Cache (Default=1GB, LRU)
- CACHE_DIR = "./.cache"
- cache: dc.Cache | None
- if AI_CACHE_SIZE_LIMIT is None:
- cache = None
- else:
- os.makedirs(CACHE_DIR, exist_ok=True)
- cache = dc.Cache(f"{CACHE_DIR}/ai_cache.db", size_limit=AI_CACHE_SIZE_LIMIT)
- RATE_LIMIT_RATIO = 0.95
- def ai_num_tokens(
- model: AIModel | AIEmbeddingModel | AIRerankModel, input_string: str
- ) -> int:
- """
- Calculate the number of tokens for a given string based on the AI model.
- Parameters
- ----------
- model : AIModel | AIEmbeddingModel | AIRerankModel
- The AI model used to determine the tokenization rules. Can be an instance of `AIModel`,
- `AIEmbeddingModel`, or `AIRerankModel`.
- input_string : str
- The input string to be tokenized.
- Returns
- -------
- int
- The number of tokens in the input string based on the tokenization rules of the provided model.
- Example
- -------
- ```python
- model = AIModel(company="openai", model="gpt-4o")
- num_tokens = ai_num_tokens(model, "Hello, world!")
- print(num_tokens)
- ```
- """
- if isinstance(model, AIModel):
- match model.company:
- case "anthropic":
- # Doesn't actually connect to the network
- return get_ai_connection().sync_anthropic_client.count_tokens(
- input_string
- )
- case "openai":
- encoding = tiktoken.encoding_for_model(model.model)
- num_tokens = len(encoding.encode(input_string))
- return num_tokens
- elif isinstance(model, AIEmbeddingModel):
- match model.company:
- case "openai":
- encoding = tiktoken.encoding_for_model(model.model)
- num_tokens = len(encoding.encode(input_string))
- return num_tokens
- case "cohere":
- pass
- # Otherwise, estimate
- logger.warning("Estimating Tokens!")
- return int(len(input_string) / 3.5)
- def ai_call_cache_key(
- model: AIModel,
- messages: list[AIMessage],
- output_type: type,
- ) -> str:
- # Hash the array of texts
- md5_hasher = hashlib.md5()
- md5_hasher.update(model.model_dump_json().encode())
- for message in messages:
- md5_hasher.update(md5_hasher.hexdigest().encode())
- md5_hasher.update(message.model_dump_json().encode())
- md5_hasher.update(md5_hasher.hexdigest().encode())
- md5_hasher.update(f"{output_type}".encode())
- key = md5_hasher.hexdigest()
- return key
- type T = BaseModel | str
- async def ai_call[T](
- model: AIModel,
- messages: list[AIMessage],
- *,
- max_tokens: int = 4096,
- temperature: float = 0.0,
- anthropic_initial_message: str | None = "<START>",
- anthropic_combine_delimiter: str = "\n",
- num_ratelimit_retries: int = 10,
- backoff_algo: Callable[[int], float] = lambda i: min(2**i, 5),
- output_type: type[T] = str,
- ) -> T:
- """
- Makes an asynchronous AI call to a given model with a set of messages, returning the result either as a string or
- structured output based on the output type provided.
- Parameters:
- ----------
- model : AIModel
- The AI model used for the call.
- messages : list[AIMessage]
- A list of messages to send to the AI model. This typically includes the conversation context and user input.
- max_tokens : int, optional
- The maximum number of tokens the AI model should generate. Defaults to 4096.
- temperature : float, optional
- The sampling temperature for the model. Higher values will result in more random outputs, while lower values
- will make the output more deterministic. Defaults to 0.0.
- anthropic_initial_message : str | None, optional
- When using Anthropic's API, the first message must be from the user. If the first message isn't a user message,
- this initial message will be prepended to the messages. Defaults to "<START>".
- anthropic_combine_delimiter : str, optional
- The delimiter used to concatenate messages when two messages of the same role are passed to Anthropic's API.
- Defaults to "\n".
- num_ratelimit_retries : int, optional
- The number of retry attempts to make if the call is rate-limited. Defaults to 10.
- backoff_algo : Callable[[int], float], optional
- A function that receives the index of the attempt and returns the backoff delay time (in seconds) between retries.
- The default is an exponential backoff algorithm capped at 5 seconds: `lambda i: min(2**i, 5)`.
- output_type : type[T], optional
- The expected output type of the call. Can be a `str` (for raw text output) or a Pydantic `BaseModel` (for
- structured output). Defaults to `str`.
- Returns:
- --------
- T
- The result of the AI call. The type of the result is determined by `output_type`:
- - If `output_type` is `str`, the function returns a string.
- - If `output_type` is a Pydantic `BaseModel`, the function returns an instance of the provided model.
- Raises:
- -------
- AITimeoutError
- If the number of retry attempts exceeds `num_ratelimit_retries` without a successful response.
- AIValueError
- If the messages array is in the incorrect format for the underlying API.
- Example:
- --------
- ```python
- response: str = await ai_call(
- model=AIModel(company="openai", model="gpt-4o-mini"),
- messages=[
- AIMessage(role="system", content="Be helpful"),
- AIMessage(role="user", content="What does AI mean?"),
- ],
- output_type=MyResponseModel,
- )
- ```
- Structured Output
- ```python
- class ResponseModel(BaseModel):
- year: int
- response: ResponseModel = await ai_call(
- model=AIModel(company="openai", model="gpt-4o-mini"),
- messages=[
- AIMessage(role="system", content="Be helpful"),
- AIMessage(role="user", content="When was George Washington born?"),
- ],
- output_type=ResponseModel,
- )
- ```
- """
- cache_key = ai_call_cache_key(model, messages, output_type)
- cached_call = cast(Any, cache.get(cache_key)) if cache is not None else None # pyright: ignore[reportUnknownMemberType]
- if cached_call is not None:
- assert isinstance(cached_call, output_type)
- return cached_call
- num_tokens_input: int = sum(
- [ai_num_tokens(model, message.content) for message in messages]
- )
- return_value: T | None = None
- match model.company:
- case "openai":
- for i in range(num_ratelimit_retries):
- try:
- await get_ai_connection().ai_wait_ratelimit(
- model, num_tokens_input, backoff_algo(i - 1) if i > 0 else None
- )
- def ai_message_to_openai_message_param(
- message: AIMessage,
- ) -> ChatCompletionMessageParam:
- if message.role == "system": # noqa: SIM114
- return {"role": message.role, "content": message.content}
- elif message.role == "user": # noqa: SIM114
- return {"role": message.role, "content": message.content}
- elif message.role == "assistant":
- return {"role": message.role, "content": message.content}
- raise NotImplementedError("Unreachable Code")
- if i > 0:
- logger.debug("Trying again after RateLimitError...")
- if output_type is str:
- response = await get_ai_connection().openai_client.chat.completions.create(
- model=model.model,
- messages=[
- ai_message_to_openai_message_param(message)
- for message in messages
- ],
- temperature=temperature,
- max_tokens=max_tokens,
- )
- response_content = response.choices[0].message.content
- assert response_content is not None
- assert isinstance(response_content, output_type)
- return_value = cast(T, response_content)
- else:
- response = await get_ai_connection().openai_client.beta.chat.completions.parse(
- model=model.model,
- messages=[
- ai_message_to_openai_message_param(message)
- for message in messages
- ],
- temperature=temperature,
- max_tokens=max_tokens,
- response_format=output_type,
- )
- response_parsed = response.choices[0].message.parsed
- assert response_parsed is not None
- assert isinstance(response_parsed, output_type)
- return_value = cast(T, response_parsed)
- break
- except (
- openai.RateLimitError,
- openai.APITimeoutError,
- openai.APIConnectionError,
- ) as e:
- logger.warning(f"OpenAI RateLimitError: {repr(e)}")
- if return_value is None:
- raise AITimeoutError("Cannot overcome OpenAI RateLimitError")
- case "anthropic":
- for i in range(num_ratelimit_retries):
- try:
- await get_ai_connection().ai_wait_ratelimit(
- model, num_tokens_input, backoff_algo(i - 1) if i > 0 else None
- )
- def ai_message_to_anthropic_message_param(
- message: AIMessage,
- ) -> MessageParam:
- if message.role == "user" or message.role == "assistant":
- return {"role": message.role, "content": message.content}
- elif message.role == "system":
- raise AIValueError(
- "system not allowed in anthropic message param"
- )
- raise NotImplementedError("Unreachable Code")
- if i > 0:
- logger.debug("Trying again after RateLimitError...")
- # Extract system message if it exists
- system: str | NotGiven = NOT_GIVEN
- if len(messages) > 0 and messages[0].role == "system":
- system = messages[0].content
- messages = messages[1:]
- # Insert initial message if necessary
- if (
- anthropic_initial_message is not None
- and len(messages) > 0
- and messages[0].role != "user"
- ):
- messages = [
- AIMessage(role="user", content=anthropic_initial_message)
- ] + messages
- # Combined messages (By combining consecutive messages of the same role)
- combined_messages: list[AIMessage] = []
- for message in messages:
- if (
- len(combined_messages) == 0
- or combined_messages[-1].role != message.role
- ):
- combined_messages.append(message)
- else:
- # Copy before edit
- combined_messages[-1] = combined_messages[-1].model_copy(
- deep=True
- )
- # Merge consecutive messages with the same role
- combined_messages[-1].content += (
- anthropic_combine_delimiter + message.content
- )
- # Get the response
- response_message = (
- await get_ai_connection().anthropic_client.messages.create(
- model=model.model,
- system=system,
- messages=[
- ai_message_to_anthropic_message_param(message)
- for message in combined_messages
- ],
- temperature=0.0,
- max_tokens=max_tokens,
- )
- )
- assert isinstance(
- response_message.content[0], anthropic.types.TextBlock
- )
- assert isinstance(response_message.content[0].text, str)
- if output_type is not str:
- raise NotImplementedError(
- "TODO: Implement Structured Output with Anthropic"
- )
- assert isinstance(response_message.content[0].text, output_type)
- return_value = cast(T, response_message.content[0].text)
- break
- except (anthropic.RateLimitError, anthropic.APIConnectionError) as e:
- logger.warning(f"Anthropic Error: {repr(e)}")
- if return_value is None:
- raise AITimeoutError("Cannot overcome Anthropic RateLimitError")
- if cache is not None:
- cache.set(cache_key, return_value) # pyright: ignore[reportUnknownMemberType]
- return return_value
- def get_embeddings_cache_key(
- model: AIEmbeddingModel, text: str, embedding_type: AIEmbeddingType
- ) -> str:
- md5_hasher = hashlib.md5()
- md5_hasher.update(model.model_dump_json().encode())
- md5_hasher.update(md5_hasher.hexdigest().encode())
- md5_hasher.update(text.encode())
- md5_hasher.update(md5_hasher.hexdigest().encode())
- md5_hasher.update(embedding_type.name.encode())
- hash = md5_hasher.hexdigest()
- return hash
- AIEmbedding = np.ndarray[Literal[1], np.dtype[np.float32]]
- """numpy 1D Array of Floats"""
- def cosine_similarity(vec1: AIEmbedding, vec2: AIEmbedding) -> float:
- """
- Compute the cosine similarity between two embeddings.
- Parameters
- ----------
- vec1 : AIEmbedding
- The first embedding vector.
- vec2 : AIEmbedding
- The second embedding vector.
- Returns
- -------
- float
- A value between -1 and 1 representing the cosine similarity between the two vectors.
- Example
- -------
- ```python
- vec1: AIEmbedding = np.array([1,2,3])
- vec2: AIEmbedding = np.array([1,2,3])
- similarity = cosine_similarity(vec1, vec2)
- print(similarity)
- ```
- """
- return np.dot(vec1, vec2)
- async def ai_embedding(
- model: AIEmbeddingModel,
- texts: list[str],
- embedding_type: AIEmbeddingType,
- *,
- num_ratelimit_retries: int = 10,
- backoff_algo: Callable[[int], float] = lambda i: min(2**i, 5),
- callback: Callable[[], None] = lambda: None,
- ) -> list[AIEmbedding]:
- """
- Generate embeddings for a list of texts using a specified AI embedding model.
- Parameters
- ----------
- model : AIEmbeddingModel
- The AI model used to generate embeddings.
- texts : list[str]
- A list of texts for which embeddings will be generated.
- embedding_type : AIEmbeddingType
- The type of embedding to be generated (e.g., sentence-level, document-level).
- num_ratelimit_retries : int, optional
- The number of retry attempts to make if the AI call is rate-limited. Defaults to 10.
- backoff_algo : Callable[[int], float], optional
- A function that determines the backoff strategy between retries. It receives the retry attempt index
- and returns the delay time in seconds. Defaults to an exponential backoff function.
- callback : Callable[[], None], optional
- A callback function that is executed to track progress during embedding generation. Defaults to a no-op function.
- Returns
- -------
- list[AIEmbedding]
- A list of embeddings generated by the AI model, where each embedding corresponds to an input text.
- Raises
- ------
- AITimeoutError
- If the number of retry attempts exceeds `num_ratelimit_retries` without a successful response.
- Example
- -------
- ```python
- model = AIEmbeddingModel()
- texts = ["This is a sentence.", "Another sentence."]
- embeddings = await ai_embedding(model, texts, AIEmbeddingType.SENTENCE)
- print(embeddings)
- ```
- """
- # Extract cache miss indices
- text_embeddings: list[AIEmbedding | None] = [None] * len(texts)
- if cache is not None:
- with cache.transact():
- for i, text in enumerate(texts):
- cache_key = get_embeddings_cache_key(model, text, embedding_type)
- cache_result = cast(Any, cache.get(cache_key)) # pyright: ignore[reportUnknownMemberType]
- if cache_result is not None:
- callback()
- if not isinstance(cache_result, np.ndarray):
- logger.warning("Invalid cache_result, ignoring...")
- continue
- cache_result = cast(AIEmbedding, cache_result)
- text_embeddings[i] = cache_result
- if not any(embedding is None for embedding in text_embeddings):
- return cast(list[AIEmbedding], text_embeddings)
- required_text_embeddings_indices = [
- i for i in range(len(text_embeddings)) if text_embeddings[i] is None
- ]
- # Recursively Batch if necessary
- if len(required_text_embeddings_indices) > model.max_batch_len:
- # Calculate embeddings in batches
- tasks: list[Coroutine[Any, Any, list[AIEmbedding]]] = []
- for i in range(0, len(required_text_embeddings_indices), model.max_batch_len):
- batch_indices = required_text_embeddings_indices[
- i : i + model.max_batch_len
- ]
- tasks.append(
- ai_embedding(
- model,
- [texts[i] for i in batch_indices],
- embedding_type,
- num_ratelimit_retries=num_ratelimit_retries,
- backoff_algo=backoff_algo,
- callback=callback,
- )
- )
- preflattened_results = await asyncio.gather(*tasks)
- results: list[AIEmbedding] = []
- for embeddings_list in preflattened_results:
- results.extend(embeddings_list)
- # Merge with cache hits
- assert len(required_text_embeddings_indices) == len(results)
- for i, embedding in zip(
- required_text_embeddings_indices, results, strict=False
- ):
- text_embeddings[i] = embedding
- assert all(embedding is not None for embedding in text_embeddings)
- return cast(list[AIEmbedding], text_embeddings)
- num_tokens_input: int = sum(
- [
- ai_num_tokens(model, texts[index])
- for index in required_text_embeddings_indices
- ]
- )
- input_texts = [texts[i] for i in required_text_embeddings_indices]
- text_embeddings_response: list[AIEmbedding] | None = None
- match model.company:
- case "openai":
- for i in range(num_ratelimit_retries):
- try:
- await get_ai_connection().ai_wait_ratelimit(
- model, num_tokens_input, backoff_algo(i - 1) if i > 0 else None
- )
- response = (
- await get_ai_connection().openai_client.embeddings.create(
- input=input_texts,
- model=model.model,
- encoding_format="base64",
- )
- )
- response_embeddings: list[AIEmbedding] = []
- for embedding_obj in response.data:
- data = cast(object, embedding_obj.embedding)
- if not isinstance(data, str):
- # numpy is not installed / base64 optimisation isn't enabled for this model yet
- raise RuntimeError("Error with base64/numpy")
- response_embeddings.append(
- np.frombuffer(base64.b64decode(data), dtype="float32")
- )
- text_embeddings_response = response_embeddings
- break
- except (
- openai.RateLimitError,
- openai.APITimeoutError,
- ):
- logger.warning("OpenAI RateLimitError")
- except openai.APIError as e:
- logger.warning(f"OpenAI Unknown Error: {repr(e)}")
- if text_embeddings_response is None:
- raise AITimeoutError("Cannot overcome OpenAI RateLimitError")
- case "cohere":
- for i in range(num_ratelimit_retries):
- try:
- await get_ai_connection().ai_wait_ratelimit(
- model, num_tokens_input, backoff_algo(i - 1) if i > 0 else None
- )
- result = await get_ai_connection().cohere_client.embed(
- texts=input_texts,
- model=model.model,
- input_type=(
- "search_document"
- if embedding_type == AIEmbeddingType.DOCUMENT
- else "search_query"
- ),
- )
- assert isinstance(result.embeddings, list)
- text_embeddings_response = [
- np.array(embedding) for embedding in result.embeddings
- ]
- break
- except (
- cohere.errors.TooManyRequestsError,
- httpx.ConnectError,
- httpx.RemoteProtocolError,
- ):
- logger.warning("Cohere RateLimitError")
- if text_embeddings_response is None:
- raise AITimeoutError("Cannot overcome Cohere RateLimitError")
- if cache is not None:
- with cache.transact():
- assert len(text_embeddings_response) == len(
- required_text_embeddings_indices
- )
- for index, embedding in zip(
- required_text_embeddings_indices, text_embeddings_response, strict=False
- ):
- cache_key = get_embeddings_cache_key(
- model, texts[index], embedding_type
- )
- cache.set(cache_key, embedding) # pyright: ignore[reportUnknownMemberType]
- for index, embedding in zip(
- required_text_embeddings_indices, text_embeddings_response, strict=False
- ):
- text_embeddings[index] = embedding
- callback()
- assert all(embedding is not None for embedding in text_embeddings)
- return cast(list[AIEmbedding], text_embeddings)
- def get_rerank_cache_key(model: AIRerankModel, query: str, text: str) -> str:
- md5_hasher = hashlib.md5()
- md5_hasher.update(model.model_dump_json().encode())
- md5_hasher.update(md5_hasher.hexdigest().encode())
- md5_hasher.update(query.encode())
- md5_hasher.update(md5_hasher.hexdigest().encode())
- md5_hasher.update(text.encode())
- hash = md5_hasher.hexdigest()
- return hash
- async def ai_rerank(
- model: AIRerankModel,
- query: str,
- texts: list[str],
- *,
- num_ratelimit_retries: int = 10,
- backoff_algo: Callable[[int], float] = lambda i: min(2**i, 5),
- ) -> list[float]:
- """
- Calculate the similarity score between a query and a text, using a cross-encoder model (Commonly known as a reranker)
- Parameters
- ----------
- model : AIRerankModel
- The AI model used for calculating cross-encoder similarity scores between the texts and the query.
- query : str
- The query used to evaluate the relevance of the texts.
- texts : list[str]
- A list of texts to calculate similarity scores against.
- num_ratelimit_retries : int, optional
- The number of retry attempts to make if the AI call is rate-limited. Defaults to 10.
- backoff_algo : Callable[[int], float], optional
- A function that defines the backoff strategy between retries. It receives the retry attempt index and
- returns the delay time in seconds. Defaults to an exponential backoff function.
- Returns
- -------
- list[float]
- A list of scores corresponding to the relevance of each text to the query. Higher scores indicate
- higher relevance.
- Raises
- ------
- AITimeoutError
- If the number of retry attempts exceeds `num_ratelimit_retries` without a successful response.
- Example
- -------
- ```python
- model = AIRerankModel(company="cohere", model="embed-multilingual-v3.0")
- query = "What is AI?"
- texts = ["AI is artificial intelligence.", "AI is used in various fields."]
- scores = await ai_rerank(model, query, texts)
- print(scores)
- ```
- """
- text_scores: list[float | None] = [None] * len(texts)
- if cache is not None:
- with cache.transact():
- for i, text in enumerate(texts):
- cache_key = get_rerank_cache_key(model, query, text)
- cache_result = cast(Any, cache.get(cache_key)) # pyright: ignore[reportUnknownMemberType]
- if cache_result is not None:
- # cast instead of assert isinstance, because of ints
- cache_result = float(cache_result)
- text_scores[i] = cache_result
- if all(score is not None for score in text_scores):
- return cast(list[float], text_scores)
- unprocessed_indices = [i for i, score in enumerate(text_scores) if score is None]
- unprocessed_texts = [texts[i] for i in unprocessed_indices]
- num_tokens_input = sum(ai_num_tokens(model, text) for text in unprocessed_texts)
- relevancy_scores: list[float] | None = None
- match model.company:
- case "cohere":
- for i in range(num_ratelimit_retries):
- try:
- await get_ai_connection().ai_wait_ratelimit(
- model, num_tokens_input, backoff_algo(i - 1) if i > 0 else None
- )
- response = await get_ai_connection().cohere_client.rerank(
- model=model.model,
- query=query,
- documents=unprocessed_texts,
- )
- original_order_results = sorted(
- response.results, key=lambda x: x.index
- )
- relevancy_scores = [
- result.relevance_score for result in original_order_results
- ]
- break
- except (
- cohere.errors.TooManyRequestsError,
- httpx.ConnectError,
- httpx.RemoteProtocolError,
- ):
- logger.warning("Cohere RateLimitError")
- if relevancy_scores is None:
- raise AITimeoutError("Cannot overcome Cohere RateLimitError")
- assert len(unprocessed_indices) == len(relevancy_scores)
- if cache is not None:
- with cache.transact():
- for index, score in zip(
- unprocessed_indices, relevancy_scores, strict=False
- ):
- cache_key = get_rerank_cache_key(model, query, texts[index])
- cache.set(cache_key, score) # pyright: ignore[reportUnknownMemberType]
- for index, score in zip(unprocessed_indices, relevancy_scores, strict=False):
- text_scores[index] = score
- assert all(score is not None for score in text_scores)
- return cast(list[float], text_scores)
Add Comment
Please, Sign In to add comment