Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import asyncio
- from uuid import uuid4
- from httpx import AsyncClient
- from tests.conftest import USER_TOKEN
- from zeroentropy.logger import get_logger
- from zeroentropy.routes.chat_interaction import (
- APIChatInteractionStatus,
- CreateChatInteractionRequest,
- GetListChatInteractionRequest,
- GetListChatInteractionResponse,
- )
- from zeroentropy.routes.chat_session import (
- CreateChatSessionRequest,
- GetListChatSessionResponse,
- )
- from zeroentropy.routes.helpers.http_utils import GenericMessageResponse
- logger = get_logger()
- async def test_chat(client: AsyncClient) -> None:
- headers = {
- "Authorization": f"Bearer {USER_TOKEN}",
- }
- # Create a random name
- TEST_NAME = f"Example Chat Session <{uuid4()}>"
- # Create a new chat session
- chat_session_id = uuid4()
- response = await client.post(
- "/chat-sessions/create",
- json=CreateChatSessionRequest(
- id=chat_session_id,
- title=TEST_NAME,
- ).model_dump(mode="json"),
- headers=headers,
- )
- logger.info(response.json())
- assert response.status_code == 200
- # Get all the chat sessions
- response = await client.post(
- "/chat-sessions/get-list",
- json={},
- headers=headers,
- )
- assert response.status_code == 200
- chat_sessions = GetListChatSessionResponse.model_validate(response.json())
- assert len(chat_sessions.chat_sessions) == 1
- assert chat_sessions.chat_sessions[0].title == TEST_NAME
- # Add a chat interaction
- chat_interaction_id = uuid4()
- response = await client.post(
- "/chat-interactions/create",
- json=CreateChatInteractionRequest(
- id=chat_interaction_id,
- chat_session_id=chat_session_id,
- message="What is 2+2?",
- ).model_dump(mode="json"),
- headers=headers,
- )
- logger.info(response.json())
- assert response.status_code == 200
- GenericMessageResponse.model_validate(response.json())
- # Wait for the chat interaction to be processed
- response_data = None
- for _ in range(10): # Try for up to 10 seconds
- response = await client.post(
- "/chat-interactions/get-list",
- json=GetListChatInteractionRequest(
- chat_session_id=chat_session_id
- ).model_dump(mode="json"),
- headers=headers,
- )
- assert response.status_code == 200
- response_data = GetListChatInteractionResponse.model_validate(response.json())
- if (
- len(response_data.chat_interactions) > 0
- and response_data.chat_interactions[0].status
- == APIChatInteractionStatus.DONE
- ):
- break
- await asyncio.sleep(1)
- assert response_data is not None
- assert len(response_data.chat_interactions) > 0
- assert response_data.chat_interactions[0].status == APIChatInteractionStatus.DONE
- assert response_data.chat_interactions[0].ai_message is not None
- # 2+2 = 4
- assert "4" in response_data.chat_interactions[0].ai_message.message
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement