Skip to content

Commit a071d29

Browse files
authored
Merge pull request mistralai#39 from mistralai/bam4d/safe_prompt
adding safe_prompt to calls and deprecating safe_mode
2 parents 92c16de + 3ee19fd commit a071d29

3 files changed

Lines changed: 18 additions & 10 deletions

File tree

src/mistralai/async_client.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ async def chat(
133133
top_p: Optional[float] = None,
134134
random_seed: Optional[int] = None,
135135
safe_mode: bool = False,
136+
safe_prompt: bool = False,
136137
) -> ChatCompletionResponse:
137138
"""A asynchronous chat endpoint that returns a single response.
138139
@@ -145,7 +146,8 @@ async def chat(
145146
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
146147
Defaults to None.
147148
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
148-
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
149+
safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False.
150+
safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False.
149151
150152
Returns:
151153
ChatCompletionResponse: a response object containing the generated text.
@@ -158,7 +160,7 @@ async def chat(
158160
top_p=top_p,
159161
random_seed=random_seed,
160162
stream=False,
161-
safe_mode=safe_mode,
163+
safe_prompt=safe_mode or safe_prompt,
162164
)
163165

164166
single_response = self._request("post", request, "v1/chat/completions")
@@ -177,6 +179,7 @@ async def chat_stream(
177179
top_p: Optional[float] = None,
178180
random_seed: Optional[int] = None,
179181
safe_mode: bool = False,
182+
safe_prompt: bool = False,
180183
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
181184
"""An Asynchronous chat endpoint that streams responses.
182185
@@ -189,7 +192,8 @@ async def chat_stream(
189192
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
190193
Defaults to None.
191194
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
192-
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
195+
safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False.
196+
safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False.
193197
194198
Returns:
195199
AsyncGenerator[ChatCompletionStreamResponse, None]:
@@ -204,7 +208,7 @@ async def chat_stream(
204208
top_p=top_p,
205209
random_seed=random_seed,
206210
stream=True,
207-
safe_mode=safe_mode,
211+
safe_prompt=safe_mode or safe_prompt,
208212
)
209213
async_response = self._request(
210214
"post", request, "v1/chat/completions", stream=True

src/mistralai/client.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def chat(
125125
top_p: Optional[float] = None,
126126
random_seed: Optional[int] = None,
127127
safe_mode: bool = False,
128+
safe_prompt: bool = False,
128129
) -> ChatCompletionResponse:
129130
"""A chat endpoint that returns a single response.
130131
@@ -137,7 +138,8 @@ def chat(
137138
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
138139
Defaults to None.
139140
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
140-
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
141+
safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False.
142+
safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False.
141143
142144
Returns:
143145
ChatCompletionResponse: a response object containing the generated text.
@@ -150,7 +152,7 @@ def chat(
150152
top_p=top_p,
151153
random_seed=random_seed,
152154
stream=False,
153-
safe_mode=safe_mode,
155+
safe_prompt=safe_mode or safe_prompt,
154156
)
155157

156158
single_response = self._request("post", request, "v1/chat/completions")
@@ -169,6 +171,7 @@ def chat_stream(
169171
top_p: Optional[float] = None,
170172
random_seed: Optional[int] = None,
171173
safe_mode: bool = False,
174+
safe_prompt: bool = False,
172175
) -> Iterable[ChatCompletionStreamResponse]:
173176
"""A chat endpoint that streams responses.
174177
@@ -181,7 +184,8 @@ def chat_stream(
181184
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
182185
Defaults to None.
183186
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
184-
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
187+
safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False.
188+
safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False.
185189
186190
Returns:
187191
Iterable[ChatCompletionStreamResponse]:
@@ -195,7 +199,7 @@ def chat_stream(
195199
top_p=top_p,
196200
random_seed=random_seed,
197201
stream=True,
198-
safe_mode=safe_mode,
202+
safe_prompt=safe_mode or safe_prompt,
199203
)
200204

201205
response = self._request("post", request, "v1/chat/completions", stream=True)

src/mistralai/client_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ def _make_chat_request(
4747
top_p: Optional[float] = None,
4848
random_seed: Optional[int] = None,
4949
stream: Optional[bool] = None,
50-
safe_mode: Optional[bool] = False,
50+
safe_prompt: Optional[bool] = False,
5151
) -> Dict[str, Any]:
5252
request_data: Dict[str, Any] = {
5353
"model": model,
5454
"messages": [msg.model_dump() for msg in messages],
55-
"safe_prompt": safe_mode,
55+
"safe_prompt": safe_prompt,
5656
}
5757
if temperature is not None:
5858
request_data["temperature"] = temperature

0 commit comments

Comments
 (0)