Nickpips

ai.py

Sep 30th, 2024 (edited)
29
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 34.77 KB | None | 0 0
  1. import asyncio
  2. import base64
  3. import hashlib
  4. import os
  5. from collections.abc import Callable, Coroutine
  6. from enum import Enum
  7. from typing import Any, Literal, cast
  8.  
  9. import anthropic
  10. import cohere
  11. import diskcache as dc  # pyright: ignore[reportMissingTypeStubs]
  12. import httpx
  13. import numpy as np
  14. import openai
  15. import tiktoken
  16. from anthropic import NOT_GIVEN, Anthropic, AsyncAnthropic, NotGiven
  17. from anthropic.types import MessageParam
  18. from loguru import logger
  19. from openai import AsyncOpenAI
  20. from openai.types.chat import ChatCompletionMessageParam
  21. from openlimit.rate_limiters import (  # pyright: ignore[reportMissingTypeStubs]
  22.     RateLimiter,
  23. )
  24. from pydantic import BaseModel
  25.  
  26. from .credentials import credentials
  27.  
  28. # Size in bytes
  29. AI_CACHE_SIZE_LIMIT = cast(int | None, 2 * 2**30)
  30.  
  31. # AI Types
  32.  
  33.  
  34. class AIModel(BaseModel):
  35.     company: Literal["openai", "anthropic"]
  36.     model: str
  37.  
  38.     @property
  39.     def ratelimit_tpm(self) -> float:
  40.         match self.company:
  41.             case "openai":
  42.                 # Tier 5
  43.                 match self.model:
  44.                     case _ if self.model.startswith("gpt-4o-mini"):
  45.                         return 150_000_000
  46.                     case _ if self.model.startswith("gpt-4o"):
  47.                         return 30_000_000
  48.                     case "gpt-4-turbo":
  49.                         return 2_000_000
  50.                     case _:
  51.                         raise NotImplementedError("Unknown OpenAI Model")
  52.             case "anthropic":
  53.                 # Tier 4
  54.                 return 400_000
  55.  
  56.     @property
  57.     def ratelimit_rpm(self) -> float:
  58.         match self.company:
  59.             case "openai":
  60.                 # Tier 5
  61.                 match self.model:
  62.                     case _ if self.model.startswith("gpt-4o-mini"):
  63.                         return 30_000
  64.                     case _:
  65.                         return 10_000
  66.             case "anthropic":
  67.                 # Tier 4
  68.                 return 4_000
  69.  
  70.  
  71. class AIMessage(BaseModel):
  72.     role: Literal["system", "user", "assistant"]
  73.     content: str
  74.  
  75.  
  76. class AIEmbeddingModel(BaseModel):
  77.     company: Literal["openai", "cohere"]
  78.     model: str
  79.  
  80.     @property
  81.     def dimensions(self) -> int:
  82.         match self.company:
  83.             case "openai":
  84.                 match self.model:
  85.                     case "text-embedding-3-large":
  86.                         return 3072
  87.                     case "text-embedding-3-small":
  88.                         return 1536
  89.                     case _:
  90.                         pass
  91.             case "cohere":
  92.                 pass
  93.         raise NotImplementedError("Unknown Dimensions")
  94.  
  95.     @property
  96.     def ratelimit_tpm(self) -> float:
  97.         match self.company:
  98.             case "openai":
  99.                 return 10_000_000
  100.             case "cohere":
  101.                 return float("inf")
  102.  
  103.     @property
  104.     def ratelimit_rpm(self) -> float:
  105.         match self.company:
  106.             case "openai":
  107.                 return 10_000
  108.             case "cohere":
  109.                 return 1_000
  110.  
  111.     @property
  112.     def max_batch_len(self) -> int:
  113.         match self.company:
  114.             case "openai":
  115.                 return 2048
  116.             case "cohere":
  117.                 return 96
  118.  
  119.  
  120. class AIEmbeddingType(Enum):
  121.     DOCUMENT = 1
  122.     QUERY = 2
  123.  
  124.  
  125. class AIRerankModel(BaseModel):
  126.     company: Literal["cohere"]
  127.     model: str
  128.  
  129.     @property
  130.     def ratelimit_rpm(self) -> float:
  131.         match self.company:
  132.             case "cohere":
  133.                 return 10_000
  134.  
  135.     @property
  136.     def ratelimit_tpm(self) -> float:
  137.         match self.company:
  138.             case "cohere":
  139.                 return float("inf")
  140.  
  141.  
  142. class AIConnection:
  143.     openai_client: AsyncOpenAI
  144.     anthropic_client: AsyncAnthropic
  145.     sync_anthropic_client: Anthropic
  146.     cohere_client: cohere.AsyncClient
  147.     # Mapping from (company, model) to RateLimiter
  148.     rate_limiters: dict[str, RateLimiter]
  149.     semaphores: dict[str, asyncio.Semaphore]
  150.  
  151.     def __init__(self) -> None:
  152.         self.openai_client = AsyncOpenAI(
  153.             api_key=credentials.ai.openai_api_key.get_secret_value()
  154.         )
  155.         self.anthropic_client = AsyncAnthropic(
  156.             api_key=credentials.ai.anthropic_api_key.get_secret_value()
  157.         )
  158.         self.sync_anthropic_client = Anthropic(
  159.             api_key=credentials.ai.anthropic_api_key.get_secret_value()
  160.         )
  161.         self.cohere_client = cohere.AsyncClient(
  162.             api_key=credentials.ai.cohere_api_key.get_secret_value()
  163.         )
  164.         self.rate_limiters = {}
  165.         self.semaphores = {}
  166.  
  167.     async def ai_wait_ratelimit(
  168.         self,
  169.         model: AIModel | AIEmbeddingModel | AIRerankModel,
  170.         num_tokens: int,
  171.         backoff: float | None = None,
  172.     ) -> None:
  173.         key = f"{model.__class__}|{model.model}|{model.company}"
  174.         if key not in self.rate_limiters:
  175.             self.rate_limiters[key] = RateLimiter(
  176.                 request_limit=model.ratelimit_rpm * RATE_LIMIT_RATIO,
  177.                 token_limit=model.ratelimit_tpm * RATE_LIMIT_RATIO,
  178.                 token_counter=None,
  179.                 bucket_size_in_seconds=15,
  180.             )
  181.             self.semaphores[key] = asyncio.Semaphore(1)
  182.         if backoff is not None:
  183.             async with self.semaphores[key]:
  184.                 await asyncio.sleep(backoff)
  185.         await self.rate_limiters[key].wait_for_capacity(num_tokens)  # pyright: ignore[reportUnknownMemberType]
  186.  
  187.  
  188. class AIError(Exception):
  189.     """A class for AI Task Errors"""
  190.  
  191.     def __init__(self, message: str) -> None:
  192.         super().__init__(message)
  193.  
  194.  
  195. class AIValueError(AIError, ValueError):  # pyright: ignore[reportUnsafeMultipleInheritance]
  196.     """A class for AI Value Errors"""
  197.  
  198.     def __init__(self, message: str) -> None:
  199.         super().__init__(message)
  200.  
  201.  
  202. class AITimeoutError(AIError, TimeoutError):  # pyright: ignore[reportUnsafeMultipleInheritance]
  203.     """A class for AI Task Timeout Errors"""
  204.  
  205.     def __init__(self, message: str) -> None:
  206.         super().__init__(message)
  207.  
  208.  
  209. # NOTE: API Clients cannot be called from multiple event loops,
  210. # So every asyncio event loop needs its own API connection
  211. ai_connections: dict[asyncio.AbstractEventLoop, AIConnection] = {}
  212.  
  213.  
  214. def get_ai_connection() -> AIConnection:
  215.     event_loop = asyncio.get_event_loop()
  216.     if event_loop not in ai_connections:
  217.         ai_connections[event_loop] = AIConnection()
  218.     return ai_connections[event_loop]
  219.  
  220.  
  221. # Cache (Default=1GB, LRU)
  222. CACHE_DIR = "./.cache"
  223. cache: dc.Cache | None
  224. if AI_CACHE_SIZE_LIMIT is None:
  225.     cache = None
  226. else:
  227.     os.makedirs(CACHE_DIR, exist_ok=True)
  228.     cache = dc.Cache(f"{CACHE_DIR}/ai_cache.db", size_limit=AI_CACHE_SIZE_LIMIT)
  229.  
  230. RATE_LIMIT_RATIO = 0.95
  231.  
  232.  
  233. def ai_num_tokens(
  234.     model: AIModel | AIEmbeddingModel | AIRerankModel, input_string: str
  235. ) -> int:
  236.     """
  237.    Calculate the number of tokens for a given string based on the AI model.
  238.  
  239.    Parameters
  240.    ----------
  241.    model : AIModel | AIEmbeddingModel | AIRerankModel
  242.        The AI model used to determine the tokenization rules. Can be an instance of `AIModel`,
  243.        `AIEmbeddingModel`, or `AIRerankModel`.
  244.  
  245.    input_string : str
  246.        The input string to be tokenized.
  247.  
  248.    Returns
  249.    -------
  250.    int
  251.        The number of tokens in the input string based on the tokenization rules of the provided model.
  252.  
  253.    Example
  254.    -------
  255.    ```python
  256.    model = AIModel(company="openai", model="gpt-4o")
  257.    num_tokens = ai_num_tokens(model, "Hello, world!")
  258.    print(num_tokens)
  259.    ```
  260.    """
  261.     if isinstance(model, AIModel):
  262.         match model.company:
  263.             case "anthropic":
  264.                 # Doesn't actually connect to the network
  265.                 return get_ai_connection().sync_anthropic_client.count_tokens(
  266.                     input_string
  267.                 )
  268.             case "openai":
  269.                 encoding = tiktoken.encoding_for_model(model.model)
  270.                 num_tokens = len(encoding.encode(input_string))
  271.                 return num_tokens
  272.     elif isinstance(model, AIEmbeddingModel):
  273.         match model.company:
  274.             case "openai":
  275.                 encoding = tiktoken.encoding_for_model(model.model)
  276.                 num_tokens = len(encoding.encode(input_string))
  277.                 return num_tokens
  278.             case "cohere":
  279.                 pass
  280.     # Otherwise, estimate
  281.     logger.warning("Estimating Tokens!")
  282.     return int(len(input_string) / 3.5)
  283.  
  284.  
  285. def ai_call_cache_key(
  286.     model: AIModel,
  287.     messages: list[AIMessage],
  288.     output_type: type,
  289. ) -> str:
  290.     # Hash the array of texts
  291.     md5_hasher = hashlib.md5()
  292.     md5_hasher.update(model.model_dump_json().encode())
  293.     for message in messages:
  294.         md5_hasher.update(md5_hasher.hexdigest().encode())
  295.         md5_hasher.update(message.model_dump_json().encode())
  296.     md5_hasher.update(md5_hasher.hexdigest().encode())
  297.     md5_hasher.update(f"{output_type}".encode())
  298.     key = md5_hasher.hexdigest()
  299.  
  300.     return key
  301.  
  302.  
  303. type T = BaseModel | str
  304.  
  305.  
  306. async def ai_call[T](
  307.     model: AIModel,
  308.     messages: list[AIMessage],
  309.     *,
  310.     max_tokens: int = 4096,
  311.     temperature: float = 0.0,
  312.     anthropic_initial_message: str | None = "<START>",
  313.     anthropic_combine_delimiter: str = "\n",
  314.     num_ratelimit_retries: int = 10,
  315.     backoff_algo: Callable[[int], float] = lambda i: min(2**i, 5),
  316.     output_type: type[T] = str,
  317. ) -> T:
  318.     """
  319.    Makes an asynchronous AI call to a given model with a set of messages, returning the result either as a string or
  320.    structured output based on the output type provided.
  321.  
  322.    Parameters:
  323.    ----------
  324.    model : AIModel
  325.        The AI model used for the call.
  326.  
  327.    messages : list[AIMessage]
  328.        A list of messages to send to the AI model. This typically includes the conversation context and user input.
  329.  
  330.    max_tokens : int, optional
  331.        The maximum number of tokens the AI model should generate. Defaults to 4096.
  332.  
  333.    temperature : float, optional
  334.        The sampling temperature for the model. Higher values will result in more random outputs, while lower values
  335.        will make the output more deterministic. Defaults to 0.0.
  336.  
  337.    anthropic_initial_message : str | None, optional
  338.        When using Anthropic's API, the first message must be from the user. If the first message isn't a user message,
  339.        this initial message will be prepended to the messages. Defaults to "<START>".
  340.  
  341.    anthropic_combine_delimiter : str, optional
  342.        The delimiter used to concatenate messages when two messages of the same role are passed to Anthropic's API.
  343.        Defaults to "\n".
  344.  
  345.    num_ratelimit_retries : int, optional
  346.        The number of retry attempts to make if the call is rate-limited. Defaults to 10.
  347.  
  348.    backoff_algo : Callable[[int], float], optional
  349.        A function that receives the index of the attempt and returns the backoff delay time (in seconds) between retries.
  350.        The default is an exponential backoff algorithm capped at 5 seconds: `lambda i: min(2**i, 5)`.
  351.  
  352.    output_type : type[T], optional
  353.        The expected output type of the call. Can be a `str` (for raw text output) or a Pydantic `BaseModel` (for
  354.        structured output). Defaults to `str`.
  355.  
  356.    Returns:
  357.    --------
  358.    T
  359.        The result of the AI call. The type of the result is determined by `output_type`:
  360.        - If `output_type` is `str`, the function returns a string.
  361.        - If `output_type` is a Pydantic `BaseModel`, the function returns an instance of the provided model.
  362.  
  363.    Raises:
  364.    -------
  365.    AITimeoutError
  366.        If the number of retry attempts exceeds `num_ratelimit_retries` without a successful response.
  367.  
  368.    AIValueError
  369.        If the messages array is in the incorrect format for the underlying API.
  370.  
  371.    Example:
  372.    --------
  373.    ```python
  374.    response: str = await ai_call(
  375.        model=AIModel(company="openai", model="gpt-4o-mini"),
  376.        messages=[
  377.            AIMessage(role="system", content="Be helpful"),
  378.            AIMessage(role="user", content="What does AI mean?"),
  379.        ],
  380.        output_type=MyResponseModel,
  381.    )
  382.    ```
  383.  
  384.    Structured Output
  385.    ```python
  386.    class ResponseModel(BaseModel):
  387.        year: int
  388.    response: ResponseModel = await ai_call(
  389.        model=AIModel(company="openai", model="gpt-4o-mini"),
  390.        messages=[
  391.            AIMessage(role="system", content="Be helpful"),
  392.            AIMessage(role="user", content="When was George Washington born?"),
  393.        ],
  394.        output_type=ResponseModel,
  395.    )
  396.    ```
  397.    """
  398.     cache_key = ai_call_cache_key(model, messages, output_type)
  399.     cached_call = cast(Any, cache.get(cache_key)) if cache is not None else None  # pyright: ignore[reportUnknownMemberType]
  400.  
  401.     if cached_call is not None:
  402.         assert isinstance(cached_call, output_type)
  403.         return cached_call
  404.  
  405.     num_tokens_input: int = sum(
  406.         [ai_num_tokens(model, message.content) for message in messages]
  407.     )
  408.  
  409.     return_value: T | None = None
  410.     match model.company:
  411.         case "openai":
  412.             for i in range(num_ratelimit_retries):
  413.                 try:
  414.                     await get_ai_connection().ai_wait_ratelimit(
  415.                         model, num_tokens_input, backoff_algo(i - 1) if i > 0 else None
  416.                     )
  417.  
  418.                     def ai_message_to_openai_message_param(
  419.                         message: AIMessage,
  420.                     ) -> ChatCompletionMessageParam:
  421.                         if message.role == "system":  # noqa: SIM114
  422.                             return {"role": message.role, "content": message.content}
  423.                         elif message.role == "user":  # noqa: SIM114
  424.                             return {"role": message.role, "content": message.content}
  425.                         elif message.role == "assistant":
  426.                             return {"role": message.role, "content": message.content}
  427.                         raise NotImplementedError("Unreachable Code")
  428.  
  429.                     if i > 0:
  430.                         logger.debug("Trying again after RateLimitError...")
  431.                     if output_type is str:
  432.                         response = await get_ai_connection().openai_client.chat.completions.create(
  433.                             model=model.model,
  434.                             messages=[
  435.                                 ai_message_to_openai_message_param(message)
  436.                                 for message in messages
  437.                             ],
  438.                             temperature=temperature,
  439.                             max_tokens=max_tokens,
  440.                         )
  441.                         response_content = response.choices[0].message.content
  442.                         assert response_content is not None
  443.                         assert isinstance(response_content, output_type)
  444.                         return_value = cast(T, response_content)
  445.                     else:
  446.                         response = await get_ai_connection().openai_client.beta.chat.completions.parse(
  447.                             model=model.model,
  448.                             messages=[
  449.                                 ai_message_to_openai_message_param(message)
  450.                                 for message in messages
  451.                             ],
  452.                             temperature=temperature,
  453.                             max_tokens=max_tokens,
  454.                             response_format=output_type,
  455.                         )
  456.                         response_parsed = response.choices[0].message.parsed
  457.                         assert response_parsed is not None
  458.                         assert isinstance(response_parsed, output_type)
  459.                         return_value = cast(T, response_parsed)
  460.                     break
  461.                 except (
  462.                     openai.RateLimitError,
  463.                     openai.APITimeoutError,
  464.                     openai.APIConnectionError,
  465.                 ) as e:
  466.                     logger.warning(f"OpenAI RateLimitError: {repr(e)}")
  467.             if return_value is None:
  468.                 raise AITimeoutError("Cannot overcome OpenAI RateLimitError")
  469.  
  470.         case "anthropic":
  471.             for i in range(num_ratelimit_retries):
  472.                 try:
  473.                     await get_ai_connection().ai_wait_ratelimit(
  474.                         model, num_tokens_input, backoff_algo(i - 1) if i > 0 else None
  475.                     )
  476.  
  477.                     def ai_message_to_anthropic_message_param(
  478.                         message: AIMessage,
  479.                     ) -> MessageParam:
  480.                         if message.role == "user" or message.role == "assistant":
  481.                             return {"role": message.role, "content": message.content}
  482.                         elif message.role == "system":
  483.                             raise AIValueError(
  484.                                 "system not allowed in anthropic message param"
  485.                             )
  486.                         raise NotImplementedError("Unreachable Code")
  487.  
  488.                     if i > 0:
  489.                         logger.debug("Trying again after RateLimitError...")
  490.  
  491.                     # Extract system message if it exists
  492.                     system: str | NotGiven = NOT_GIVEN
  493.                     if len(messages) > 0 and messages[0].role == "system":
  494.                         system = messages[0].content
  495.                         messages = messages[1:]
  496.                     # Insert initial message if necessary
  497.                     if (
  498.                         anthropic_initial_message is not None
  499.                         and len(messages) > 0
  500.                         and messages[0].role != "user"
  501.                     ):
  502.                         messages = [
  503.                             AIMessage(role="user", content=anthropic_initial_message)
  504.                         ] + messages
  505.                     # Combined messages (By combining consecutive messages of the same role)
  506.                     combined_messages: list[AIMessage] = []
  507.                     for message in messages:
  508.                         if (
  509.                             len(combined_messages) == 0
  510.                             or combined_messages[-1].role != message.role
  511.                         ):
  512.                             combined_messages.append(message)
  513.                         else:
  514.                             # Copy before edit
  515.                             combined_messages[-1] = combined_messages[-1].model_copy(
  516.                                 deep=True
  517.                             )
  518.                             # Merge consecutive messages with the same role
  519.                             combined_messages[-1].content += (
  520.                                 anthropic_combine_delimiter + message.content
  521.                             )
  522.                     # Get the response
  523.                     response_message = (
  524.                         await get_ai_connection().anthropic_client.messages.create(
  525.                             model=model.model,
  526.                             system=system,
  527.                             messages=[
  528.                                 ai_message_to_anthropic_message_param(message)
  529.                                 for message in combined_messages
  530.                             ],
  531.                             temperature=0.0,
  532.                             max_tokens=max_tokens,
  533.                         )
  534.                     )
  535.                     assert isinstance(
  536.                         response_message.content[0], anthropic.types.TextBlock
  537.                     )
  538.                     assert isinstance(response_message.content[0].text, str)
  539.                     if output_type is not str:
  540.                         raise NotImplementedError(
  541.                             "TODO: Implement Structured Output with Anthropic"
  542.                         )
  543.                     assert isinstance(response_message.content[0].text, output_type)
  544.                     return_value = cast(T, response_message.content[0].text)
  545.                     break
  546.                 except (anthropic.RateLimitError, anthropic.APIConnectionError) as e:
  547.                     logger.warning(f"Anthropic Error: {repr(e)}")
  548.             if return_value is None:
  549.                 raise AITimeoutError("Cannot overcome Anthropic RateLimitError")
  550.  
  551.     if cache is not None:
  552.         cache.set(cache_key, return_value)  # pyright: ignore[reportUnknownMemberType]
  553.     return return_value
  554.  
  555.  
  556. def get_embeddings_cache_key(
  557.     model: AIEmbeddingModel, text: str, embedding_type: AIEmbeddingType
  558. ) -> str:
  559.     md5_hasher = hashlib.md5()
  560.     md5_hasher.update(model.model_dump_json().encode())
  561.     md5_hasher.update(md5_hasher.hexdigest().encode())
  562.     md5_hasher.update(text.encode())
  563.     md5_hasher.update(md5_hasher.hexdigest().encode())
  564.     md5_hasher.update(embedding_type.name.encode())
  565.     hash = md5_hasher.hexdigest()
  566.     return hash
  567.  
  568.  
  569. AIEmbedding = np.ndarray[Literal[1], np.dtype[np.float32]]
  570. """numpy 1D Array of Floats"""
  571.  
  572.  
  573. def cosine_similarity(vec1: AIEmbedding, vec2: AIEmbedding) -> float:
  574.     """
  575.    Compute the cosine similarity between two embeddings.
  576.  
  577.    Parameters
  578.    ----------
  579.    vec1 : AIEmbedding
  580.        The first embedding vector.
  581.  
  582.    vec2 : AIEmbedding
  583.        The second embedding vector.
  584.  
  585.    Returns
  586.    -------
  587.    float
  588.        A value between -1 and 1 representing the cosine similarity between the two vectors.
  589.  
  590.    Example
  591.    -------
  592.    ```python
  593.    vec1: AIEmbedding = np.array([1,2,3])
  594.    vec2: AIEmbedding = np.array([1,2,3])
  595.    similarity = cosine_similarity(vec1, vec2)
  596.    print(similarity)
  597.    ```
  598.    """
  599.     return np.dot(vec1, vec2)
  600.  
  601.  
  602. async def ai_embedding(
  603.     model: AIEmbeddingModel,
  604.     texts: list[str],
  605.     embedding_type: AIEmbeddingType,
  606.     *,
  607.     num_ratelimit_retries: int = 10,
  608.     backoff_algo: Callable[[int], float] = lambda i: min(2**i, 5),
  609.     callback: Callable[[], None] = lambda: None,
  610. ) -> list[AIEmbedding]:
  611.     """
  612.    Generate embeddings for a list of texts using a specified AI embedding model.
  613.  
  614.    Parameters
  615.    ----------
  616.    model : AIEmbeddingModel
  617.        The AI model used to generate embeddings.
  618.  
  619.    texts : list[str]
  620.        A list of texts for which embeddings will be generated.
  621.  
  622.    embedding_type : AIEmbeddingType
  623.        The type of embedding to be generated (e.g., sentence-level, document-level).
  624.  
  625.    num_ratelimit_retries : int, optional
  626.        The number of retry attempts to make if the AI call is rate-limited. Defaults to 10.
  627.  
  628.    backoff_algo : Callable[[int], float], optional
  629.        A function that determines the backoff strategy between retries. It receives the retry attempt index
  630.        and returns the delay time in seconds. Defaults to an exponential backoff function.
  631.  
  632.    callback : Callable[[], None], optional
  633.        A callback function that is executed to track progress during embedding generation. Defaults to a no-op function.
  634.  
  635.    Returns
  636.    -------
  637.    list[AIEmbedding]
  638.        A list of embeddings generated by the AI model, where each embedding corresponds to an input text.
  639.  
  640.    Raises
  641.    ------
  642.    AITimeoutError
  643.        If the number of retry attempts exceeds `num_ratelimit_retries` without a successful response.
  644.  
  645.    Example
  646.    -------
  647.    ```python
  648.    model = AIEmbeddingModel()
  649.    texts = ["This is a sentence.", "Another sentence."]
  650.    embeddings = await ai_embedding(model, texts, AIEmbeddingType.SENTENCE)
  651.    print(embeddings)
  652.    ```
  653.    """
  654.     # Extract cache miss indices
  655.     text_embeddings: list[AIEmbedding | None] = [None] * len(texts)
  656.     if cache is not None:
  657.         with cache.transact():
  658.             for i, text in enumerate(texts):
  659.                 cache_key = get_embeddings_cache_key(model, text, embedding_type)
  660.                 cache_result = cast(Any, cache.get(cache_key))  # pyright: ignore[reportUnknownMemberType]
  661.                 if cache_result is not None:
  662.                     callback()
  663.                     if not isinstance(cache_result, np.ndarray):
  664.                         logger.warning("Invalid cache_result, ignoring...")
  665.                         continue
  666.                     cache_result = cast(AIEmbedding, cache_result)
  667.                     text_embeddings[i] = cache_result
  668.         if not any(embedding is None for embedding in text_embeddings):
  669.             return cast(list[AIEmbedding], text_embeddings)
  670.     required_text_embeddings_indices = [
  671.         i for i in range(len(text_embeddings)) if text_embeddings[i] is None
  672.     ]
  673.  
  674.     # Recursively Batch if necessary
  675.     if len(required_text_embeddings_indices) > model.max_batch_len:
  676.         # Calculate embeddings in batches
  677.         tasks: list[Coroutine[Any, Any, list[AIEmbedding]]] = []
  678.         for i in range(0, len(required_text_embeddings_indices), model.max_batch_len):
  679.             batch_indices = required_text_embeddings_indices[
  680.                 i : i + model.max_batch_len
  681.             ]
  682.             tasks.append(
  683.                 ai_embedding(
  684.                     model,
  685.                     [texts[i] for i in batch_indices],
  686.                     embedding_type,
  687.                     num_ratelimit_retries=num_ratelimit_retries,
  688.                     backoff_algo=backoff_algo,
  689.                     callback=callback,
  690.                 )
  691.             )
  692.         preflattened_results = await asyncio.gather(*tasks)
  693.         results: list[AIEmbedding] = []
  694.         for embeddings_list in preflattened_results:
  695.             results.extend(embeddings_list)
  696.         # Merge with cache hits
  697.         assert len(required_text_embeddings_indices) == len(results)
  698.         for i, embedding in zip(
  699.             required_text_embeddings_indices, results, strict=False
  700.         ):
  701.             text_embeddings[i] = embedding
  702.         assert all(embedding is not None for embedding in text_embeddings)
  703.         return cast(list[AIEmbedding], text_embeddings)
  704.  
  705.     num_tokens_input: int = sum(
  706.         [
  707.             ai_num_tokens(model, texts[index])
  708.             for index in required_text_embeddings_indices
  709.         ]
  710.     )
  711.  
  712.     input_texts = [texts[i] for i in required_text_embeddings_indices]
  713.     text_embeddings_response: list[AIEmbedding] | None = None
  714.     match model.company:
  715.         case "openai":
  716.             for i in range(num_ratelimit_retries):
  717.                 try:
  718.                     await get_ai_connection().ai_wait_ratelimit(
  719.                         model, num_tokens_input, backoff_algo(i - 1) if i > 0 else None
  720.                     )
  721.                     response = (
  722.                         await get_ai_connection().openai_client.embeddings.create(
  723.                             input=input_texts,
  724.                             model=model.model,
  725.                             encoding_format="base64",
  726.                         )
  727.                     )
  728.                     response_embeddings: list[AIEmbedding] = []
  729.                     for embedding_obj in response.data:
  730.                         data = cast(object, embedding_obj.embedding)
  731.                         if not isinstance(data, str):
  732.                             # numpy is not installed / base64 optimisation isn't enabled for this model yet
  733.                             raise RuntimeError("Error with base64/numpy")
  734.  
  735.                         response_embeddings.append(
  736.                             np.frombuffer(base64.b64decode(data), dtype="float32")
  737.                         )
  738.                     text_embeddings_response = response_embeddings
  739.                     break
  740.                 except (
  741.                     openai.RateLimitError,
  742.                     openai.APITimeoutError,
  743.                 ):
  744.                     logger.warning("OpenAI RateLimitError")
  745.                 except openai.APIError as e:
  746.                     logger.warning(f"OpenAI Unknown Error: {repr(e)}")
  747.             if text_embeddings_response is None:
  748.                 raise AITimeoutError("Cannot overcome OpenAI RateLimitError")
  749.         case "cohere":
  750.             for i in range(num_ratelimit_retries):
  751.                 try:
  752.                     await get_ai_connection().ai_wait_ratelimit(
  753.                         model, num_tokens_input, backoff_algo(i - 1) if i > 0 else None
  754.                     )
  755.                     result = await get_ai_connection().cohere_client.embed(
  756.                         texts=input_texts,
  757.                         model=model.model,
  758.                         input_type=(
  759.                             "search_document"
  760.                             if embedding_type == AIEmbeddingType.DOCUMENT
  761.                             else "search_query"
  762.                         ),
  763.                     )
  764.                     assert isinstance(result.embeddings, list)
  765.                     text_embeddings_response = [
  766.                         np.array(embedding) for embedding in result.embeddings
  767.                     ]
  768.                     break
  769.                 except (
  770.                     cohere.errors.TooManyRequestsError,
  771.                     httpx.ConnectError,
  772.                     httpx.RemoteProtocolError,
  773.                 ):
  774.                     logger.warning("Cohere RateLimitError")
  775.             if text_embeddings_response is None:
  776.                 raise AITimeoutError("Cannot overcome Cohere RateLimitError")
  777.     if cache is not None:
  778.         with cache.transact():
  779.             assert len(text_embeddings_response) == len(
  780.                 required_text_embeddings_indices
  781.             )
  782.             for index, embedding in zip(
  783.                 required_text_embeddings_indices, text_embeddings_response, strict=False
  784.             ):
  785.                 cache_key = get_embeddings_cache_key(
  786.                     model, texts[index], embedding_type
  787.                 )
  788.                 cache.set(cache_key, embedding)  # pyright: ignore[reportUnknownMemberType]
  789.     for index, embedding in zip(
  790.         required_text_embeddings_indices, text_embeddings_response, strict=False
  791.     ):
  792.         text_embeddings[index] = embedding
  793.         callback()
  794.     assert all(embedding is not None for embedding in text_embeddings)
  795.     return cast(list[AIEmbedding], text_embeddings)
  796.  
  797.  
  798. def get_rerank_cache_key(model: AIRerankModel, query: str, text: str) -> str:
  799.     md5_hasher = hashlib.md5()
  800.     md5_hasher.update(model.model_dump_json().encode())
  801.     md5_hasher.update(md5_hasher.hexdigest().encode())
  802.     md5_hasher.update(query.encode())
  803.     md5_hasher.update(md5_hasher.hexdigest().encode())
  804.     md5_hasher.update(text.encode())
  805.     hash = md5_hasher.hexdigest()
  806.     return hash
  807.  
  808.  
  809. async def ai_rerank(
  810.     model: AIRerankModel,
  811.     query: str,
  812.     texts: list[str],
  813.     *,
  814.     num_ratelimit_retries: int = 10,
  815.     backoff_algo: Callable[[int], float] = lambda i: min(2**i, 5),
  816. ) -> list[float]:
  817.     """
  818.    Calculate the similarity score between a query and a text, using a cross-encoder model (Commonly known as a reranker)
  819.  
  820.    Parameters
  821.    ----------
  822.    model : AIRerankModel
  823.        The AI model used for calculating cross-encoder similarity scores between the texts and the query.
  824.  
  825.    query : str
  826.        The query used to evaluate the relevance of the texts.
  827.  
  828.    texts : list[str]
  829.        A list of texts to calculate similarity scores against.
  830.  
  831.    num_ratelimit_retries : int, optional
  832.        The number of retry attempts to make if the AI call is rate-limited. Defaults to 10.
  833.  
  834.    backoff_algo : Callable[[int], float], optional
  835.        A function that defines the backoff strategy between retries. It receives the retry attempt index and
  836.        returns the delay time in seconds. Defaults to an exponential backoff function.
  837.  
  838.    Returns
  839.    -------
  840.    list[float]
  841.        A list of scores corresponding to the relevance of each text to the query. Higher scores indicate
  842.        higher relevance.
  843.  
  844.    Raises
  845.    ------
  846.    AITimeoutError
  847.        If the number of retry attempts exceeds `num_ratelimit_retries` without a successful response.
  848.  
  849.    Example
  850.    -------
  851.    ```python
  852.    model = AIRerankModel(company="cohere", model="embed-multilingual-v3.0")
  853.    query = "What is AI?"
  854.    texts = ["AI is artificial intelligence.", "AI is used in various fields."]
  855.    scores = await ai_rerank(model, query, texts)
  856.    print(scores)
  857.    ```
  858.    """
  859.  
  860.     text_scores: list[float | None] = [None] * len(texts)
  861.     if cache is not None:
  862.         with cache.transact():
  863.             for i, text in enumerate(texts):
  864.                 cache_key = get_rerank_cache_key(model, query, text)
  865.                 cache_result = cast(Any, cache.get(cache_key))  # pyright: ignore[reportUnknownMemberType]
  866.                 if cache_result is not None:
  867.                     # cast instead of assert isinstance, because of ints
  868.                     cache_result = float(cache_result)
  869.                     text_scores[i] = cache_result
  870.     if all(score is not None for score in text_scores):
  871.         return cast(list[float], text_scores)
  872.  
  873.     unprocessed_indices = [i for i, score in enumerate(text_scores) if score is None]
  874.     unprocessed_texts = [texts[i] for i in unprocessed_indices]
  875.     num_tokens_input = sum(ai_num_tokens(model, text) for text in unprocessed_texts)
  876.  
  877.     relevancy_scores: list[float] | None = None
  878.     match model.company:
  879.         case "cohere":
  880.             for i in range(num_ratelimit_retries):
  881.                 try:
  882.                     await get_ai_connection().ai_wait_ratelimit(
  883.                         model, num_tokens_input, backoff_algo(i - 1) if i > 0 else None
  884.                     )
  885.                     response = await get_ai_connection().cohere_client.rerank(
  886.                         model=model.model,
  887.                         query=query,
  888.                         documents=unprocessed_texts,
  889.                     )
  890.                     original_order_results = sorted(
  891.                         response.results, key=lambda x: x.index
  892.                     )
  893.                     relevancy_scores = [
  894.                         result.relevance_score for result in original_order_results
  895.                     ]
  896.                     break
  897.                 except (
  898.                     cohere.errors.TooManyRequestsError,
  899.                     httpx.ConnectError,
  900.                     httpx.RemoteProtocolError,
  901.                 ):
  902.                     logger.warning("Cohere RateLimitError")
  903.             if relevancy_scores is None:
  904.                 raise AITimeoutError("Cannot overcome Cohere RateLimitError")
  905.     assert len(unprocessed_indices) == len(relevancy_scores)
  906.     if cache is not None:
  907.         with cache.transact():
  908.             for index, score in zip(
  909.                 unprocessed_indices, relevancy_scores, strict=False
  910.             ):
  911.                 cache_key = get_rerank_cache_key(model, query, texts[index])
  912.                 cache.set(cache_key, score)  # pyright: ignore[reportUnknownMemberType]
  913.     for index, score in zip(unprocessed_indices, relevancy_scores, strict=False):
  914.         text_scores[index] = score
  915.  
  916.     assert all(score is not None for score in text_scores)
  917.     return cast(list[float], text_scores)
  918.  
Add Comment
Please, Sign In to add comment