Advertisement
Nickpips

AI Python File

Aug 13th, 2024 (edited)
281
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 24.83 KB | None | 0 0
  1. import asyncio
  2. import hashlib
  3. import logging
  4. import os
  5. from collections.abc import Callable, Coroutine
  6. from enum import Enum
  7. from typing import Any, Literal, cast
  8.  
  9. import anthropic
  10. import cohere
  11. import diskcache as dc  # type: ignore
  12. import httpx
  13. import openai
  14. import tiktoken
  15. import voyageai  # type: ignore
  16. import voyageai.error  # type: ignore
  17. from anthropic import NOT_GIVEN, Anthropic, AsyncAnthropic, NotGiven
  18. from anthropic.types import MessageParam
  19. from openai import AsyncOpenAI, RateLimitError
  20. from openai.types.chat import ChatCompletionMessageParam
  21. from pydantic import BaseModel, computed_field
  22.  
  23. from utils.credentials import credentials
  24.  
  25. logger = logging.getLogger("uvicorn")
  26.  
  27. # AI Types
  28.  
  29.  
  30. class AIModel(BaseModel):
  31.     company: Literal["openai", "anthropic"]
  32.     model: str
  33.  
  34.     @computed_field  # type: ignore[misc]
  35.     @property
  36.     def ratelimit_tpm(self) -> float:
  37.         match self.company:
  38.             case "openai":
  39.                 # Tier 5
  40.                 match self.model:
  41.                     case "gpt-4o-mini":
  42.                         return 150000000
  43.                     case "gpt-4o":
  44.                         return 30000000
  45.                     case m if m.startswith("gpt-4-turbo"):
  46.                         return 2000000
  47.                     case _:
  48.                         return 1000000
  49.             case "anthropic":
  50.                 # Tier 4
  51.                 return 400000
  52.  
  53.  
  54. class AIMessage(BaseModel):
  55.     role: Literal["system", "user", "assistant"]
  56.     content: str
  57.  
  58.  
  59. class AIEmbeddingModel(BaseModel):
  60.     company: Literal["openai", "cohere", "voyageai"]
  61.     model: str
  62.  
  63.     @computed_field  # type: ignore[misc]
  64.     @property
  65.     def ratelimit_tpm(self) -> float:
  66.         match self.company:
  67.             case "openai":
  68.                 return 1000000
  69.             case "cohere":
  70.                 # 96 texts per embed
  71.                 return 10000 * 96
  72.             case "voyageai":
  73.                 # It says 300RPM but I can only get 30 out of it
  74.                 return 1000000
  75.  
  76.     @computed_field  # type: ignore[misc]
  77.     @property
  78.     def ratelimit_rpm(self) -> float:
  79.         match self.company:
  80.             case "openai":
  81.                 return 5000
  82.             case "cohere":
  83.                 return 10000
  84.             case "voyageai":
  85.                 # It says 300RPM but I can only get 30 out of it
  86.                 return 30
  87.  
  88.     @computed_field  # type: ignore[misc]
  89.     @property
  90.     def max_batch_len(self) -> int:
  91.         match self.company:
  92.             case "openai":
  93.                 return 2048
  94.             case "cohere":
  95.                 return 96
  96.             case "voyageai":
  97.                 return 128
  98.  
  99.  
  100. class AIEmbeddingType(Enum):
  101.     DOCUMENT = 1
  102.     QUERY = 2
  103.  
  104.  
  105. class AIRerankModel(BaseModel):
  106.     company: Literal["cohere", "voyageai"]
  107.     model: str
  108.  
  109.     @computed_field  # type: ignore[misc]
  110.     @property
  111.     def ratelimit_rpm(self) -> float:
  112.         match self.company:
  113.             case "cohere":
  114.                 return 10000
  115.             case "voyageai":
  116.                 # It says 100RPM but I can only get 60 out of it
  117.                 return 60
  118.  
  119.  
  120. # Cache
  121. os.makedirs("./data/cache", exist_ok=True)
  122. cache = dc.Cache("./data/cache/ai_cache.db")
  123.  
  124. RATE_LIMIT_RATIO = 0.95
  125.  
  126.  
  127. class AIConnection:
  128.     openai_client: AsyncOpenAI
  129.     voyageai_client: voyageai.AsyncClient
  130.     cohere_client: cohere.AsyncClient
  131.     anthropic_client: AsyncAnthropic
  132.     sync_anthropic_client: Anthropic
  133.     # Share one global Semaphore across all threads
  134.     cohere_ratelimit_semaphore = asyncio.Semaphore(1)
  135.     voyageai_ratelimit_semaphore = asyncio.Semaphore(1)
  136.     openai_ratelimit_semaphore = asyncio.Semaphore(1)
  137.     anthropic_ratelimit_semaphore = asyncio.Semaphore(1)
  138.  
  139.     def __init__(self) -> None:
  140.         self.openai_client = AsyncOpenAI(
  141.             api_key=credentials.ai.openai_api_key.get_secret_value()
  142.         )
  143.         self.anthropic_client = AsyncAnthropic(
  144.             api_key=credentials.ai.anthropic_api_key.get_secret_value()
  145.         )
  146.         self.sync_anthropic_client = Anthropic(
  147.             api_key=credentials.ai.anthropic_api_key.get_secret_value()
  148.         )
  149.         self.voyageai_client = voyageai.AsyncClient(
  150.             api_key=credentials.ai.voyageai_api_key.get_secret_value()
  151.         )
  152.         self.cohere_client = cohere.AsyncClient(
  153.             api_key=credentials.ai.cohere_api_key.get_secret_value()
  154.         )
  155.  
  156.  
  157. # NOTE: API Clients cannot be called from multiple event loops,
  158. # So every asyncio event loop needs its own API connection
  159. ai_connections: dict[asyncio.AbstractEventLoop, AIConnection] = {}
  160.  
  161.  
  162. def get_ai_connection() -> AIConnection:
  163.     event_loop = asyncio.get_event_loop()
  164.     if event_loop not in ai_connections:
  165.         ai_connections[event_loop] = AIConnection()
  166.     return ai_connections[event_loop]
  167.  
  168.  
  169. class AIError(Exception):
  170.     """A class for AI Task Errors"""
  171.  
  172.  
  173. class AIValueError(AIError, ValueError):
  174.     """A class for AI Value Errors"""
  175.  
  176.  
  177. class AITimeoutError(AIError, TimeoutError):
  178.     """A class for AI Task Timeout Errors"""
  179.  
  180.  
  181. def ai_num_tokens(model: AIModel | AIEmbeddingModel | AIRerankModel, s: str) -> int:
  182.     if isinstance(model, AIModel):
  183.         if model.company == "anthropic":
  184.             # Doesn't actually connect to the network
  185.             return get_ai_connection().sync_anthropic_client.count_tokens(s)
  186.         elif model.company == "openai":
  187.             encoding = tiktoken.encoding_for_model(model.model)
  188.             num_tokens = len(encoding.encode(s))
  189.             return num_tokens
  190.     if isinstance(model, AIEmbeddingModel):
  191.         if model.company == "openai":
  192.             encoding = tiktoken.encoding_for_model(model.model)
  193.             num_tokens = len(encoding.encode(s))
  194.             return num_tokens
  195.         elif model.company == "voyageai":
  196.             return get_ai_connection().voyageai_client.count_tokens([s], model.model)
  197.     # Otherwise, estimate
  198.     logger.warning("Estimating Tokens!")
  199.     return int(len(s) / 4)
  200.  
  201.  
  202. def get_call_cache_key(
  203.     model: AIModel,
  204.     messages: list[AIMessage],
  205. ) -> str:
  206.     # Hash the array of texts
  207.     md5_hasher = hashlib.md5()
  208.     md5_hasher.update(model.model_dump_json().encode())
  209.     for message in messages:
  210.         md5_hasher.update(md5_hasher.hexdigest().encode())
  211.         md5_hasher.update(message.model_dump_json().encode())
  212.     key = md5_hasher.hexdigest()
  213.  
  214.     return key
  215.  
  216.  
  217. async def ai_call(
  218.     model: AIModel,
  219.     messages: list[AIMessage],
  220.     *,
  221.     max_tokens: int = 4096,
  222.     temperature: float = 0.0,
  223.     # When using anthropic, the first message must be from the user.
  224.     # If the first message is not a User, this message will be prepended to the messages.
  225.     anthropic_initial_message: str | None = "<START>",
  226.     # If two messages of the same role are given to anthropic, they must be concatenated.
  227.     # This is the delimiter between concatenated.
  228.     anthropic_combine_delimiter: str = "\n",
  229.     # Throw an AITimeoutError after this many retries fail
  230.     num_ratelimit_retries: int = 10,
  231.     # Backoff function (Receives index of attempt)
  232.     backoff_algo: Callable[[int], float] = lambda i: min(2**i, 5),
  233. ) -> str:
  234.     cache_key = get_call_cache_key(model, messages)
  235.     cached_call = cache.get(cache_key)
  236.  
  237.     if cached_call is not None:
  238.         return cached_call
  239.  
  240.     num_tokens_input: int = sum(
  241.         [ai_num_tokens(model, message.content) for message in messages]
  242.     )
  243.  
  244.     return_value: str | None = None
  245.     match model.company:
  246.         case "openai":
  247.             for i in range(num_ratelimit_retries):
  248.                 try:
  249.                     # Guard with ratelimit
  250.                     async with get_ai_connection().openai_ratelimit_semaphore:
  251.                         tpm = model.ratelimit_tpm * RATE_LIMIT_RATIO
  252.                         expected_wait = num_tokens_input / (tpm / 60)
  253.                         await asyncio.sleep(expected_wait)
  254.  
  255.                     def ai_message_to_openai_message_param(
  256.                         message: AIMessage,
  257.                     ) -> ChatCompletionMessageParam:
  258.                         if message.role == "system":  # noqa: SIM114
  259.                             return {"role": message.role, "content": message.content}
  260.                         elif message.role == "user":  # noqa: SIM114
  261.                             return {"role": message.role, "content": message.content}
  262.                         elif message.role == "assistant":
  263.                             return {"role": message.role, "content": message.content}
  264.  
  265.                     if i > 0:
  266.                         logger.debug("Trying again after RateLimitError...")
  267.                     response = (
  268.                         await get_ai_connection().openai_client.chat.completions.create(
  269.                             model=model.model,
  270.                             messages=[
  271.                                 ai_message_to_openai_message_param(message)
  272.                                 for message in messages
  273.                             ],
  274.                             temperature=temperature,
  275.                             max_tokens=max_tokens,
  276.                         )
  277.                     )
  278.                     assert response.choices[0].message.content is not None
  279.                     return_value = response.choices[0].message.content
  280.                     break
  281.                 except RateLimitError:
  282.                     logger.warning("OpenAI RateLimitError")
  283.                     async with get_ai_connection().openai_ratelimit_semaphore:
  284.                         await asyncio.sleep(backoff_algo(i))
  285.             if return_value is None:
  286.                 raise AITimeoutError("Cannot overcome OpenAI RateLimitError")
  287.  
  288.         case "anthropic":
  289.             for i in range(num_ratelimit_retries):
  290.                 try:
  291.                     # Guard with ratelimit
  292.                     async with get_ai_connection().anthropic_ratelimit_semaphore:
  293.                         tpm = model.ratelimit_tpm * RATE_LIMIT_RATIO
  294.                         expected_wait = num_tokens_input / (tpm / 60)
  295.                         await asyncio.sleep(expected_wait)
  296.  
  297.                     def ai_message_to_anthropic_message_param(
  298.                         message: AIMessage,
  299.                     ) -> MessageParam:
  300.                         if message.role == "user" or message.role == "assistant":
  301.                             return {"role": message.role, "content": message.content}
  302.                         elif message.role == "system":
  303.                             raise AIValueError(
  304.                                 "system not allowed in anthropic message param"
  305.                             )
  306.  
  307.                     if i > 0:
  308.                         logger.debug("Trying again after RateLimitError...")
  309.  
  310.                     # Extract system message if it exists
  311.                     system: str | NotGiven = NOT_GIVEN
  312.                     if len(messages) > 0 and messages[0].role == "system":
  313.                         system = messages[0].content
  314.                         messages = messages[1:]
  315.                     # Insert initial message if necessary
  316.                     if (
  317.                         anthropic_initial_message is not None
  318.                         and len(messages) > 0
  319.                         and messages[0].role != "user"
  320.                     ):
  321.                         messages = [
  322.                             AIMessage(role="user", content=anthropic_initial_message)
  323.                         ] + messages
  324.                     # Combined messages (By combining consecutive messages of the same role)
  325.                     combined_messages: list[AIMessage] = []
  326.                     for message in messages:
  327.                         if (
  328.                             len(combined_messages) == 0
  329.                             or combined_messages[-1].role != message.role
  330.                         ):
  331.                             combined_messages.append(message)
  332.                         else:
  333.                             # Copy before edit
  334.                             combined_messages[-1] = combined_messages[-1].model_copy(
  335.                                 deep=True
  336.                             )
  337.                             # Merge consecutive messages with the same role
  338.                             combined_messages[-1].content += (
  339.                                 anthropic_combine_delimiter + message.content
  340.                             )
  341.                     # Get the response
  342.                     response_message = (
  343.                         await get_ai_connection().anthropic_client.messages.create(
  344.                             model=model.model,
  345.                             system=system,
  346.                             messages=[
  347.                                 ai_message_to_anthropic_message_param(message)
  348.                                 for message in combined_messages
  349.                             ],
  350.                             temperature=0.0,
  351.                             max_tokens=max_tokens,
  352.                         )
  353.                     )
  354.                     assert isinstance(
  355.                         response_message.content[0], anthropic.types.TextBlock
  356.                     )
  357.                     assert isinstance(response_message.content[0].text, str)
  358.                     return_value = response_message.content[0].text
  359.                     break
  360.                 except anthropic.RateLimitError as e:
  361.                     logger.warning(f"Anthropic Error: {repr(e)}")
  362.                     async with get_ai_connection().anthropic_ratelimit_semaphore:
  363.                         await asyncio.sleep(backoff_algo(i))
  364.             if return_value is None:
  365.                 raise AITimeoutError("Cannot overcome Anthropic RateLimitError")
  366.  
  367.     cache.set(cache_key, return_value)
  368.     return return_value
  369.  
  370.  
  371. def get_embeddings_cache_key(
  372.     model: AIEmbeddingModel, text: str, embedding_type: AIEmbeddingType
  373. ) -> str:
  374.     key = f"{model.company}||||{model.model}||||{embedding_type.name}||||{hashlib.md5(text.encode()).hexdigest()}"
  375.     return key
  376.  
  377.  
  378. async def ai_embedding(
  379.     model: AIEmbeddingModel,
  380.     texts: list[str],
  381.     embedding_type: AIEmbeddingType,
  382.     *,
  383.     # Throw an AITimeoutError after this many retries fail
  384.     num_ratelimit_retries: int = 10,
  385.     # Backoff function (Receives index of attempt)
  386.     backoff_algo: Callable[[int], float] = lambda i: min(2**i, 5),
  387.     # Callback (For tracking progress)
  388.     callback: Callable[[], None] = lambda: None,
  389. ) -> list[list[float]]:
  390.     # Extract cache miss indices
  391.     text_embeddings: list[list[float] | None] = [None] * len(texts)
  392.     for i, text in enumerate(texts):
  393.         cache_key = get_embeddings_cache_key(model, text, embedding_type)
  394.         text_embeddings[i] = cache.get(cache_key)
  395.         if text_embeddings[i] is not None:
  396.             callback()
  397.     if not any(embedding is None for embedding in text_embeddings):
  398.         return cast(list[list[float]], text_embeddings)
  399.     required_text_embeddings_indices = [
  400.         i for i in range(len(text_embeddings)) if text_embeddings[i] is None
  401.     ]
  402.  
  403.     # Recursively Batch if necessary
  404.     if len(required_text_embeddings_indices) > model.max_batch_len:
  405.         # Calculate embeddings in batches
  406.         tasks: list[Coroutine[Any, Any, list[list[float]]]] = []
  407.         for i in range(0, len(required_text_embeddings_indices), model.max_batch_len):
  408.             batch_indices = required_text_embeddings_indices[
  409.                 i : i + model.max_batch_len
  410.             ]
  411.             tasks.append(
  412.                 ai_embedding(
  413.                     model,
  414.                     [texts[i] for i in batch_indices],
  415.                     embedding_type,
  416.                     num_ratelimit_retries=num_ratelimit_retries,
  417.                     backoff_algo=backoff_algo,
  418.                     callback=callback,
  419.                 )
  420.             )
  421.         preflattened_results = await asyncio.gather(*tasks)
  422.         results: list[list[float]] = []
  423.         for embeddings_list in preflattened_results:
  424.             results.extend(embeddings_list)
  425.         # Merge with cache hits
  426.         assert len(required_text_embeddings_indices) == len(results)
  427.         for i, embedding in zip(required_text_embeddings_indices, results):
  428.             text_embeddings[i] = embedding
  429.         assert all(embedding is not None for embedding in text_embeddings)
  430.         return cast(list[list[float]], text_embeddings)
  431.  
  432.     num_tokens_input: int = sum(
  433.         [
  434.             ai_num_tokens(model, texts[index])
  435.             for index in required_text_embeddings_indices
  436.         ]
  437.     )
  438.  
  439.     input_texts = [texts[i] for i in required_text_embeddings_indices]
  440.     text_embeddings_response: list[list[float]] | None = None
  441.     match model.company:
  442.         case "openai":
  443.             for i in range(num_ratelimit_retries):
  444.                 try:
  445.                     async with get_ai_connection().openai_ratelimit_semaphore:
  446.                         rpm = model.ratelimit_rpm * RATE_LIMIT_RATIO
  447.                         tpm = model.ratelimit_tpm * RATE_LIMIT_RATIO
  448.                         expected_wait = max(60.0 / rpm, num_tokens_input / (tpm / 60))
  449.                         await asyncio.sleep(expected_wait)
  450.                     response = (
  451.                         await get_ai_connection().openai_client.embeddings.create(
  452.                             input=input_texts,
  453.                             model=model.model,
  454.                         )
  455.                     )
  456.                     text_embeddings_response = [
  457.                         embedding.embedding for embedding in response.data
  458.                     ]
  459.                     break
  460.                 except (
  461.                     openai.RateLimitError,
  462.                     openai.APIConnectionError,
  463.                     openai.APITimeoutError,
  464.                 ):
  465.                     logger.warning("OpenAI RateLimitError")
  466.                     async with get_ai_connection().openai_ratelimit_semaphore:
  467.                         await asyncio.sleep(backoff_algo(i))
  468.             if text_embeddings_response is None:
  469.                 raise AITimeoutError("Cannot overcome OpenAI RateLimitError")
  470.         case "cohere":
  471.             for i in range(num_ratelimit_retries):
  472.                 try:
  473.                     async with get_ai_connection().cohere_ratelimit_semaphore:
  474.                         rpm = model.ratelimit_rpm * RATE_LIMIT_RATIO
  475.                         tpm = model.ratelimit_tpm * RATE_LIMIT_RATIO
  476.                         expected_wait = max(60.0 / rpm, num_tokens_input / (tpm / 60))
  477.                         await asyncio.sleep(expected_wait)
  478.                     result = await get_ai_connection().cohere_client.embed(
  479.                         texts=input_texts,
  480.                         model=model.model,
  481.                         input_type=(
  482.                             "search_document"
  483.                             if embedding_type == AIEmbeddingType.DOCUMENT
  484.                             else "search_query"
  485.                         ),
  486.                     )
  487.                     assert isinstance(result.embeddings, list)
  488.                     text_embeddings_response = result.embeddings
  489.                     break
  490.                 except voyageai.error.RateLimitError:
  491.                     logger.warning("Cohere RateLimitError")
  492.                     async with get_ai_connection().cohere_ratelimit_semaphore:
  493.                         await asyncio.sleep(backoff_algo(i))
  494.             if text_embeddings_response is None:
  495.                 raise AITimeoutError("Cannot overcome Cohere RateLimitError")
  496.         case "voyageai":
  497.             for i in range(num_ratelimit_retries):
  498.                 try:
  499.                     async with get_ai_connection().voyageai_ratelimit_semaphore:
  500.                         rpm = model.ratelimit_rpm * RATE_LIMIT_RATIO
  501.                         tpm = model.ratelimit_tpm * RATE_LIMIT_RATIO
  502.                         expected_wait = max(60.0 / rpm, num_tokens_input / (tpm / 60))
  503.                         await asyncio.sleep(expected_wait)
  504.                     result = await get_ai_connection().voyageai_client.embed(
  505.                         input_texts,
  506.                         model=model.model,
  507.                         input_type=(
  508.                             "document"
  509.                             if embedding_type == AIEmbeddingType.DOCUMENT
  510.                             else "query"
  511.                         ),
  512.                     )
  513.                     assert isinstance(result.embeddings, list)
  514.                     text_embeddings_response = result.embeddings
  515.                     break
  516.                 except voyageai.error.RateLimitError:
  517.                     logger.warning("VoyageAI RateLimitError")
  518.                     async with get_ai_connection().voyageai_ratelimit_semaphore:
  519.                         await asyncio.sleep(backoff_algo(i))
  520.             if text_embeddings_response is None:
  521.                 raise AITimeoutError("Cannot overcome VoyageAI RateLimitError")
  522.  
  523.     assert len(text_embeddings_response) == len(required_text_embeddings_indices)
  524.     for index, embedding in zip(
  525.         required_text_embeddings_indices, text_embeddings_response
  526.     ):
  527.         cache_key = get_embeddings_cache_key(model, texts[index], embedding_type)
  528.         cache.set(cache_key, embedding)
  529.         text_embeddings[index] = embedding
  530.         callback()
  531.     assert all(embedding is not None for embedding in text_embeddings)
  532.     return cast(list[list[float]], text_embeddings)
  533.  
  534.  
  535. def get_rerank_cache_key(
  536.     model: AIRerankModel, query: str, texts: list[str], top_k: int | None
  537. ) -> str:
  538.     # Hash the array of texts
  539.     md5_hasher = hashlib.md5()
  540.     md5_hasher.update(query.encode())
  541.     for text in texts:
  542.         md5_hasher.update(md5_hasher.hexdigest().encode())
  543.         md5_hasher.update(text.encode())
  544.     texts_hash = md5_hasher.hexdigest()
  545.  
  546.     key = f"{model.company}||||{model.model}||||{top_k}||||{texts_hash}"
  547.     return key
  548.  
  549.  
  550. # Gets the list of indices that reranks the original texts
  551. async def ai_rerank(
  552.     model: AIRerankModel,
  553.     query: str,
  554.     texts: list[str],
  555.     *,
  556.     top_k: int | None = None,
  557.     # Throw an AITimeoutError after this many retries fail
  558.     num_ratelimit_retries: int = 10,
  559.     # Backoff function (Receives index of attempt)
  560.     backoff_algo: Callable[[int], float] = lambda i: min(2**i, 5),
  561. ) -> list[int]:
  562.     cache_key = get_rerank_cache_key(model, query, texts, top_k)
  563.     cached_reranking = cache.get(cache_key)
  564.  
  565.     if cached_reranking is not None:
  566.         return cached_reranking
  567.  
  568.     indices: list[int] | None = None
  569.     match model.company:
  570.         case "cohere":
  571.             for i in range(num_ratelimit_retries):
  572.                 try:
  573.                     async with get_ai_connection().cohere_ratelimit_semaphore:
  574.                         rpm = model.ratelimit_rpm * RATE_LIMIT_RATIO
  575.                         await asyncio.sleep(60.0 / rpm)
  576.                     response = await get_ai_connection().cohere_client.rerank(
  577.                         model=model.model,
  578.                         query=query,
  579.                         documents=texts,
  580.                         top_n=top_k,
  581.                     )
  582.                     indices = [result.index for result in response.results]
  583.                     break
  584.                 except (
  585.                     cohere.errors.TooManyRequestsError,
  586.                     httpx.ConnectError,
  587.                     httpx.RemoteProtocolError,
  588.                 ):
  589.                     logger.warning("Cohere RateLimitError")
  590.                     async with get_ai_connection().cohere_ratelimit_semaphore:
  591.                         await asyncio.sleep(backoff_algo(i))
  592.             if indices is None:
  593.                 raise AITimeoutError("Cannot overcome Cohere RateLimitError")
  594.         case "voyageai":
  595.             for i in range(num_ratelimit_retries):
  596.                 try:
  597.                     async with get_ai_connection().voyageai_ratelimit_semaphore:
  598.                         rpm = model.ratelimit_rpm * RATE_LIMIT_RATIO
  599.                         await asyncio.sleep(60.0 / rpm)
  600.                     voyageai_response = (
  601.                         await get_ai_connection().voyageai_client.rerank(
  602.                             query=query,
  603.                             documents=texts,
  604.                             model=model.model,
  605.                             top_k=top_k,
  606.                         )
  607.                     )
  608.                     indices = [
  609.                         int(result.index) for result in voyageai_response.results
  610.                     ]
  611.                     break
  612.                 except voyageai.error.RateLimitError:
  613.                     logger.warning("VoyageAI RateLimitError")
  614.                     async with get_ai_connection().voyageai_ratelimit_semaphore:
  615.                         await asyncio.sleep(backoff_algo(i))
  616.             if indices is None:
  617.                 raise AITimeoutError("Cannot overcome VoyageAI RateLimitError")
  618.     cache.set(cache_key, indices)
  619.     return indices
  620.  
Tags: ai.py
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement