From c9cbf9a0a1a97869a6f8632697e1b515c73bf01b Mon Sep 17 00:00:00 2001 From: anhnh2002 Date: Mon, 13 Oct 2025 14:34:34 +0700 Subject: [PATCH] using native agent stead of pydantic ai agent --- src/be/agent_orchestrator.py | 47 +--- .../generate_sub_module_documentations.py | 43 +++- src/be/agent_tools/read_code_components.py | 25 +- src/be/agent_tools/str_replace_editor.py | 62 ++++- src/be/llm_services.py | 36 +-- src/be/native_agent.py | 235 ++++++++++++++++++ 6 files changed, 351 insertions(+), 97 deletions(-) create mode 100644 src/be/native_agent.py diff --git a/src/be/agent_orchestrator.py b/src/be/agent_orchestrator.py index 7b7455e3..304a3ed8 100644 --- a/src/be/agent_orchestrator.py +++ b/src/be/agent_orchestrator.py @@ -1,5 +1,3 @@ -from pydantic_ai import Agent -import logfire import logging import os from typing import Dict, List, Any @@ -8,38 +6,13 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -try: - # Configure logfire with environment variables for Docker compatibility - logfire_token = os.getenv('LOGFIRE_TOKEN') - logfire_project = os.getenv('LOGFIRE_PROJECT_NAME', 'default') - logfire_service = os.getenv('LOGFIRE_SERVICE_NAME', 'default') - - if logfire_token: - # Configure with explicit token (for Docker) - logfire.configure( - token=logfire_token, - project_name=logfire_project, - service_name=logfire_service, - ) - else: - # Use default configuration (for local development with logfire auth) - logfire.configure( - project_name=logfire_project, - service_name=logfire_service, - ) - - logfire.instrument_pydantic_ai() - logger.info(f"Logfire configured successfully for project: {logfire_project}") - -except Exception as e: - logger.warning(f"Failed to configure logfire: {e}") - # Local imports from .agent_tools.deps import CodeWikiDeps from .agent_tools.read_code_components import read_code_components_tool from .agent_tools.str_replace_editor import str_replace_editor_tool from .agent_tools.generate_sub_module_documentations import generate_sub_module_documentation_tool -from .llm_services import fallback_models +from .llm_services import client, MAIN_MODEL, FALLBACK_MODEL_1 +from .native_agent import NativeAgent from .prompt_template import ( SYSTEM_PROMPT, LEAF_SYSTEM_PROMPT, @@ -61,13 +34,14 @@ def __init__(self, config: Config): self.config = config def create_agent(self, module_name: str, components: Dict[str, Any], - core_component_ids: List[str]) -> Agent: + core_component_ids: List[str]) -> NativeAgent: """Create an appropriate agent based on module complexity.""" if is_complex_module(components, core_component_ids): - return Agent( - fallback_models, + return NativeAgent( + client=client, + model=MAIN_MODEL, + fallback_model=FALLBACK_MODEL_1, name=module_name, - deps_type=CodeWikiDeps, tools=[ read_code_components_tool, str_replace_editor_tool, @@ -76,10 +50,11 @@ def create_agent(self, module_name: str, components: Dict[str, Any], system_prompt=SYSTEM_PROMPT.format(module_name=module_name), ) else: - return Agent( - fallback_models, + return NativeAgent( + client=client, + model=MAIN_MODEL, + fallback_model=FALLBACK_MODEL_1, name=module_name, - deps_type=CodeWikiDeps, tools=[read_code_components_tool, str_replace_editor_tool], system_prompt=LEAF_SYSTEM_PROMPT.format(module_name=module_name), ) diff --git a/src/be/agent_tools/generate_sub_module_documentations.py b/src/be/agent_tools/generate_sub_module_documentations.py index 4bc275ae..6634d351 100644 --- a/src/be/agent_tools/generate_sub_module_documentations.py +++ b/src/be/agent_tools/generate_sub_module_documentations.py @@ -1,9 +1,8 @@ -from pydantic_ai import RunContext, Tool, Agent - +from ..native_agent import AgentContext, NativeAgent, create_tool_definition from .deps import CodeWikiDeps from .read_code_components import read_code_components_tool from .str_replace_editor import str_replace_editor_tool -from ..llm_services import fallback_models +from ..llm_services import client, MAIN_MODEL, FALLBACK_MODEL_1 from ..prompt_template import SYSTEM_PROMPT, LEAF_SYSTEM_PROMPT, format_user_prompt from ..utils import is_complex_module, count_tokens from ..cluster_modules import format_potential_core_components @@ -12,7 +11,7 @@ async def generate_sub_module_documentation( - ctx: RunContext[CodeWikiDeps], + ctx: AgentContext, sub_module_specs: dict[str, list[str]] ) -> str: """Generate detailed description of a given sub-module specs to the sub-agents @@ -36,18 +35,20 @@ async def generate_sub_module_documentation( num_tokens = count_tokens(format_potential_core_components(core_component_ids, ctx.deps.components)[-1]) if is_complex_module(ctx.deps.components, core_component_ids) and ctx.deps.current_depth < ctx.deps.max_depth and num_tokens >= MAX_TOKEN_PER_LEAF_MODULE: - sub_agent = Agent( - model=fallback_models, + sub_agent = NativeAgent( + client=client, + model=MAIN_MODEL, + fallback_model=FALLBACK_MODEL_1, name=sub_module_name, - deps_type=CodeWikiDeps, system_prompt=SYSTEM_PROMPT.format(module_name=sub_module_name), tools=[read_code_components_tool, str_replace_editor_tool, generate_sub_module_documentation_tool], ) else: - sub_agent = Agent( - model=fallback_models, + sub_agent = NativeAgent( + client=client, + model=MAIN_MODEL, + fallback_model=FALLBACK_MODEL_1, name=sub_module_name, - deps_type=CodeWikiDeps, system_prompt=LEAF_SYSTEM_PROMPT.format(module_name=sub_module_name), tools=[read_code_components_tool, str_replace_editor_tool], ) @@ -78,4 +79,24 @@ async def generate_sub_module_documentation( return f"Generate successfully. Documentations: {', '.join([key + '.md' for key in sub_module_specs.keys()])} are saved in the working directory." -generate_sub_module_documentation_tool = Tool(function=generate_sub_module_documentation, name="generate_sub_module_documentation", description="Generate detailed description of a given sub-module specs to the sub-agents", takes_ctx=True) \ No newline at end of file +# Tool definition for native OpenAI function calling +generate_sub_module_documentation_tool = create_tool_definition( + name="generate_sub_module_documentation", + description="Generate detailed description of a given sub-module specs to the sub-agents", + function=generate_sub_module_documentation, + parameters={ + "type": "object", + "properties": { + "sub_module_specs": { + "type": "object", + "description": "The specs of the sub-modules to generate documentation for. E.g. {'sub_module_1': ['core_component_1.1', 'core_component_1.2'], 'sub_module_2': ['core_component_2.1', 'core_component_2.2'], ...}", + "additionalProperties": { + "type": "array", + "items": {"type": "string"} + } + } + }, + "required": ["sub_module_specs"] + }, + takes_ctx=True +) \ No newline at end of file diff --git a/src/be/agent_tools/read_code_components.py b/src/be/agent_tools/read_code_components.py index 10005624..00d3090b 100644 --- a/src/be/agent_tools/read_code_components.py +++ b/src/be/agent_tools/read_code_components.py @@ -1,8 +1,7 @@ -from pydantic_ai import RunContext, Tool -from .deps import CodeWikiDeps +from ..native_agent import AgentContext, create_tool_definition -async def read_code_components(ctx: RunContext[CodeWikiDeps], component_ids: list[str]) -> str: +async def read_code_components(ctx: AgentContext, component_ids: list[str]) -> str: """Read the code of a given component id Args: @@ -19,4 +18,22 @@ async def read_code_components(ctx: RunContext[CodeWikiDeps], component_ids: lis return "\n".join(results) -read_code_components_tool = Tool(function=read_code_components, name="read_code_components", description="Read the code of a given list of component ids", takes_ctx=True) \ No newline at end of file + +# Tool definition for native OpenAI function calling +read_code_components_tool = create_tool_definition( + name="read_code_components", + description="Read the code of a given list of component ids", + function=read_code_components, + parameters={ + "type": "object", + "properties": { + "component_ids": { + "type": "array", + "items": {"type": "string"}, + "description": "The ids of the components to read, e.g. ['sweagent.types.AgentRunResult']" + } + }, + "required": ["component_ids"] + }, + takes_ctx=True +) \ No newline at end of file diff --git a/src/be/agent_tools/str_replace_editor.py b/src/be/agent_tools/str_replace_editor.py index 9d0a7e81..80d527cf 100644 --- a/src/be/agent_tools/str_replace_editor.py +++ b/src/be/agent_tools/str_replace_editor.py @@ -19,9 +19,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -from pydantic_ai import RunContext, Tool - -from .deps import CodeWikiDeps +from ..native_agent import AgentContext, create_tool_definition from ..utils import validate_mermaid_diagrams @@ -707,9 +705,9 @@ def _make_output( return f"Here's the result of running `cat -n` on {file_descriptor}:\n" + file_content + "\n" async def str_replace_editor( - ctx: RunContext[CodeWikiDeps], - working_dir: Literal["repo", "docs"], - command: Literal["view", "create", "str_replace", "insert", "undo_edit"], + ctx: AgentContext, + working_dir: str, + command: str, path: str, file_text: Optional[str] = None, view_range: Optional[List[int]] = None, @@ -766,17 +764,57 @@ async def str_replace_editor( return result -str_replace_editor_tool = Tool( - function=str_replace_editor, +# Tool definition for native OpenAI function calling +str_replace_editor_tool = create_tool_definition( name="str_replace_editor", - description=""" -Custom editing tool for viewing, creating and editing files + description="""Custom editing tool for viewing, creating and editing files * State is persistent across command calls and discussions with the user * If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep. * The `create` command cannot be used if the specified `path` already exists as a file * If a `command` generates a long output, it will be truncated and marked with `` * The `undo_edit` command will revert the last edit made to the file at `path` - * Only `view` command is allowed when `working_dir` is `repo`. -""".strip(), + * Only `view` command is allowed when `working_dir` is `repo`.""", + function=str_replace_editor, + parameters={ + "type": "object", + "properties": { + "working_dir": { + "type": "string", + "enum": ["repo", "docs"], + "description": "The working directory to use. Choose 'repo' to work with repository files, or 'docs' to work with generated documentation files." + }, + "command": { + "type": "string", + "enum": ["view", "create", "str_replace", "insert", "undo_edit"], + "description": "The command to run" + }, + "path": { + "type": "string", + "description": "Path to file or directory, e.g. './chat_core.md' or './agents/'" + }, + "file_text": { + "type": "string", + "description": "Required parameter of 'create' command, with the content of the file to be created" + }, + "view_range": { + "type": "array", + "items": {"type": "integer"}, + "description": "Optional parameter of 'view' command when path points to a file. Line number range [start, end]" + }, + "old_str": { + "type": "string", + "description": "Required parameter of 'str_replace' command containing the string in path to replace" + }, + "new_str": { + "type": "string", + "description": "Optional parameter of 'str_replace' command containing the new string" + }, + "insert_line": { + "type": "integer", + "description": "Required parameter of 'insert' command specifying the line number to insert at" + } + }, + "required": ["working_dir", "command", "path"] + }, takes_ctx=True ) diff --git a/src/be/llm_services.py b/src/be/llm_services.py index 52b28021..ae34f44f 100644 --- a/src/be/llm_services.py +++ b/src/be/llm_services.py @@ -1,39 +1,7 @@ -from pydantic_ai.models.openai import OpenAIModel -from pydantic_ai.providers.openai import OpenAIProvider -from pydantic_ai.models.openai import OpenAIModelSettings -from pydantic_ai.models.fallback import FallbackModel - +from openai import OpenAI from config import MAIN_MODEL, FALLBACK_MODEL_1, LLM_BASE_URL, LLM_API_KEY - -main_model = OpenAIModel( - model_name=MAIN_MODEL, - provider=OpenAIProvider( - base_url=LLM_BASE_URL, - api_key=LLM_API_KEY - ), - settings=OpenAIModelSettings( - temperature=0.0, - max_tokens=32768 - ) -) - -fallback_model_1 = OpenAIModel( - model_name=FALLBACK_MODEL_1, - provider=OpenAIProvider( - base_url=LLM_BASE_URL, - api_key=LLM_API_KEY - ), - settings=OpenAIModelSettings( - temperature=0.0, - max_tokens=32768 - ) -) - -fallback_models = FallbackModel(main_model, fallback_model_1) - -# ------------------------------------------------------------ -from openai import OpenAI +# Native OpenAI client configuration client = OpenAI( base_url=LLM_BASE_URL, diff --git a/src/be/native_agent.py b/src/be/native_agent.py new file mode 100644 index 00000000..93226f8e --- /dev/null +++ b/src/be/native_agent.py @@ -0,0 +1,235 @@ +""" +Native OpenAI Agent Implementation +Replaces pydantic_ai with direct OpenAI API calls +""" + +import json +import logging +from typing import Dict, List, Any, Callable, Optional +from dataclasses import dataclass +from openai import OpenAI + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolDefinition: + """Definition of a tool that can be called by the agent""" + name: str + description: str + function: Callable + parameters: Dict[str, Any] + takes_ctx: bool = False + + +class AgentContext: + """Context object passed to tool functions""" + def __init__(self, deps: Any): + self.deps = deps + + +class NativeAgent: + """ + Native OpenAI agent implementation using function calling. + Replaces pydantic_ai Agent class. + """ + + def __init__( + self, + client: OpenAI, + model: str, + fallback_model: Optional[str] = None, + name: str = "agent", + tools: List[ToolDefinition] = None, + system_prompt: str = "", + temperature: float = 0.0, + max_tokens: int = 32768, + max_iterations: int = 50 + ): + self.client = client + self.model = model + self.fallback_model = fallback_model + self.name = name + self.tools = tools or [] + self.system_prompt = system_prompt + self.temperature = temperature + self.max_tokens = max_tokens + self.max_iterations = max_iterations + + # Create tool map for quick lookup + self.tool_map = {tool.name: tool for tool in self.tools} + + def _get_tool_schemas(self) -> List[Dict[str, Any]]: + """Convert tool definitions to OpenAI function calling format""" + schemas = [] + for tool in self.tools: + schema = { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters + } + } + schemas.append(schema) + return schemas + + async def _execute_tool( + self, + tool_name: str, + arguments: Dict[str, Any], + ctx: AgentContext + ) -> str: + """Execute a tool function and return its result""" + if tool_name not in self.tool_map: + return f"Error: Tool {tool_name} not found" + + tool = self.tool_map[tool_name] + + try: + # If tool takes context, pass it as first argument + if tool.takes_ctx: + result = await tool.function(ctx, **arguments) + else: + result = await tool.function(**arguments) + + # Convert result to string if it isn't already + if not isinstance(result, str): + result = json.dumps(result) + + return result + except Exception as e: + logger.error(f"Error executing tool {tool_name}: {str(e)}") + return f"Error executing {tool_name}: {str(e)}" + + def _call_llm(self, messages: List[Dict[str, Any]], use_tools: bool = True) -> Dict[str, Any]: + """Make a call to the LLM with optional tool support""" + try: + kwargs = { + "model": self.model, + "messages": messages, + "temperature": self.temperature, + "max_tokens": self.max_tokens + } + + if use_tools and self.tools: + kwargs["tools"] = self._get_tool_schemas() + kwargs["tool_choice"] = "auto" + + response = self.client.chat.completions.create(**kwargs) + return response + + except Exception as e: + # Try fallback model if available + if self.fallback_model: + logger.warning(f"Primary model failed: {e}. Trying fallback model: {self.fallback_model}") + try: + kwargs["model"] = self.fallback_model + response = self.client.chat.completions.create(**kwargs) + return response + except Exception as fallback_error: + logger.error(f"Fallback model also failed: {fallback_error}") + raise + raise + + async def run(self, user_message: str, deps: Any) -> str: + """ + Run the agent with a user message. + + Args: + user_message: The user's input message + deps: Dependencies to pass to tool functions + + Returns: + The final response from the agent + """ + # Initialize context + ctx = AgentContext(deps=deps) + + # Initialize conversation with system prompt + messages = [] + if self.system_prompt: + messages.append({ + "role": "system", + "content": self.system_prompt + }) + messages.append({ + "role": "user", + "content": user_message + }) + + # Agent loop + iteration = 0 + while iteration < self.max_iterations: + iteration += 1 + logger.info(f"Agent iteration {iteration}/{self.max_iterations}") + + # Call LLM + response = self._call_llm(messages, use_tools=True) + assistant_message = response.choices[0].message + + # Add assistant message to conversation + messages.append({ + "role": "assistant", + "content": assistant_message.content, + "tool_calls": assistant_message.tool_calls if hasattr(assistant_message, 'tool_calls') and assistant_message.tool_calls else None + }) + + # Check if we're done (no tool calls) + if not hasattr(assistant_message, 'tool_calls') or not assistant_message.tool_calls: + # Return final response + return assistant_message.content if assistant_message.content else "Task completed." + + # Execute tool calls + for tool_call in assistant_message.tool_calls: + tool_name = tool_call.function.name + try: + arguments = json.loads(tool_call.function.arguments) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse tool arguments: {e}") + arguments = {} + + logger.info(f"Executing tool: {tool_name}") + logger.debug(f"Tool arguments: {arguments}") + + # Execute the tool + tool_result = await self._execute_tool(tool_name, arguments, ctx) + + # Add tool result to conversation + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "name": tool_name, + "content": tool_result + }) + + # If we hit max iterations, return what we have + logger.warning(f"Agent reached max iterations ({self.max_iterations})") + return "Task execution reached maximum iterations. Please check the results." + + +def create_tool_definition( + name: str, + description: str, + function: Callable, + parameters: Dict[str, Any], + takes_ctx: bool = False +) -> ToolDefinition: + """ + Helper function to create a tool definition. + + Args: + name: Name of the tool + description: Description of what the tool does + function: The function to call + parameters: JSON schema for the function parameters + takes_ctx: Whether the function takes a context as first argument + """ + return ToolDefinition( + name=name, + description=description, + function=function, + parameters=parameters, + takes_ctx=takes_ctx + ) +