Advertisement
Nickpips

test_chat.py

Dec 27th, 2024
67
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.08 KB | None | 0 0
  1. import asyncio
  2. from uuid import uuid4
  3.  
  4. from httpx import AsyncClient
  5.  
  6. from tests.conftest import USER_TOKEN
  7. from zeroentropy.logger import get_logger
  8. from zeroentropy.routes.chat_interaction import (
  9.     APIChatInteractionStatus,
  10.     CreateChatInteractionRequest,
  11.     GetListChatInteractionRequest,
  12.     GetListChatInteractionResponse,
  13. )
  14. from zeroentropy.routes.chat_session import (
  15.     CreateChatSessionRequest,
  16.     GetListChatSessionResponse,
  17. )
  18. from zeroentropy.routes.helpers.http_utils import GenericMessageResponse
  19.  
  20. logger = get_logger()
  21.  
  22.  
  23. async def test_chat(client: AsyncClient) -> None:
  24.     headers = {
  25.         "Authorization": f"Bearer {USER_TOKEN}",
  26.     }
  27.  
  28.     # Create a random name
  29.     TEST_NAME = f"Example Chat Session <{uuid4()}>"
  30.  
  31.     # Create a new chat session
  32.     chat_session_id = uuid4()
  33.     response = await client.post(
  34.         "/chat-sessions/create",
  35.         json=CreateChatSessionRequest(
  36.             id=chat_session_id,
  37.             title=TEST_NAME,
  38.         ).model_dump(mode="json"),
  39.         headers=headers,
  40.     )
  41.     logger.info(response.json())
  42.     assert response.status_code == 200
  43.  
  44.     # Get all the chat sessions
  45.     response = await client.post(
  46.         "/chat-sessions/get-list",
  47.         json={},
  48.         headers=headers,
  49.     )
  50.     assert response.status_code == 200
  51.     chat_sessions = GetListChatSessionResponse.model_validate(response.json())
  52.  
  53.     assert len(chat_sessions.chat_sessions) == 1
  54.     assert chat_sessions.chat_sessions[0].title == TEST_NAME
  55.  
  56.     # Add a chat interaction
  57.     chat_interaction_id = uuid4()
  58.     response = await client.post(
  59.         "/chat-interactions/create",
  60.         json=CreateChatInteractionRequest(
  61.             id=chat_interaction_id,
  62.             chat_session_id=chat_session_id,
  63.             message="What is 2+2?",
  64.         ).model_dump(mode="json"),
  65.         headers=headers,
  66.     )
  67.     logger.info(response.json())
  68.     assert response.status_code == 200
  69.     GenericMessageResponse.model_validate(response.json())
  70.  
  71.     # Wait for the chat interaction to be processed
  72.     response_data = None
  73.     for _ in range(10):  # Try for up to 10 seconds
  74.         response = await client.post(
  75.             "/chat-interactions/get-list",
  76.             json=GetListChatInteractionRequest(
  77.                 chat_session_id=chat_session_id
  78.             ).model_dump(mode="json"),
  79.             headers=headers,
  80.         )
  81.         assert response.status_code == 200
  82.         response_data = GetListChatInteractionResponse.model_validate(response.json())
  83.         if (
  84.             len(response_data.chat_interactions) > 0
  85.             and response_data.chat_interactions[0].status
  86.             == APIChatInteractionStatus.DONE
  87.         ):
  88.             break
  89.         await asyncio.sleep(1)
  90.  
  91.     assert response_data is not None
  92.     assert len(response_data.chat_interactions) > 0
  93.     assert response_data.chat_interactions[0].status == APIChatInteractionStatus.DONE
  94.     assert response_data.chat_interactions[0].ai_message is not None
  95.     # 2+2 = 4
  96.     assert "4" in response_data.chat_interactions[0].ai_message.message
  97.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement