forked from Zipstack/unstract
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathadapter_processor.py
More file actions
257 lines (224 loc) · 9.89 KB
/
adapter_processor.py
File metadata and controls
257 lines (224 loc) · 9.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import json
import logging
from typing import Any, Optional
from account.models import User
from adapter_processor.constants import AdapterKeys
from adapter_processor.exceptions import (
InternalServiceError,
InValidAdapterId,
TestAdapterError,
TestAdapterInputError,
)
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from platform_settings.platform_auth_service import PlatformAuthenticationService
from unstract.sdk.adapters.adapterkit import Adapterkit
from unstract.sdk.adapters.base import Adapter
from unstract.sdk.adapters.enums import AdapterTypes
from unstract.sdk.adapters.exceptions import AdapterError
from unstract.sdk.adapters.x2text.constants import X2TextConstants
from .models import AdapterInstance, UserDefaultAdapter
logger = logging.getLogger(__name__)
class AdapterProcessor:
@staticmethod
def get_json_schema(adapter_id: str) -> dict[str, Any]:
"""Function to return JSON Schema for Adapters."""
schema_details: dict[str, Any] = {}
updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value(
AdapterKeys.ID, adapter_id
)
if len(updated_adapters) != 0:
schema_details[AdapterKeys.JSON_SCHEMA] = json.loads(
updated_adapters[0].get(AdapterKeys.JSON_SCHEMA)
)
else:
logger.error(
f"Invalid adapter Id : {adapter_id} while fetching JSON Schema"
)
raise InValidAdapterId()
return schema_details
@staticmethod
def get_all_supported_adapters(type: str) -> list[dict[Any, Any]]:
"""Function to return list of all supported adapters."""
supported_adapters = []
updated_adapters = []
updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value(
AdapterKeys.ADAPTER_TYPE, type
)
for each_adapter in updated_adapters:
supported_adapters.append(
{
AdapterKeys.ID: each_adapter.get(AdapterKeys.ID),
AdapterKeys.NAME: each_adapter.get(AdapterKeys.NAME),
AdapterKeys.DESCRIPTION: each_adapter.get(AdapterKeys.DESCRIPTION),
AdapterKeys.ICON: each_adapter.get(AdapterKeys.ICON),
AdapterKeys.ADAPTER_TYPE: each_adapter.get(
AdapterKeys.ADAPTER_TYPE
),
}
)
return supported_adapters
@staticmethod
def get_adapter_data_with_key(adapter_id: str, key_value: str) -> Any:
"""Generic Function to get adapter data with provided key."""
updated_adapters = AdapterProcessor.__fetch_adapters_by_key_value(
"id", adapter_id
)
if len(updated_adapters) == 0:
logger.error(f"Invalid adapter ID {adapter_id} while invoking utility")
raise InValidAdapterId()
return updated_adapters[0].get(key_value)
@staticmethod
def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool:
logger.info(f"Testing adapter: {adapter_id}")
try:
adapter_class = Adapterkit().get_adapter_class_by_adapter_id(adapter_id)
if adapter_metadata.pop(AdapterKeys.ADAPTER_TYPE) == AdapterKeys.X2TEXT:
adapter_metadata[X2TextConstants.X2TEXT_HOST] = settings.X2TEXT_HOST
adapter_metadata[X2TextConstants.X2TEXT_PORT] = settings.X2TEXT_PORT
platform_key = PlatformAuthenticationService.get_active_platform_key()
adapter_metadata[X2TextConstants.PLATFORM_SERVICE_API_KEY] = str(
platform_key.key
)
adapter_instance = adapter_class(adapter_metadata)
test_result: bool = adapter_instance.test_connection()
logger.info(f"{adapter_id} test result: {test_result}")
return test_result
# HACK: Remove after error is explicitly handled in VertexAI adapter
except json.JSONDecodeError:
raise TestAdapterInputError(
"Credentials is not a valid service account JSON, "
"please provide a valid JSON."
)
except AdapterError as e:
raise TestAdapterError(str(e))
@staticmethod
def __fetch_adapters_by_key_value(key: str, value: Any) -> Adapter:
"""Fetches a list of adapters that have an attribute matching key and
value."""
logger.info(f"Fetching adapter list for {key} with {value}")
adapter_kit = Adapterkit()
adapters = adapter_kit.get_adapters_list()
return [iterate for iterate in adapters if iterate[key] == value]
@staticmethod
def set_default_triad(default_triad: dict[str, str], user: User) -> None:
try:
(
user_default_adapter,
created,
) = UserDefaultAdapter.objects.get_or_create(user=user)
if default_triad.get(AdapterKeys.LLM_DEFAULT, None):
user_default_adapter.default_llm_adapter = AdapterInstance.objects.get(
pk=default_triad[AdapterKeys.LLM_DEFAULT]
)
if default_triad.get(AdapterKeys.EMBEDDING_DEFAULT, None):
user_default_adapter.default_embedding_adapter = (
AdapterInstance.objects.get(
pk=default_triad[AdapterKeys.EMBEDDING_DEFAULT]
)
)
if default_triad.get(AdapterKeys.VECTOR_DB_DEFAULT, None):
user_default_adapter.default_vector_db_adapter = (
AdapterInstance.objects.get(
pk=default_triad[AdapterKeys.VECTOR_DB_DEFAULT]
)
)
if default_triad.get(AdapterKeys.X2TEXT_DEFAULT, None):
user_default_adapter.default_x2text_adapter = (
AdapterInstance.objects.get(
pk=default_triad[AdapterKeys.X2TEXT_DEFAULT]
)
)
user_default_adapter.save()
logger.info("Changed defaults successfully")
except Exception as e:
logger.error(f"Unable to save defaults because: {e}")
if isinstance(e, InValidAdapterId):
raise e
else:
raise InternalServiceError()
@staticmethod
def get_adapter_instance_by_id(adapter_instance_id: str) -> Adapter:
"""Get the adapter instance by its ID.
Parameters:
- adapter_instance_id (str): The ID of the adapter instance.
Returns:
- Adapter: The adapter instance with the specified ID.
Raises:
- Exception: If there is an error while fetching the adapter instance.
"""
try:
adapter = AdapterInstance.objects.get(id=adapter_instance_id)
except Exception as e:
logger.error(f"Unable to fetch adapter: {e}")
if not adapter:
logger.error("Unable to fetch adapter")
return adapter.adapter_name
@staticmethod
def get_adapters_by_type(
adapter_type: AdapterTypes, user: User
) -> list[AdapterInstance]:
"""Get a list of adapters by their type.
Parameters:
- adapter_type (AdapterTypes): The type of adapters to retrieve.
- user: Logged in User
Returns:
- list[AdapterInstance]: A list of AdapterInstance objects that match
the specified adapter type.
"""
adapters: list[AdapterInstance] = AdapterInstance.objects.for_user(user).filter(
adapter_type=adapter_type.value,
)
return adapters
@staticmethod
def get_adapter_by_name_and_type(
adapter_type: AdapterTypes,
adapter_name: Optional[str] = None,
) -> Optional[AdapterInstance]:
"""Get the adapter instance by its name and type.
Parameters:
- adapter_name (str): The name of the adapter instance.
- adapter_type (AdapterTypes): The type of the adapter instance.
Returns:
- AdapterInstance: The adapter with the specified name and type.
"""
if adapter_name:
adapter: AdapterInstance = AdapterInstance.objects.get(
adapter_name=adapter_name, adapter_type=adapter_type.value
)
else:
try:
adapter = AdapterInstance.objects.get(
adapter_type=adapter_type.value, is_default=True
)
except AdapterInstance.DoesNotExist:
return None
return adapter
@staticmethod
def get_default_adapters(user: User) -> list[AdapterInstance]:
"""Retrieve a list of default adapter instances. This method queries
the database to fetch all adapter instances marked as default.
Raises:
InternalServiceError: If an unexpected error occurs during
the database query.
Returns:
list[AdapterInstance]: A list of AdapterInstance objects that are
marked as default.
"""
try:
adapters: list[AdapterInstance] = []
default_adapter = UserDefaultAdapter.objects.get(user=user)
if default_adapter.default_embedding_adapter:
adapters.append(default_adapter.default_embedding_adapter)
if default_adapter.default_llm_adapter:
adapters.append(default_adapter.default_llm_adapter)
if default_adapter.default_vector_db_adapter:
adapters.append(default_adapter.default_vector_db_adapter)
if default_adapter.default_x2text_adapter:
adapters.append(default_adapter.default_x2text_adapter)
return adapters
except ObjectDoesNotExist as e:
logger.error(f"No default adapters found: {e}")
raise InternalServiceError(
"No default adapters found, configure them through Platform Settings"
)