Skip to content

Commit d4ac2c2

Browse files
authored
fix(example): support multi-step Responses tool streaming (abetlen#2288)
* fix(example): support multi-step Responses tool streaming * docs: add Responses tool streaming changelog
1 parent 7eb494d commit d4ac2c2

2 files changed

Lines changed: 112 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- fix(example): support multi-step Responses tool streaming by @abetlen in #2288
1011
- fix(ci): Repair Linux accelerator wheels for manylinux publishing
1112

1213
## [0.3.28]

examples/server/server.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2812,6 +2812,7 @@ def to_chat_template_tool(self) -> ChatTemplateTool:
28122812
class ResponsesCustomToolFormat(BaseModel):
28132813
model_config = ConfigDict(extra="ignore")
28142814

2815+
type: Optional[str] = None
28152816
syntax: Optional[str] = None
28162817
definition: Optional[str] = None
28172818

@@ -2880,10 +2881,24 @@ class ResponsesWebSearchTool(BaseModel):
28802881
type: Literal["web_search"]
28812882

28822883

2884+
class ResponsesNamespaceTool(BaseModel):
2885+
model_config = ConfigDict(extra="ignore")
2886+
2887+
type: Literal["namespace"]
2888+
2889+
2890+
class ResponsesImageGenerationTool(BaseModel):
2891+
model_config = ConfigDict(extra="ignore")
2892+
2893+
type: Literal["image_generation"]
2894+
2895+
28832896
ResponsesToolDefinition = Union[
28842897
ResponsesFunctionTool,
28852898
ResponsesCustomTool,
28862899
ResponsesWebSearchTool,
2900+
ResponsesNamespaceTool,
2901+
ResponsesImageGenerationTool,
28872902
]
28882903

28892904

@@ -5069,6 +5084,68 @@ def _tool_content_type(self, tool_name: str) -> Optional[str]:
50695084
return content_type
50705085
return None
50715086

5087+
def _raw_string_tool_arguments(self, tool_name: str, value: str) -> Optional[Dict[str, str]]:
5088+
if self._tools is None:
5089+
return None
5090+
for tool in self._tools:
5091+
if tool.get("type") != "function":
5092+
continue
5093+
function = tool.get("function", {})
5094+
if function.get("name") != tool_name:
5095+
continue
5096+
parameters = function.get("parameters")
5097+
if not isinstance(parameters, dict):
5098+
return None
5099+
required = parameters.get("required")
5100+
if not isinstance(required, list) or len(required) != 1:
5101+
return None
5102+
argument_name = required[0]
5103+
if not isinstance(argument_name, str):
5104+
return None
5105+
properties = parameters.get("properties")
5106+
if not isinstance(properties, dict):
5107+
return None
5108+
argument_schema = properties.get(argument_name)
5109+
if not isinstance(argument_schema, dict):
5110+
return None
5111+
argument_type = argument_schema.get("type")
5112+
if argument_type == "string" or (
5113+
isinstance(argument_type, list) and "string" in argument_type
5114+
):
5115+
return {argument_name: value}
5116+
return None
5117+
return None
5118+
5119+
@classmethod
5120+
def _raw_object_tool_arguments(cls, value: str) -> Optional[Dict[str, Any]]:
5121+
candidates = [value]
5122+
stripped = value.strip()
5123+
if stripped.startswith("{{") and stripped.endswith("}}"):
5124+
candidates.append(stripped[1:-1])
5125+
for candidate in candidates:
5126+
normalized = cls._gemma4_tool_call_to_json(candidate)
5127+
for allow_partial in (False, True):
5128+
try:
5129+
parsed = from_json(normalized, allow_partial=allow_partial)
5130+
except ValueError:
5131+
continue
5132+
if isinstance(parsed, dict):
5133+
return {
5134+
key: cls._trim_partial_gemma_quote_marker(value)
5135+
if isinstance(value, str)
5136+
else value
5137+
for key, value in parsed.items()
5138+
}
5139+
return None
5140+
5141+
@staticmethod
5142+
def _trim_partial_gemma_quote_marker(value: str) -> str:
5143+
quote_marker = '<|"|>'
5144+
for prefix_length in range(len(quote_marker) - 1, 0, -1):
5145+
if value.endswith(quote_marker[:prefix_length]):
5146+
return value[:-prefix_length]
5147+
return value
5148+
50725149
def _has_text_tools(self) -> bool:
50735150
return any(
50745151
isinstance(tool_schema, dict) and tool_schema.get("content_type") == "text"
@@ -5637,6 +5714,18 @@ def _advance_direct_stream_state(self, text: str) -> Tuple[bool, List[Dict[str,
56375714
self._direct.saw_tool_calls = saw_tool_calls
56385715
self._direct.done = done
56395716
return True, deltas
5717+
if leading_capture_field is not None:
5718+
if buffer.startswith(leading_capture_start):
5719+
buffer = buffer[len(leading_capture_start) :]
5720+
mode = self.DIRECT_MODE_LEADING_CAPTURE
5721+
continue
5722+
if leading_capture_start.startswith(buffer):
5723+
self._direct.pending = buffer
5724+
self._direct.mode = mode
5725+
self._direct.tool_call_count = tool_call_count
5726+
self._direct.saw_tool_calls = saw_tool_calls
5727+
self._direct.done = done
5728+
return True, deltas
56405729
if buffer.startswith(iterator_start):
56415730
saw_tool_calls = True
56425731
self._start_direct_tool_call(tool_call_count)
@@ -6302,6 +6391,16 @@ def _advance_stream_state(self, text: str) -> Tuple[bool, List[Dict[str, Any]]]:
63026391
if not buffer:
63036392
state.pending = ""
63046393
return True, deltas
6394+
leading_capture = plan.get("leading_capture")
6395+
if leading_capture is not None:
6396+
capture_start = leading_capture["start"]
6397+
if buffer.startswith(capture_start):
6398+
buffer = buffer[len(capture_start) :]
6399+
state.mode = "leading-capture"
6400+
continue
6401+
if capture_start.startswith(buffer):
6402+
state.pending = buffer
6403+
return True, deltas
63056404
if buffer.startswith(iterator_start):
63066405
item_state = self._new_tool_call_state(plan["iterator"]["item"])
63076406
state.saw_tool_calls = True
@@ -6866,6 +6965,10 @@ def _normalize_tool_call_item(
68666965
},
68676966
}
68686967
arguments = function.get("arguments", {})
6968+
if isinstance(arguments, str):
6969+
arguments = self._raw_object_tool_arguments(arguments) or self._raw_string_tool_arguments(
6970+
tool_name, arguments
6971+
)
68696972
if not isinstance(arguments, (dict, ResponseParser.PartialJsonObject)):
68706973
if partial:
68716974
return None
@@ -8009,7 +8112,14 @@ def _responses_tools_to_chat_tools(
80098112
return None
80108113
chat_tools: List[ChatTemplateTool] = []
80118114
for tool in tools:
8012-
if isinstance(tool, ResponsesWebSearchTool):
8115+
if isinstance(
8116+
tool,
8117+
(
8118+
ResponsesWebSearchTool,
8119+
ResponsesNamespaceTool,
8120+
ResponsesImageGenerationTool,
8121+
),
8122+
):
80138123
continue
80148124
if isinstance(tool, ResponsesFunctionTool):
80158125
chat_tools.append(tool.to_chat_template_tool())

0 commit comments

Comments
 (0)