Skip to content

Commit 0d55818

Browse files
author
GitHub Actions
committed
Update version to 0.1.8
1 parent 63be8a2 commit 0d55818

5 files changed

Lines changed: 239 additions & 483 deletions

File tree

examples/function_calling.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import functools
22
import json
33
import os
4+
from typing import Dict, List
45

5-
import pandas as pd
66
from mistralai.client import MistralClient
77
from mistralai.models.chat_completion import ChatMessage, Function
88

@@ -15,30 +15,27 @@
1515
"payment_status": ["Paid", "Unpaid", "Paid", "Paid", "Pending"],
1616
}
1717

18-
# Create DataFrame
19-
df = pd.DataFrame(data)
18+
n_rows = len(data["transaction_id"])
2019

20+
def retrieve_payment_status(data: Dict[str,List], transaction_id: str) -> str:
21+
for i, r in enumerate(data["transaction_id"]):
22+
if r == transaction_id:
23+
return json.dumps({"status": data["payment_status"][i]})
24+
else:
25+
return json.dumps({"status": "Error - transaction id not found"})
2126

22-
def retrieve_payment_status(df: pd.DataFrame, transaction_id: str) -> str:
23-
if transaction_id in df.transaction_id.values:
24-
return json.dumps({"status": df[df.transaction_id == transaction_id].payment_status.item()})
25-
else:
26-
return json.dumps({"status": "error - transaction id not found."})
27-
28-
29-
def retrieve_payment_date(df: pd.DataFrame, transaction_id: str) -> str:
30-
if transaction_id in df.transaction_id.values:
31-
return json.dumps({"date": df[df.transaction_id == transaction_id].payment_date.item()})
32-
else:
33-
return json.dumps({"status": "error - transaction id not found."})
34-
27+
def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
28+
for i, r in enumerate(data["transaction_id"]):
29+
if r == transaction_id:
30+
return json.dumps({"date": data["payment_date"][i]})
31+
else:
32+
return json.dumps({"status": "Error - transaction id not found"})
3533

3634
names_to_functions = {
37-
"retrieve_payment_status": functools.partial(retrieve_payment_status, df=df),
38-
"retrieve_payment_date": functools.partial(retrieve_payment_date, df=df),
35+
"retrieve_payment_status": functools.partial(retrieve_payment_status, data=data),
36+
"retrieve_payment_date": functools.partial(retrieve_payment_date, data=data)
3937
}
4038

41-
4239
tools = [
4340
{
4441
"type": "function",
@@ -66,7 +63,6 @@ def retrieve_payment_date(df: pd.DataFrame, transaction_id: str) -> str:
6663
},
6764
]
6865

69-
7066
api_key = os.environ["MISTRAL_API_KEY"]
7167
model = "mistral-large-latest"
7268

examples/json_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def main():
1515
chat_response = client.chat(
1616
model=model,
1717
response_format={"type": "json_object"},
18-
messages=[ChatMessage(role="user", content="What is the best French cheese?")],
18+
messages=[ChatMessage(role="user", content="What is the best French cheese? Answer shortly in JSON.")],
1919

2020
)
2121
print(chat_response.choices[0].message.content)

0 commit comments

Comments
 (0)