forked from Zipstack/unstract
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtool_processor.py
More file actions
137 lines (118 loc) · 5.29 KB
/
tool_processor.py
File metadata and controls
137 lines (118 loc) · 5.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import logging
from typing import Any, Optional
from account.models import User
from adapter_processor.adapter_processor import AdapterProcessor
from prompt_studio.prompt_studio_registry.prompt_studio_registry_helper import (
PromptStudioRegistryHelper,
)
from tool_instance.exceptions import ToolDoesNotExist
from unstract.sdk.adapters.enums import AdapterTypes
from unstract.tool_registry.dto import Spec, Tool
from unstract.tool_registry.tool_registry import ToolRegistry
from unstract.tool_registry.tool_utils import ToolUtils
logger = logging.getLogger(__name__)
class ToolProcessor:
TOOL_NOT_IN_REGISTRY_MESSAGE = "Tool does not exist in registry"
tool_registry = ToolRegistry()
@staticmethod
def get_tool_by_uid(tool_uid: str) -> Tool:
"""Function to get and instantiate a tool for a given tool
settingsId."""
tool_registry = ToolRegistry()
tool: Optional[Tool] = tool_registry.get_tool_by_uid(tool_uid)
# HACK: Assume tool_uid is prompt_registry_id for fetching a dynamic
# tool made with Prompt Studio.
if not tool:
tool = PromptStudioRegistryHelper.get_tool_by_prompt_registry_id(
prompt_registry_id=tool_uid
)
if not tool:
raise ToolDoesNotExist(
f"{ToolProcessor.TOOL_NOT_IN_REGISTRY_MESSAGE}: {tool_uid}"
)
return tool
@staticmethod
def get_default_settings(tool: Tool) -> dict[str, str]:
"""Function to make and fill settings with default values.
Args:
tool (ToolSettings): tool
Returns:
dict[str, str]: tool settings
"""
tool_metadata: dict[str, str] = ToolUtils.get_default_settings(tool)
return tool_metadata
@staticmethod
def get_json_schema_for_tool(tool_uid: str, user: User) -> dict[str, str]:
"""Function to Get JSON Schema for Tools."""
tool: Tool = ToolProcessor.get_tool_by_uid(tool_uid=tool_uid)
schema: Spec = ToolUtils.get_json_schema_for_tool(tool)
ToolProcessor.update_schema_with_adapter_configurations(
schema=schema, user=user
)
schema_json: dict[str, Any] = schema.to_dict()
return schema_json
@staticmethod
def update_schema_with_adapter_configurations(schema: Spec, user: User) -> None:
"""Updates the JSON schema with the available adapter configurations
for the LLM, embedding, and vector DB adapters.
Args:
schema (Spec): The JSON schema object to be updated.
Returns:
None. The `schema` object is updated in-place.
"""
llm_keys = schema.get_llm_adapter_properties_keys()
embedding_keys = schema.get_embedding_adapter_properties_keys()
vector_db_keys = schema.get_vector_db_adapter_properties_keys()
x2text_keys = schema.get_text_extractor_adapter_properties_keys()
ocr_keys = schema.get_ocr_adapter_properties_keys()
if llm_keys:
adapters = AdapterProcessor.get_adapters_by_type(
AdapterTypes.LLM, user=user
)
for key in llm_keys:
adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters)
schema.properties[key]["enum"] = list(adapter_names)
if embedding_keys:
adapters = AdapterProcessor.get_adapters_by_type(
AdapterTypes.EMBEDDING, user=user
)
for key in embedding_keys:
adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters)
schema.properties[key]["enum"] = list(adapter_names)
if vector_db_keys:
adapters = AdapterProcessor.get_adapters_by_type(
AdapterTypes.VECTOR_DB, user=user
)
for key in vector_db_keys:
adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters)
schema.properties[key]["enum"] = list(adapter_names)
if x2text_keys:
adapters = AdapterProcessor.get_adapters_by_type(
AdapterTypes.X2TEXT, user=user
)
for key in x2text_keys:
adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters)
schema.properties[key]["enum"] = list(adapter_names)
if ocr_keys:
adapters = AdapterProcessor.get_adapters_by_type(
AdapterTypes.OCR, user=user
)
for key in ocr_keys:
adapter_names = map(lambda adapter: str(adapter.adapter_name), adapters)
schema.properties[key]["enum"] = list(adapter_names)
@staticmethod
def get_tool_list(user: User) -> list[dict[str, Any]]:
"""Function to get a list of tools."""
tool_registry = ToolRegistry()
prompt_studio_tools: list[dict[str, Any]] = (
PromptStudioRegistryHelper.fetch_json_for_registry(user)
)
tool_list: list[dict[str, Any]] = tool_registry.fetch_tools_descriptions()
tool_list = tool_list + prompt_studio_tools
return tool_list
@staticmethod
def get_registry_tools() -> list[Tool]:
"""Function to get a list of tools."""
tool_registry = ToolRegistry()
tool_list: list[Tool] = tool_registry.fetch_all_tools()
return tool_list