Skip to content

Commit 894c73f

Browse files
yeesiancopybara-github
authored andcommitted
chore: Add unit tests for langchain template.
PiperOrigin-RevId: 624236094
1 parent 3842d26 commit 894c73f

4 files changed

Lines changed: 212 additions & 0 deletions

File tree

noxfile.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
nox.options.sessions = [
6969
"unit",
7070
"unit_ray",
71+
"unit_langchain",
7172
"system",
7273
"cover",
7374
"lint",
@@ -181,6 +182,7 @@ def default(session):
181182
"--cov-report=",
182183
"--cov-fail-under=0",
183184
"--ignore=tests/unit/vertex_ray",
185+
"--ignore=tests/unit/vertex_langchain",
184186
os.path.join("tests", "unit"),
185187
*session.posargs,
186188
)
@@ -219,6 +221,32 @@ def unit_ray(session, ray):
219221
)
220222

221223

224+
@nox.session(python=UNIT_TEST_PYTHON_VERSIONS)
225+
def unit_langchain(session):
226+
# Install all test dependencies, then install this package in-place.
227+
228+
constraints_path = str(CURRENT_DIRECTORY / "testing" / f"constraints-langchain.txt")
229+
standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES
230+
session.install(*standard_deps, "-c", constraints_path)
231+
232+
# Install langchain extras
233+
session.install("-e", ".[langchain_testing]", "-c", constraints_path)
234+
235+
# Run py.test against the unit tests.
236+
session.run(
237+
"py.test",
238+
"--quiet",
239+
f"--junitxml=unit_langchain_sponge_log.xml",
240+
"--cov=google",
241+
"--cov-append",
242+
"--cov-config=.coveragerc",
243+
"--cov-report=",
244+
"--cov-fail-under=0",
245+
os.path.join("tests", "unit", "vertex_langchain"),
246+
*session.posargs,
247+
)
248+
249+
222250
def install_systemtest_dependencies(session, *constraints):
223251
# Use pre-release gRPC for system tests.
224252
# Exclude version 1.52.0rc1 which has a known issue.

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@
260260
"reasoningengine": reasoning_engine_extra_require,
261261
"rapid_evaluation": rapid_evaluation_extra_require,
262262
"langchain": langchain_extra_require,
263+
"langchain_testing": langchain_extra_require,
263264
},
264265
python_requires=">=3.8",
265266
classifiers=[

testing/constraints-langchain.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
langchain==0.1.15
2+
langchain-core==0.1.40
3+
langchain-google-vertexai==0.1.2
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
import importlib
16+
import json
17+
from typing import Optional
18+
from unittest import mock
19+
20+
from google import auth
21+
from google.auth import credentials as auth_credentials
22+
import vertexai
23+
from google.cloud.aiplatform import initializer
24+
from vertexai.preview import reasoning_engines
25+
import pytest
26+
27+
from langchain_core import agents
28+
from langchain_core import messages
29+
from langchain_core import outputs
30+
from langchain_core import tools as lc_tools
31+
32+
33+
_DEFAULT_PLACE_TOOL_ACTIVITY = "museums"
34+
_DEFAULT_PLACE_TOOL_PAGE_SIZE = 3
35+
_DEFAULT_PLACE_PHOTO_MAXWIDTH = 400
36+
_TEST_LOCATION = "us-central1"
37+
_TEST_PROJECT = "test-project"
38+
_TEST_MODEL = "gemini-1.0-pro"
39+
40+
41+
def place_tool_query(
42+
city: str,
43+
activity: str = _DEFAULT_PLACE_TOOL_ACTIVITY,
44+
page_size: int = _DEFAULT_PLACE_TOOL_PAGE_SIZE,
45+
):
46+
"""Searches the city for recommendations on the activity."""
47+
return {"city": city, "activity": activity, "page_size": page_size}
48+
49+
50+
def place_photo_query(
51+
photo_reference: str,
52+
maxwidth: int = _DEFAULT_PLACE_PHOTO_MAXWIDTH,
53+
maxheight: Optional[int] = None,
54+
):
55+
"""Returns the photo for a given reference."""
56+
result = {"photo_reference": photo_reference, "maxwidth": maxwidth}
57+
if maxheight:
58+
result["maxheight"] = maxheight
59+
return result
60+
61+
62+
@pytest.fixture(scope="module")
63+
def google_auth_mock():
64+
with mock.patch.object(auth, "default") as google_auth_mock:
65+
google_auth_mock.return_value = (
66+
auth_credentials.AnonymousCredentials(),
67+
_TEST_PROJECT,
68+
)
69+
yield google_auth_mock
70+
71+
72+
@pytest.fixture
73+
def vertexai_init_mock():
74+
with mock.patch.object(vertexai, "init") as vertexai_init_mock:
75+
yield vertexai_init_mock
76+
77+
78+
@pytest.mark.usefixtures("google_auth_mock")
79+
class TestLangchainAgent:
80+
def setup_method(self):
81+
importlib.reload(initializer)
82+
importlib.reload(vertexai)
83+
vertexai.init(
84+
project=_TEST_PROJECT,
85+
location=_TEST_LOCATION,
86+
)
87+
88+
def teardown_method(self):
89+
initializer.global_pool.shutdown(wait=True)
90+
91+
def test_initialization(self):
92+
agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL)
93+
assert agent._model_name == _TEST_MODEL
94+
assert agent._project == _TEST_PROJECT
95+
assert agent._location == _TEST_LOCATION
96+
assert agent._runnable is None
97+
98+
def test_initialization_with_tools(self):
99+
agent = reasoning_engines.LangchainAgent(
100+
model=_TEST_MODEL,
101+
tools=[
102+
place_tool_query,
103+
place_photo_query,
104+
],
105+
)
106+
for tool in agent._tools:
107+
assert isinstance(tool, lc_tools.BaseTool)
108+
109+
def test_set_up(self, vertexai_init_mock):
110+
agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL)
111+
assert agent._runnable is None
112+
agent.set_up()
113+
assert agent._runnable is not None
114+
115+
def test_query(self):
116+
agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL)
117+
agent._runnable = mock.Mock()
118+
mocks = mock.Mock()
119+
mocks.attach_mock(mock=agent._runnable, attribute="invoke")
120+
agent.query(input="test query")
121+
mocks.assert_has_calls(
122+
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
123+
)
124+
125+
126+
class TestDefaultOutputParser:
127+
def test_parse_result_function_call(self, vertexai_init_mock):
128+
agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL)
129+
agent.set_up()
130+
tool_input = {
131+
"photo_reference": "abcd1234",
132+
"maxwidth": _DEFAULT_PLACE_PHOTO_MAXWIDTH,
133+
}
134+
result = agent._output_parser.parse_result(
135+
[
136+
outputs.ChatGeneration(
137+
message=messages.AIMessage(
138+
content="",
139+
additional_kwargs={
140+
"function_call": {
141+
"name": "place_tool_query",
142+
"arguments": json.dumps(tool_input),
143+
},
144+
},
145+
)
146+
)
147+
]
148+
)
149+
assert isinstance(result, agents.AgentActionMessageLog)
150+
assert result.tool == "place_tool_query"
151+
assert result.tool_input == tool_input
152+
153+
def test_parse_result_not_function_call(self, vertexai_init_mock):
154+
agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL)
155+
agent.set_up()
156+
content = "test content"
157+
result = agent._output_parser.parse_result(
158+
[
159+
outputs.ChatGeneration(
160+
message=messages.AIMessage(content=content),
161+
)
162+
]
163+
)
164+
assert isinstance(result, agents.AgentFinish)
165+
assert result.return_values == {"output": content}
166+
assert result.log == content
167+
168+
169+
class TestDefaultOutputParserErrors:
170+
def test_parse_result_non_chat_generation_errors(self, vertexai_init_mock):
171+
agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL)
172+
agent.set_up()
173+
with pytest.raises(ValueError, match=r"only works on ChatGeneration"):
174+
agent._output_parser.parse_result(["text"])
175+
176+
def test_parse_text_errors(self, vertexai_init_mock):
177+
agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL)
178+
agent.set_up()
179+
with pytest.raises(ValueError, match=r"Can only parse messages"):
180+
agent._output_parser.parse("text")

0 commit comments

Comments
 (0)