forked from google/adk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathservice_registry.py
More file actions
428 lines (347 loc) · 14.2 KB
/
service_registry.py
File metadata and controls
428 lines (347 loc) · 14.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ADK Service Registry.
This module manages pluggable backend services for sessions, artifacts, and memory.
ADK includes built-in support for common backends like SQLite, PostgreSQL,
GCS, and Vertex AI Agent Engine. You can also extend ADK by registering
custom services.
There are two ways to register custom services:
1. YAML Configuration (Recommended for simple cases)
If your custom service can be instantiated with `MyService(uri="...", **kwargs)`,
you can register it without writing Python code by creating a `services.yaml`
or `services.yml` file in your agent directory (e.g., `my_agent/services.yaml`).
Example `services.yaml`:
```yaml
services:
- scheme: mysession
type: session
class: my_package.my_module.MyCustomSessionService
- scheme: mymemory
type: memory
class: my_package.other_module.MyCustomMemoryService
```
2. Python Registration (`services.py`)
For more complex initialization logic, create a `services.py` file in your
agent directory (e.g., `my_agent/services.py`). In this file, get the
registry instance and register your custom factory functions. This file can
be used for registration in addition to, or instead of, `services.yaml`.
Example `services.py`:
```python
from google.adk.cli.service_registry import get_service_registry
from my_package.my_module import MyCustomSessionService
def my_session_factory(uri: str, **kwargs):
# custom logic
return MyCustomSessionService(...)
get_service_registry().register_session_service("mysession", my_session_factory)
```
Note: If both `services.yaml` (or `.yml`) and `services.py` are present in the
same directory, services from **both** files will be loaded. YAML files are
processed first, then `services.py`. If the same service scheme is defined in
both, the definition in `services.py` will overwrite the one from YAML.
"""
from __future__ import annotations
import importlib
import logging
import os
from pathlib import Path
import sys
from typing import Any
from typing import Optional
from typing import Protocol
from urllib.parse import unquote
from urllib.parse import urlparse
from ..artifacts.base_artifact_service import BaseArtifactService
from ..memory.base_memory_service import BaseMemoryService
from ..sessions.base_session_service import BaseSessionService
from ..utils import yaml_utils
logger = logging.getLogger("google_adk." + __name__)
class ServiceFactory(Protocol):
"""Protocol for service factory functions."""
def __call__(
self, uri: str, **kwargs
) -> BaseSessionService | BaseArtifactService | BaseMemoryService:
...
class ServiceRegistry:
"""Registry for custom service URI schemes."""
def __init__(self):
self._session_factories: dict[str, ServiceFactory] = {}
self._artifact_factories: dict[str, ServiceFactory] = {}
self._memory_factories: dict[str, ServiceFactory] = {}
def register_session_service(
self, scheme: str, factory: ServiceFactory
) -> None:
"""Register a factory for a custom session service URI scheme.
Args:
scheme: URI scheme (e.g., 'custom')
factory: Callable that takes (uri, **kwargs) and returns
BaseSessionService
"""
self._session_factories[scheme] = factory
def register_artifact_service(
self, scheme: str, factory: ServiceFactory
) -> None:
"""Register a factory for a custom artifact service URI scheme."""
self._artifact_factories[scheme] = factory
def register_memory_service(
self, scheme: str, factory: ServiceFactory
) -> None:
"""Register a factory for a custom memory service URI scheme."""
self._memory_factories[scheme] = factory
def create_session_service(
self, uri: str, **kwargs
) -> BaseSessionService | None:
"""Create session service from URI using registered factories."""
scheme = urlparse(uri).scheme
if scheme and scheme in self._session_factories:
return self._session_factories[scheme](uri, **kwargs)
return None
def create_artifact_service(
self, uri: str, **kwargs
) -> BaseArtifactService | None:
"""Create artifact service from URI using registered factories."""
scheme = urlparse(uri).scheme
if scheme and scheme in self._artifact_factories:
return self._artifact_factories[scheme](uri, **kwargs)
return None
def create_memory_service(
self, uri: str, **kwargs
) -> BaseMemoryService | None:
"""Create memory service from URI using registered factories."""
scheme = urlparse(uri).scheme
if scheme and scheme in self._memory_factories:
return self._memory_factories[scheme](uri, **kwargs)
return None
def get_service_registry() -> ServiceRegistry:
"""Gets the singleton ServiceRegistry instance, initializing it if needed."""
global _service_registry_instance
if _service_registry_instance is None:
_service_registry_instance = ServiceRegistry()
_register_builtin_services(_service_registry_instance)
return _service_registry_instance
def load_services_module(agents_dir: str) -> None:
"""Load services.py or services.yaml from agents_dir for custom service registration.
If services.yaml or services.yml is found, it will be loaded first,
followed by services.py if it exists.
Skip if neither services.yaml/yml nor services.py is not found.
"""
if not os.path.isdir(agents_dir):
logger.debug(
"agents_dir %s is not a valid directory, skipping service loading.",
agents_dir,
)
return
if agents_dir not in sys.path:
sys.path.insert(0, agents_dir)
# Try loading services.yaml or services.yml first
for yaml_file in ["services.yaml", "services.yml"]:
yaml_path = os.path.join(agents_dir, yaml_file)
if os.path.exists(yaml_path):
try:
config = yaml_utils.load_yaml_file(yaml_path)
_register_services_from_yaml_config(config, get_service_registry())
logger.debug(
"Loaded custom services from %s in %s.", yaml_file, agents_dir
)
except Exception as e:
logger.warning(
"Failed to load %s from %s: %s",
yaml_file,
agents_dir,
e,
)
return # If yaml exists but fails to load, stop.
try:
importlib.import_module("services")
logger.debug(
"Loaded services.py from %s for custom service registration.",
agents_dir,
)
except ModuleNotFoundError:
logger.debug("services.py not found in %s, skipping.", agents_dir)
except Exception as e:
logger.warning(
"Failed to load services.py from %s: %s",
agents_dir,
e,
)
_service_registry_instance: ServiceRegistry | None = None
def _register_builtin_services(registry: ServiceRegistry) -> None:
"""Register built-in service implementations."""
# -- Session Services --
def memory_session_factory(uri: str, **kwargs):
from ..sessions.in_memory_session_service import InMemorySessionService
return InMemorySessionService()
def agentengine_session_factory(uri: str, **kwargs):
from ..sessions.vertex_ai_session_service import VertexAiSessionService
parsed = urlparse(uri)
params = _parse_agent_engine_kwargs(
parsed.netloc + parsed.path, kwargs.get("agents_dir")
)
return VertexAiSessionService(**params)
def database_session_factory(uri: str, **kwargs):
from ..sessions.database_session_service import DatabaseSessionService
kwargs_copy = kwargs.copy()
kwargs_copy.pop("agents_dir", None)
return DatabaseSessionService(db_url=uri, **kwargs_copy)
def sqlite_session_factory(uri: str, **kwargs):
from ..sessions.sqlite_session_service import SqliteSessionService
parsed = urlparse(uri)
db_path = parsed.path
if not db_path:
# Treat sqlite:// without a path as an in-memory session service.
return memory_session_factory("memory://", **kwargs)
elif db_path.startswith("/"):
db_path = db_path[1:]
kwargs_copy = kwargs.copy()
kwargs_copy.pop("agents_dir", None)
return SqliteSessionService(db_path=db_path, **kwargs_copy)
registry.register_session_service("memory", memory_session_factory)
registry.register_session_service("agentengine", agentengine_session_factory)
registry.register_session_service("sqlite", sqlite_session_factory)
for scheme in ["postgresql", "mysql"]:
registry.register_session_service(scheme, database_session_factory)
# -- Artifact Services --
def memory_artifact_factory(uri: str, **kwargs):
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
return InMemoryArtifactService()
def gcs_artifact_factory(uri: str, **kwargs):
from ..artifacts.gcs_artifact_service import GcsArtifactService
kwargs_copy = kwargs.copy()
kwargs_copy.pop("agents_dir", None)
kwargs_copy.pop("per_agent", None)
parsed_uri = urlparse(uri)
bucket_name = parsed_uri.netloc
return GcsArtifactService(bucket_name=bucket_name, **kwargs_copy)
def file_artifact_factory(uri: str, **_):
from ..artifacts.file_artifact_service import FileArtifactService
parsed_uri = urlparse(uri)
if parsed_uri.netloc not in ("", "localhost"):
raise ValueError(
"file:// artifact URIs must reference the local filesystem."
)
if not parsed_uri.path:
raise ValueError("file:// artifact URIs must include a path component.")
artifact_path = Path(unquote(parsed_uri.path))
return FileArtifactService(root_dir=artifact_path)
registry.register_artifact_service("memory", memory_artifact_factory)
registry.register_artifact_service("gs", gcs_artifact_factory)
registry.register_artifact_service("file", file_artifact_factory)
# -- Memory Services --
def rag_memory_factory(uri: str, **kwargs):
from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService
rag_corpus = urlparse(uri).netloc
if not rag_corpus:
raise ValueError("Rag corpus can not be empty.")
agents_dir = kwargs.get("agents_dir")
project, location = _load_gcp_config(agents_dir, "RAG memory service")
return VertexAiRagMemoryService(
rag_corpus=(
f"projects/{project}/locations/{location}/ragCorpora/{rag_corpus}"
)
)
def agentengine_memory_factory(uri: str, **kwargs):
from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
parsed = urlparse(uri)
params = _parse_agent_engine_kwargs(
parsed.netloc + parsed.path, kwargs.get("agents_dir")
)
return VertexAiMemoryBankService(**params)
registry.register_memory_service("rag", rag_memory_factory)
registry.register_memory_service("agentengine", agentengine_memory_factory)
def _load_gcp_config(
agents_dir: Optional[str], service_name: str
) -> tuple[str, str]:
"""Loads GCP project and location from environment."""
if not agents_dir:
raise ValueError(f"agents_dir must be provided for {service_name}")
from .utils import envs
envs.load_dotenv_for_agent("", agents_dir)
project = os.environ.get("GOOGLE_CLOUD_PROJECT")
location = os.environ.get("GOOGLE_CLOUD_LOCATION")
if not project or not location:
raise ValueError("GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_LOCATION not set.")
return project, location
def _parse_agent_engine_kwargs(
uri_part: str, agents_dir: Optional[str]
) -> dict[str, Any]:
"""Helper to parse agent engine resource name."""
if not uri_part:
raise ValueError(
"Agent engine resource name or resource id cannot be empty."
)
# If uri_part is just an ID, load project/location from env
if "/" not in uri_part:
project, location = _load_gcp_config(
agents_dir, "short-form agent engine IDs"
)
return {
"project": project,
"location": location,
"agent_engine_id": uri_part,
}
# If uri_part is a full resource name, parse it
parts = uri_part.split("/")
if not (
len(parts) == 6
and parts[0] == "projects"
and parts[2] == "locations"
and parts[4] == "reasoningEngines"
):
raise ValueError(
"Agent engine resource name is mal-formatted. It should be of"
" format :"
" projects/{project_id}/locations/{location}/reasoningEngines/{resource_id}"
)
return {
"project": parts[1],
"location": parts[3],
"agent_engine_id": parts[5],
}
def _get_class_from_string(class_path: str) -> Any:
"""Dynamically import a class from a string path."""
try:
module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
except Exception as e:
raise ImportError(f"Could not import class {class_path}: {e}") from e
def _create_generic_factory(class_path: str) -> ServiceFactory:
"""Create a generic factory for a service class."""
cls = _get_class_from_string(class_path)
def factory(uri: str, **kwargs):
return cls(uri=uri, **kwargs)
return factory
def _register_services_from_yaml_config(
config: dict[str, Any], registry: ServiceRegistry
) -> None:
"""Register services defined in a YAML configuration."""
if not config or "services" not in config:
return
for service_config in config["services"]:
scheme = service_config.get("scheme")
service_type = service_config.get("type")
class_path = service_config.get("class")
if not all([scheme, service_type, class_path]):
logger.warning("Invalid service config in YAML: %s", service_config)
continue
factory = _create_generic_factory(class_path)
if service_type == "session":
registry.register_session_service(scheme, factory)
elif service_type == "artifact":
registry.register_artifact_service(scheme, factory)
elif service_type == "memory":
registry.register_memory_service(scheme, factory)
else:
logger.warning("Unknown service type in YAML: %s", service_type)