-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path7_agent_3.py
More file actions
82 lines (64 loc) · 2.58 KB
/
7_agent_3.py
File metadata and controls
82 lines (64 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import Annotated, Sequence, TypedDict
from dotenv import load_dotenv
from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolMessage
from langchain_core.messages import SystemMessage
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode
load_dotenv()
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], add_messages]
@tool
def add(a: int, b: int) -> int:
"""Addition function to add two numbers."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtraction function to subtract two numbers."""
return a - b
@tool
def multiply(a: int, b:int) -> int:
"""This function is used to multiply two nubers.
If you want to calculate square of a number pass the same value in two arguements"""
return a * b
@tool
def division(a: int, b: int) -> int:
"""This function is used to calculate the division"""
return a / b
tools = [add, subtract, multiply, division]
model = ChatOpenAI(model="gpt-4o").bind_tools(tools)
def model_call(state: AgentState) -> AgentState:
system_message = SystemMessage(content="You are my AI assistant help me choose the proper tool to the best of your ability.")
result = model.invoke([system_message]+state["messages"])
return {"messages": [result]}
def should_continue(state: AgentState) -> bool:
messages = state["messages"]
last_message = messages[-1]
if not last_message.tool_calls:
return "end"
return "continue"
graph = StateGraph(AgentState)
graph.add_node("our_agent", model_call)
tool_node = ToolNode(tools=tools)
graph.add_node("tools", tool_node)
graph.add_edge(START, "our_agent")
graph.add_conditional_edges("our_agent",
should_continue,
{"continue": "tools",
"end": END
})
app = graph.compile()
def print_stream(stream):
for s in stream:
message = s["messages"][-1]
if(isinstance(message, tuple)):
print(message)
else:
message.pretty_print()
inputs = {"messages": [("user", "Hello How are you can you add two numbers 5, 6 and calculate square of the result")]}
print_stream(app.stream(inputs, stream_mode="values"))
inputs = {"messages": [("user", "Hello How are you can you add two numbers 0, 0 and calculate division of the result by same number")]}
print_stream(app.stream(inputs, stream_mode="values"))