forked from a2aproject/a2a-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtask.py
More file actions
194 lines (148 loc) · 5.51 KB
/
Copy pathtask.py
File metadata and controls
194 lines (148 loc) · 5.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""Utility functions for creating A2A Task objects."""
import binascii
import uuid
from base64 import b64decode, b64encode
from typing import Literal, Protocol, runtime_checkable
from a2a.types.a2a_pb2 import (
Artifact,
Message,
Task,
TaskState,
TaskStatus,
)
from a2a.utils.constants import MAX_LIST_TASKS_PAGE_SIZE
from a2a.utils.errors import InvalidParamsError
def new_task(request: Message) -> Task:
"""Creates a new Task object from an initial user message.
Generates task and context IDs if not provided in the message.
Args:
request: The initial `Message` object from the user.
Returns:
A new `Task` object initialized with 'submitted' status and the input message in history.
Raises:
TypeError: If the message role is None.
ValueError: If the message parts are empty, if any part has empty content, or if the provided context_id is invalid.
"""
if not request.role:
raise TypeError('Message role cannot be None')
if not request.parts:
raise ValueError('Message parts cannot be empty')
for part in request.parts:
if part.HasField('text') and not part.text:
raise ValueError('Message.text cannot be empty')
return Task(
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
id=request.task_id or str(uuid.uuid4()),
context_id=request.context_id or str(uuid.uuid4()),
history=[request],
)
def completed_task(
task_id: str,
context_id: str,
artifacts: list[Artifact],
history: list[Message] | None = None,
) -> Task:
"""Creates a Task object in the 'completed' state.
Useful for constructing a final Task representation when the agent
finishes and produces artifacts.
Args:
task_id: The ID of the task.
context_id: The context ID of the task.
artifacts: A list of `Artifact` objects produced by the task.
history: An optional list of `Message` objects representing the task history.
Returns:
A `Task` object with status set to 'completed'.
"""
if not artifacts or not all(isinstance(a, Artifact) for a in artifacts):
raise ValueError(
'artifacts must be a non-empty list of Artifact objects'
)
if history is None:
history = []
return Task(
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
id=task_id,
context_id=context_id,
artifacts=artifacts,
history=history,
)
@runtime_checkable
class HistoryLengthConfig(Protocol):
"""Protocol for configuration arguments containing history_length field."""
history_length: int
def HasField(self, field_name: Literal['history_length']) -> bool: # noqa: N802 -- Protobuf generated code
"""Checks if a field is set.
This method name matches the generated Protobuf code.
"""
...
def validate_history_length(config: HistoryLengthConfig | None) -> None:
"""Validates that history_length is non-negative."""
if config and config.history_length < 0:
raise InvalidParamsError(message='history length must be non-negative')
def apply_history_length(
task: Task, config: HistoryLengthConfig | None
) -> Task:
"""Applies history_length parameter on task and returns a new task object.
Args:
task: The original task object with complete history
config: Configuration object containing 'history_length' field and HasField method.
Returns:
A new task object with limited history
See Also:
https://a2a-protocol.org/latest/specification/#324-history-length-semantics
"""
if config is None or not config.HasField('history_length'):
return task
history_length = config.history_length
if history_length == 0:
if not task.history:
return task
task_copy = Task()
task_copy.CopyFrom(task)
task_copy.ClearField('history')
return task_copy
if history_length > 0 and task.history:
if len(task.history) <= history_length:
return task
task_copy = Task()
task_copy.CopyFrom(task)
del task_copy.history[:-history_length]
return task_copy
return task
def validate_page_size(page_size: int) -> None:
"""Validates that page_size is in range [1, 100].
See Also:
https://a2a-protocol.org/latest/specification/#314-list-tasks
"""
if page_size < 1:
raise InvalidParamsError(message='minimum page size is 1')
if page_size > MAX_LIST_TASKS_PAGE_SIZE:
raise InvalidParamsError(
message=f'maximum page size is {MAX_LIST_TASKS_PAGE_SIZE}'
)
_ENCODING = 'utf-8'
def encode_page_token(task_id: str) -> str:
"""Encodes page token for tasks pagination.
Args:
task_id: The ID of the task.
Returns:
The encoded page token.
"""
return b64encode(task_id.encode(_ENCODING)).decode(_ENCODING)
def decode_page_token(page_token: str) -> str:
"""Decodes page token for tasks pagination.
Args:
page_token: The encoded page token.
Returns:
The decoded task ID.
"""
encoded_str = page_token
missing_padding = len(encoded_str) % 4
if missing_padding:
encoded_str += '=' * (4 - missing_padding)
print(f'input: {encoded_str}')
try:
decoded = b64decode(encoded_str.encode(_ENCODING)).decode(_ENCODING)
except (binascii.Error, UnicodeDecodeError) as e:
raise ValueError('Token is not a valid base64-encoded cursor.') from e
return decoded