Skip to content

Commit 2f731d7

Browse files
committed
added tests for endoints
1 parent 9b5d2c5 commit 2f731d7

13 files changed

Lines changed: 715 additions & 3 deletions

.github/workflows/build_publish.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ jobs:
4545
run: |
4646
poetry run mypy .
4747
48+
# Tests
49+
- name: Run Tests
50+
run: |
51+
poetry run pytest .
52+
4853
publish:
4954
if: startsWith(github.ref, 'refs/tags')
5055
runs-on: ubuntu-latest

examples/async_chat_no_streaming.py

100644100755
File mode changed.

examples/async_chat_with_streaming.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ async def main():
1212
client = MistralAsyncClient(api_key=api_key)
1313

1414
print("Chat response:")
15-
async for chunk in client.chat_stream(
15+
response = client.chat_stream(
1616
model=model,
1717
messages=[ChatMessage(role="user", content="What is the best French cheese?")],
18-
):
18+
)
19+
20+
async for chunk in response:
1921
if chunk.choices[0].delta.content is not None:
2022
print(chunk.choices[0].delta.content, end="")
2123

poetry.lock

Lines changed: 89 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ httpx = "^0.25.2"
3333
ruff = "^0.1.6"
3434
mypy = "^1.7.1"
3535
types-requests = "^2.31.0.10"
36+
pytest = "^7.4.3"
37+
pytest-asyncio = "^0.23.2"
3638

3739
[build-system]
3840
requires = ["poetry-core"]

tests/__init__.py

Whitespace-only changes.

tests/test_chat.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pytest
2+
import unittest.mock as mock
3+
from mistralai.client import MistralClient
4+
from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage, ChatCompletionStreamResponse
5+
from .utils import mock_response, mock_stream_response, mock_chat_response_payload, mock_chat_response_streaming_payload
6+
7+
@pytest.fixture()
8+
def client():
9+
client = MistralClient()
10+
client._client = mock.MagicMock()
11+
return client
12+
13+
14+
class TestChat:
15+
def test_chat(self, client):
16+
client._client.request.return_value = mock_response(
17+
200,
18+
mock_chat_response_payload(),
19+
)
20+
21+
result = client.chat(
22+
model="mistral-small",
23+
messages=[ChatMessage(role="user", content="What is the best French cheese?")],
24+
)
25+
26+
client._client.request.assert_called_once_with(
27+
"post",
28+
"https://api.mistral.ai/v1/chat/completions",
29+
headers={
30+
"Accept": "application/json",
31+
"Authorization": "Bearer None",
32+
"Content-Type": "application/json",
33+
},
34+
json={'model': 'mistral-small', 'messages': [{'role': 'user', 'content': 'What is the best French cheese?'}], 'safe_prompt': False, 'stream': False},
35+
)
36+
37+
38+
assert isinstance(
39+
result, ChatCompletionResponse
40+
), "Should return an ChatCompletionResponse"
41+
assert len(result.choices) == 1
42+
assert result.choices[0].index == 0
43+
assert result.object == "chat.completion"
44+
45+
46+
def test_chat_streaming(self, client):
47+
client._client.stream.return_value = mock_stream_response(
48+
200,
49+
mock_chat_response_streaming_payload(),
50+
)
51+
52+
result = client.chat_stream(
53+
model="mistral-small",
54+
messages=[ChatMessage(role="user", content="What is the best French cheese?")],
55+
)
56+
57+
results = list(result)
58+
59+
client._client.stream.assert_called_once_with(
60+
"post",
61+
"https://api.mistral.ai/v1/chat/completions",
62+
headers={
63+
"Accept": "application/json",
64+
"Authorization": "Bearer None",
65+
"Content-Type": "application/json",
66+
},
67+
json={'model': 'mistral-small', 'messages': [{'role': 'user', 'content': 'What is the best French cheese?'}], 'safe_prompt': False, 'stream': True},
68+
)
69+
70+
for i, result in enumerate(results):
71+
if i == 0:
72+
assert isinstance(
73+
result, ChatCompletionStreamResponse
74+
), "Should return an ChatCompletionStreamResponse"
75+
assert len(result.choices) == 1
76+
assert result.choices[0].index == 0
77+
assert result.choices[0].delta.role == "assistant"
78+
else:
79+
assert isinstance(
80+
result, ChatCompletionStreamResponse
81+
), "Should return an ChatCompletionStreamResponse"
82+
assert len(result.choices) == 1
83+
assert result.choices[0].index == i-1
84+
assert result.choices[0].delta.content == f"stream response {i-1}"
85+
assert result.object == "chat.completion.chunk"

tests/test_chat_async.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import pytest
2+
import unittest.mock as mock
3+
from mistralai.async_client import MistralAsyncClient
4+
from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage, ChatCompletionStreamResponse
5+
from .utils import mock_response, mock_async_stream_response, mock_chat_response_payload, mock_chat_response_streaming_payload
6+
7+
@pytest.fixture()
8+
def client():
9+
client = MistralAsyncClient()
10+
client._client = mock.AsyncMock()
11+
client._client.stream = mock.Mock()
12+
return client
13+
14+
15+
class TestAsyncChat:
16+
@pytest.mark.asyncio
17+
async def test_chat(self, client):
18+
client._client.request.return_value = mock_response(
19+
200,
20+
mock_chat_response_payload(),
21+
)
22+
23+
result = await client.chat(
24+
model="mistral-small",
25+
messages=[ChatMessage(role="user", content="What is the best French cheese?")],
26+
)
27+
28+
client._client.request.assert_awaited_once_with(
29+
"post",
30+
"https://api.mistral.ai/v1/chat/completions",
31+
headers={
32+
"Accept": "application/json",
33+
"Authorization": "Bearer None",
34+
"Content-Type": "application/json",
35+
},
36+
json={'model': 'mistral-small', 'messages': [{'role': 'user', 'content': 'What is the best French cheese?'}], 'safe_prompt': False, 'stream': False},
37+
)
38+
39+
40+
assert isinstance(
41+
result, ChatCompletionResponse
42+
), "Should return an ChatCompletionResponse"
43+
assert len(result.choices) == 1
44+
assert result.choices[0].index == 0
45+
assert result.object == "chat.completion"
46+
47+
@pytest.mark.asyncio
48+
async def test_chat_streaming(self, client):
49+
client._client.stream.return_value = mock_async_stream_response(
50+
200,
51+
mock_chat_response_streaming_payload(),
52+
)
53+
54+
result = client.chat_stream(
55+
model="mistral-small",
56+
messages=[ChatMessage(role="user", content="What is the best French cheese?")],
57+
)
58+
59+
results = [r async for r in result]
60+
61+
client._client.stream.assert_called_once_with(
62+
"post",
63+
"https://api.mistral.ai/v1/chat/completions",
64+
headers={
65+
"Accept": "application/json",
66+
"Authorization": "Bearer None",
67+
"Content-Type": "application/json",
68+
},
69+
json={'model': 'mistral-small', 'messages': [{'role': 'user', 'content': 'What is the best French cheese?'}], 'safe_prompt': False, 'stream': True},
70+
)
71+
72+
for i, result in enumerate(results):
73+
if i == 0:
74+
assert isinstance(
75+
result, ChatCompletionStreamResponse
76+
), "Should return an ChatCompletionStreamResponse"
77+
assert len(result.choices) == 1
78+
assert result.choices[0].index == 0
79+
assert result.choices[0].delta.role == "assistant"
80+
else:
81+
assert isinstance(
82+
result, ChatCompletionStreamResponse
83+
), "Should return an ChatCompletionStreamResponse"
84+
assert len(result.choices) == 1
85+
assert result.choices[0].index == i-1
86+
assert result.choices[0].delta.content == f"stream response {i-1}"
87+
assert result.object == "chat.completion.chunk"

0 commit comments

Comments
 (0)