22from typing import List , Optional , Tuple
33from uuid import uuid4 as uuid
44
5+ from codegate .db import models as db_models
56from codegate .db .connection import DbReader , DbRecorder
6- from codegate .db .models import (
7- ActiveWorkspace ,
8- MuxRule ,
9- Session ,
10- WorkspaceRow ,
11- WorkspaceWithSessionInfo ,
12- )
7+ from codegate .muxing import models as mux_models
138from codegate .muxing import rulematcher
149
1510
@@ -40,7 +35,7 @@ class WorkspaceCrud:
4035 def __init__ (self ):
4136 self ._db_reader = DbReader ()
4237
43- async def add_workspace (self , new_workspace_name : str ) -> WorkspaceRow :
38+ async def add_workspace (self , new_workspace_name : str ) -> db_models . WorkspaceRow :
4439 """
4540 Add a workspace
4641
@@ -57,7 +52,7 @@ async def add_workspace(self, new_workspace_name: str) -> WorkspaceRow:
5752
5853 async def rename_workspace (
5954 self , old_workspace_name : str , new_workspace_name : str
60- ) -> WorkspaceRow :
55+ ) -> db_models . WorkspaceRow :
6156 """
6257 Rename a workspace
6358
@@ -79,33 +74,33 @@ async def rename_workspace(
7974 if not ws :
8075 raise WorkspaceDoesNotExistError (f"Workspace { old_workspace_name } does not exist." )
8176 db_recorder = DbRecorder ()
82- new_ws = WorkspaceRow (
77+ new_ws = db_models . WorkspaceRow (
8378 id = ws .id , name = new_workspace_name , custom_instructions = ws .custom_instructions
8479 )
8580 workspace_renamed = await db_recorder .update_workspace (new_ws )
8681 return workspace_renamed
8782
88- async def get_workspaces (self ) -> List [WorkspaceWithSessionInfo ]:
83+ async def get_workspaces (self ) -> List [db_models . WorkspaceWithSessionInfo ]:
8984 """
9085 Get all workspaces
9186 """
9287 return await self ._db_reader .get_workspaces ()
9388
94- async def get_archived_workspaces (self ) -> List [WorkspaceRow ]:
89+ async def get_archived_workspaces (self ) -> List [db_models . WorkspaceRow ]:
9590 """
9691 Get all archived workspaces
9792 """
9893 return await self ._db_reader .get_archived_workspaces ()
9994
100- async def get_active_workspace (self ) -> Optional [ActiveWorkspace ]:
95+ async def get_active_workspace (self ) -> Optional [db_models . ActiveWorkspace ]:
10196 """
10297 Get the active workspace
10398 """
10499 return await self ._db_reader .get_active_workspace ()
105100
106101 async def _is_workspace_active (
107102 self , workspace_name : str
108- ) -> Tuple [bool , Optional [Session ], Optional [WorkspaceRow ]]:
103+ ) -> Tuple [bool , Optional [db_models . Session ], Optional [db_models . WorkspaceRow ]]:
109104 """
110105 Check if the workspace is active alongside the session and workspace objects
111106 """
@@ -155,13 +150,13 @@ async def recover_workspace(self, workspace_name: str):
155150
156151 async def update_workspace_custom_instructions (
157152 self , workspace_name : str , custom_instr_lst : List [str ]
158- ) -> WorkspaceRow :
153+ ) -> db_models . WorkspaceRow :
159154 selected_workspace = await self ._db_reader .get_workspace_by_name (workspace_name )
160155 if not selected_workspace :
161156 raise WorkspaceDoesNotExistError (f"Workspace { workspace_name } does not exist." )
162157
163158 custom_instructions = " " .join (custom_instr_lst )
164- workspace_update = WorkspaceRow (
159+ workspace_update = db_models . WorkspaceRow (
165160 id = selected_workspace .id ,
166161 name = selected_workspace .name ,
167162 custom_instructions = custom_instructions ,
@@ -217,17 +212,13 @@ async def hard_delete_workspace(self, workspace_name: str):
217212 raise WorkspaceCrudError (f"Error deleting workspace { workspace_name } " )
218213 return
219214
220- async def get_workspace_by_name (self , workspace_name : str ) -> WorkspaceRow :
215+ async def get_workspace_by_name (self , workspace_name : str ) -> db_models . WorkspaceRow :
221216 workspace = await self ._db_reader .get_workspace_by_name (workspace_name )
222217 if not workspace :
223218 raise WorkspaceDoesNotExistError (f"Workspace { workspace_name } does not exist." )
224219 return workspace
225220
226- # Can't use type hints since the models are not yet defined
227- # Note that I'm explicitly importing the models here to avoid circular imports.
228- async def get_muxes (self , workspace_name : str ):
229- from codegate .api import v1_models
230-
221+ async def get_muxes (self , workspace_name : str ) -> List [mux_models .MuxRule ]:
231222 # Verify if workspace exists
232223 workspace = await self ._db_reader .get_workspace_by_name (workspace_name )
233224 if not workspace :
@@ -239,7 +230,7 @@ async def get_muxes(self, workspace_name: str):
239230 # These are already sorted by priority
240231 for dbmux in dbmuxes :
241232 muxes .append (
242- v1_models .MuxRule (
233+ mux_models .MuxRule (
243234 provider_id = dbmux .provider_endpoint_id ,
244235 model = dbmux .provider_model_name ,
245236 matcher_type = dbmux .matcher_type ,
@@ -249,10 +240,7 @@ async def get_muxes(self, workspace_name: str):
249240
250241 return muxes
251242
252- # Can't use type hints since the models are not yet defined
253- async def set_muxes (self , workspace_name : str , muxes ):
254- from codegate .api import v1_models
255-
243+ async def set_muxes (self , workspace_name : str , muxes : mux_models .MuxRule ) -> None :
256244 # Verify if workspace exists
257245 workspace = await self ._db_reader .get_workspace_by_name (workspace_name )
258246 if not workspace :
@@ -265,7 +253,7 @@ async def set_muxes(self, workspace_name: str, muxes):
265253 # Add the new muxes
266254 priority = 0
267255
268- muxes_with_routes : List [Tuple [v1_models .MuxRule , rulematcher .ModelRoute ]] = []
256+ muxes_with_routes : List [Tuple [mux_models .MuxRule , rulematcher .ModelRoute ]] = []
269257
270258 # Verify all models are valid
271259 for mux in muxes :
@@ -275,7 +263,7 @@ async def set_muxes(self, workspace_name: str, muxes):
275263 matchers : List [rulematcher .MuxingRuleMatcher ] = []
276264
277265 for mux , route in muxes_with_routes :
278- new_mux = MuxRule (
266+ new_mux = db_models . MuxRule (
279267 id = str (uuid ()),
280268 provider_endpoint_id = mux .provider_id ,
281269 provider_model_name = mux .model ,
@@ -294,7 +282,7 @@ async def set_muxes(self, workspace_name: str, muxes):
294282 mux_registry = await rulematcher .get_muxing_rules_registry ()
295283 await mux_registry .set_ws_rules (workspace_name , matchers )
296284
297- async def get_routing_for_mux (self , mux ) -> rulematcher .ModelRoute :
285+ async def get_routing_for_mux (self , mux : mux_models . MuxRule ) -> rulematcher .ModelRoute :
298286 """Get the routing for a mux
299287
300288 Note that this particular mux object is the API model, not the database model.
@@ -322,7 +310,7 @@ async def get_routing_for_mux(self, mux) -> rulematcher.ModelRoute:
322310 auth_material = dbauth ,
323311 )
324312
325- async def get_routing_for_db_mux (self , mux : MuxRule ) -> rulematcher .ModelRoute :
313+ async def get_routing_for_db_mux (self , mux : db_models . MuxRule ) -> rulematcher .ModelRoute :
326314 """Get the routing for a mux
327315
328316 Note that this particular mux object is the database model, not the API model.
0 commit comments