Skip to content
This repository was archived by the owner on Apr 23, 2026. It is now read-only.

Commit 16807d8

Browse files
yeesiancopybara-github
authored andcommitted
chore: add unit tests for ModuleAgent template
PiperOrigin-RevId: 753638103
1 parent 3289d92 commit 16807d8

2 files changed

Lines changed: 113 additions & 0 deletions

File tree

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
from vertexai import agent_engines
16+
from test_constants import test_agent
17+
18+
_TEST_MODULE_NAME = "test_constants"
19+
_TEST_AGENT_NAME = "test_agent"
20+
_TEST_REGISTER_OPERATIONS = {"": ["query"], "stream": ["stream_query"]}
21+
_TEST_QUERY_INPUT = "test query"
22+
_TEST_STREAM_QUERY_INPUT = 5
23+
24+
25+
class TestModuleAgent:
26+
def test_initialization(self):
27+
agent = agent_engines.ModuleAgent(
28+
module_name=_TEST_MODULE_NAME,
29+
agent_name=_TEST_AGENT_NAME,
30+
register_operations=_TEST_REGISTER_OPERATIONS,
31+
)
32+
assert agent._tmpl_attrs.get("module_name") == _TEST_MODULE_NAME
33+
assert agent._tmpl_attrs.get("agent_name") == _TEST_AGENT_NAME
34+
assert agent._tmpl_attrs.get("register_operations") == _TEST_REGISTER_OPERATIONS
35+
36+
def test_set_up(self):
37+
agent = agent_engines.ModuleAgent(
38+
module_name=_TEST_MODULE_NAME,
39+
agent_name=_TEST_AGENT_NAME,
40+
register_operations=_TEST_REGISTER_OPERATIONS,
41+
)
42+
assert agent._tmpl_attrs.get("agent") is None
43+
agent.set_up()
44+
assert agent._tmpl_attrs.get("agent") is not None
45+
46+
def test_clone(self):
47+
agent = agent_engines.ModuleAgent(
48+
module_name=_TEST_MODULE_NAME,
49+
agent_name=_TEST_AGENT_NAME,
50+
register_operations=_TEST_REGISTER_OPERATIONS,
51+
)
52+
agent.set_up()
53+
assert agent._tmpl_attrs.get("agent") is not None
54+
agent_clone = agent.clone()
55+
assert agent._tmpl_attrs.get("agent") is not None
56+
assert agent_clone._tmpl_attrs.get("agent") is None
57+
agent_clone.set_up()
58+
assert agent_clone._tmpl_attrs.get("agent") is not None
59+
60+
def test_query(self):
61+
agent = agent_engines.ModuleAgent(
62+
module_name=_TEST_MODULE_NAME,
63+
agent_name=_TEST_AGENT_NAME,
64+
register_operations=_TEST_REGISTER_OPERATIONS,
65+
)
66+
agent.set_up()
67+
got_result = agent.query(input=_TEST_QUERY_INPUT)
68+
expected_result = agent._tmpl_attrs.get("agent").query(input=_TEST_QUERY_INPUT)
69+
assert got_result == expected_result
70+
expected_result = test_agent.query(input=_TEST_QUERY_INPUT)
71+
assert got_result == expected_result
72+
73+
def test_stream_query(self):
74+
agent = agent_engines.ModuleAgent(
75+
module_name=_TEST_MODULE_NAME,
76+
agent_name=_TEST_AGENT_NAME,
77+
register_operations=_TEST_REGISTER_OPERATIONS,
78+
)
79+
agent.set_up()
80+
for got_result, expected_result in zip(
81+
agent.stream_query(n=_TEST_STREAM_QUERY_INPUT),
82+
agent._tmpl_attrs.get("agent").stream_query(n=_TEST_STREAM_QUERY_INPUT),
83+
):
84+
assert got_result == expected_result
85+
for got_result, expected_result in zip(
86+
agent.stream_query(n=_TEST_STREAM_QUERY_INPUT),
87+
test_agent.stream_query(n=_TEST_STREAM_QUERY_INPUT),
88+
):
89+
assert got_result == expected_result
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
class _CustomAgent:
16+
def query(self, input: str):
17+
return input
18+
19+
def stream_query(self, n: int):
20+
for i in range(n):
21+
yield i
22+
23+
24+
test_agent = _CustomAgent()

0 commit comments

Comments
 (0)