forked from mistralai/client-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfunction_calling.py
More file actions
135 lines (113 loc) · 3.94 KB
/
function_calling.py
File metadata and controls
135 lines (113 loc) · 3.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import functools
import json
import os
from typing import Dict, List
from mistralai import Mistral
from mistralai.models.assistantmessage import AssistantMessage
from mistralai.models.function import Function
from mistralai.models.toolmessage import ToolMessage
from mistralai.models.usermessage import UserMessage
# Assuming we have the following data
data = {
"transaction_id": ["T1001", "T1002", "T1003", "T1004", "T1005"],
"customer_id": ["C001", "C002", "C003", "C002", "C001"],
"payment_amount": [125.50, 89.99, 120.00, 54.30, 210.20],
"payment_date": [
"2021-10-05",
"2021-10-06",
"2021-10-07",
"2021-10-05",
"2021-10-08",
],
"payment_status": ["Paid", "Unpaid", "Paid", "Paid", "Pending"],
}
def retrieve_payment_status(data: Dict[str, List], transaction_id: str) -> str:
for i, r in enumerate(data["transaction_id"]):
if r == transaction_id:
return json.dumps({"status": data["payment_status"][i]})
else:
return json.dumps({"status": "Error - transaction id not found"})
def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
for i, r in enumerate(data["transaction_id"]):
if r == transaction_id:
return json.dumps({"date": data["payment_date"][i]})
else:
return json.dumps({"status": "Error - transaction id not found"})
names_to_functions = {
"retrieve_payment_status": functools.partial(retrieve_payment_status, data=data),
"retrieve_payment_date": functools.partial(retrieve_payment_date, data=data),
}
tools = [
{
"type": "function",
"function": Function(
name="retrieve_payment_status",
description="Get payment status of a transaction id",
parameters={
"type": "object",
"required": ["transaction_id"],
"properties": {
"transaction_id": {
"type": "string",
"description": "The transaction id.",
}
},
},
),
},
{
"type": "function",
"function": Function(
name="retrieve_payment_date",
description="Get payment date of a transaction id",
parameters={
"type": "object",
"required": ["transaction_id"],
"properties": {
"transaction_id": {
"type": "string",
"description": "The transaction id.",
}
},
},
),
},
]
api_key = os.environ["MISTRAL_API_KEY"]
model = "mistral-small-latest"
client = Mistral(api_key=api_key)
messages = [UserMessage(content="What's the status of my transaction?")]
response = client.chat.complete(
model=model, messages=messages, tools=tools, temperature=0
)
print(response.choices[0].message.content)
messages.append(AssistantMessage(content=response.choices[0].message.content))
messages.append(UserMessage(content="My transaction ID is T1001."))
response = client.chat.complete(
model=model, messages=messages, tools=tools, temperature=0
)
tool_call = response.choices[0].message.tool_calls[0]
function_name = tool_call.function.name
function_params = json.loads(tool_call.function.arguments)
print(
f"calling function_name: {function_name}, with function_params: {function_params}"
)
function_result = names_to_functions[function_name](**function_params)
messages.append(
AssistantMessage(
content=response.choices[0].message.content,
tool_calls=response.choices[0].message.tool_calls,
)
)
messages.append(
ToolMessage(
name=function_name,
content=function_result,
tool_call_id=tool_call.id,
)
)
print(messages)
response = client.chat.complete(
model=model, messages=messages, tools=tools, temperature=0
)
print(f"{response.choices[0].message.content}")