Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Update the muxing rules to v3
Closes: #1060

Right now the muxing rules are designed to catch globally
FIM or Chat requests. This PR extends its functionality
to be able to match per file and request, i.e. this PR enables
- Chat request of main.py -> model 1
- FIM request of main.py -> model 2
- Any type of v1.py -> model 3
  • Loading branch information
aponcedeleonch committed Feb 21, 2025
commit 161737acbff37c9004446b3c094a75feab7b9d3c
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""update matcher types

Revision ID: 5e5cd2288147
Revises: 0c3539f66339
Create Date: 2025-02-19 14:52:39.126196+00:00

"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "5e5cd2288147"
down_revision: Union[str, None] = "0c3539f66339"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Begin transaction
op.execute("BEGIN TRANSACTION;")

# Update the matcher types. We need to do this every time we change the matcher types.
# in /muxing/models.py
op.execute(
"""
UPDATE muxes
SET matcher_type = 'fim', matcher_blob = ''
WHERE matcher_type = 'request_type_match' AND matcher_blob = 'fim';
"""
)
op.execute(
"""
UPDATE muxes
SET matcher_type = 'chat', matcher_blob = ''
WHERE matcher_type = 'request_type_match' AND matcher_blob = 'chat';
"""
)
op.execute(
"""
UPDATE muxes
SET matcher_type = 'catch_all'
WHERE matcher_type = 'filename_match' AND matcher_blob != '';
"""
)

# Finish transaction
op.execute("COMMIT;")


def downgrade() -> None:
# Begin transaction
op.execute("BEGIN TRANSACTION;")

op.execute(
"""
UPDATE muxes
SET matcher_blob = 'fim', matcher_type = 'request_type_match'
WHERE matcher_type = 'fim';
"""
)
op.execute(
"""
UPDATE muxes
SET matcher_blob = 'chat', matcher_type = 'request_type_match'
WHERE matcher_type = 'chat';
"""
)
op.execute(
"""
UPDATE muxes
SET matcher_type = 'filename_match', matcher_blob = 'catch_all'
WHERE matcher_type = 'catch_all';
"""
)

# Finish transaction
op.execute("COMMIT;")
35 changes: 28 additions & 7 deletions src/codegate/muxing/models.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
from enum import Enum
from typing import Optional
from typing import Optional, Self

import pydantic

from codegate.clients.clients import ClientType
from codegate.db.models import MuxRule as DBMuxRule


class MuxMatcherType(str, Enum):
"""
Represents the different types of matchers we support.

The 3 rules present match filenames and request types. They're used in conjunction with the
matcher field in the MuxRule model.
E.g.
- catch_all and match: None -> Always match
- fim and match: requests.py -> Match the request if the filename is requests.py and FIM
- chat and match: None -> Match the request if it's a chat request
- chat and match: .js -> Match the request if the filename has a .js extension and is chat

NOTE: Removing or updating fields from this enum will require a migration.
"""

# Always match this prompt
catch_all = "catch_all"
# Match based on the filename. It will match if there is a filename
# in the request that matches the matcher either extension or full name (*.py or main.py)
filename_match = "filename_match"
# Match based on the request type. It will match if the request type
# matches the matcher (e.g. FIM or chat)
request_type_match = "request_type_match"
# Match based on fim request type. It will match if the request type is fim
fim = "fim"
# Match based on chat request type. It will match if the request type is chat
chat = "chat"


class MuxRule(pydantic.BaseModel):
Expand All @@ -36,6 +45,18 @@ class MuxRule(pydantic.BaseModel):
# this depends on the matcher type.
matcher: Optional[str] = None

@classmethod
def from_db_mux_rule(cls, db_mux_rule: DBMuxRule) -> Self:
"""
Convert a DBMuxRule to a MuxRule.
"""
return MuxRule(
provider_id=db_mux_rule.id,
model=db_mux_rule.provider_model_name,
matcher_type=db_mux_rule.matcher_type,
matcher=db_mux_rule.matcher_blob,
)


class ThingToMatchMux(pydantic.BaseModel):
"""
Expand Down
5 changes: 4 additions & 1 deletion src/codegate/muxing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ async def _get_model_route(
# Try to get a model route for the active workspace
model_route = await mux_registry.get_match_for_active_workspace(thing_to_match)
return model_route
except rulematcher.MuxMatchingError as e:
logger.exception(f"Error matching rule and getting model route: {e}")
raise HTTPException(detail=str(e), status_code=404)
except Exception as e:
logger.error(f"Error getting active workspace muxes: {e}")
logger.exception(f"Error getting active workspace muxes: {e}")
raise HTTPException(detail=str(e), status_code=404)

def _setup_routes(self):
Expand Down
94 changes: 48 additions & 46 deletions src/codegate/muxing/rulematcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
_singleton_lock = Lock()


class MuxMatchingError(Exception):
"""An exception for muxing matching errors."""

pass


async def get_muxing_rules_registry():
"""Returns a singleton instance of the muxing rules registry."""

Expand Down Expand Up @@ -48,9 +54,9 @@ def __init__(
class MuxingRuleMatcher(ABC):
"""Base class for matching muxing rules."""

def __init__(self, route: ModelRoute, matcher_blob: str):
def __init__(self, route: ModelRoute, mux_rule: mux_models.MuxRule):
self._route = route
self._matcher_blob = matcher_blob
self._mux_rule = mux_rule

@abstractmethod
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
Expand All @@ -67,32 +73,24 @@ class MuxingMatcherFactory:
"""Factory for creating muxing matchers."""

@staticmethod
def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
"""Create a muxing matcher for the given endpoint and model."""

factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = {
mux_models.MuxMatcherType.catch_all: CatchAllMuxingRuleMatcher,
mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher,
mux_models.MuxMatcherType.request_type_match: RequestTypeMuxingRuleMatcher,
mux_models.MuxMatcherType.catch_all: RequestTypeAndFileMuxingRuleMatcher,
mux_models.MuxMatcherType.fim: RequestTypeAndFileMuxingRuleMatcher,
mux_models.MuxMatcherType.chat: RequestTypeAndFileMuxingRuleMatcher,
Comment thread
aponcedeleonch marked this conversation as resolved.
Outdated
}

try:
# Initialize the MuxingRuleMatcher
return factory[mux_rule.matcher_type](route, mux_rule.matcher_blob)
mux_rule = mux_models.MuxRule.from_db_mux_rule(db_mux_rule)
return factory[mux_rule.matcher_type](route, mux_rule)
except KeyError:
raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}")


class CatchAllMuxingRuleMatcher(MuxingRuleMatcher):
"""A catch all muxing rule matcher."""

def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
logger.info("Catch all rule matched")
return True


class FileMuxingRuleMatcher(MuxingRuleMatcher):
"""A file muxing rule matcher."""
class RequestTypeAndFileMuxingRuleMatcher(MuxingRuleMatcher):

def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]:
"""
Expand All @@ -103,47 +101,51 @@ def _extract_request_filenames(self, detected_client: ClientType, data: dict) ->
return body_extractor.extract_unique_filenames(data)
except BodyCodeSnippetExtractorError as e:
logger.error(f"Error extracting filenames from request: {e}")
return set()
raise MuxMatchingError("Error extracting filenames from request")

def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
def _is_matcher_in_filenames(self, detected_client: ClientType, data: dict) -> bool:
"""
Retun True if there is a filename in the request that matches the matcher_blob.
The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py).
Check if the matcher is in the request filenames.
"""
# If there is no matcher_blob, we don't match
if not self._matcher_blob:
return False
filenames_to_match = self._extract_request_filenames(
thing_to_match.client_type, thing_to_match.body
# Empty matcher_blob means we match everything
if not self._mux_rule.matcher:
return True
filenames_to_match = self._extract_request_filenames(detected_client, data)
# _mux_rule.matcher can be a filename or a file extension. We match if any of the filenames
# match the rule.
is_filename_match = any(
self._mux_rule.matcher == filename or filename.endswith(self._mux_rule.matcher)
for filename in filenames_to_match
)
is_filename_match = any(self._matcher_blob in filename for filename in filenames_to_match)
if is_filename_match:
logger.info(
"Filename rule matched", filenames=filenames_to_match, matcher=self._matcher_blob
)
return is_filename_match


class RequestTypeMuxingRuleMatcher(MuxingRuleMatcher):
"""A catch all muxing rule matcher."""
def _is_request_type_match(self, is_fim_request: bool) -> bool:
"""
Check if the request type matches the MuxMatcherType.
"""
# Catch all rule matches both chat and FIM requests
if self._mux_rule.matcher_type == mux_models.MuxMatcherType.catch_all:
return True
incoming_request_type = "fim" if is_fim_request else "chat"
if incoming_request_type == self._mux_rule.matcher_type:
return True
return False

def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""
Return True if the request type matches the matcher_blob.
The matcher_blob is either "fim" or "chat".
Return True if the matcher is in one of the request filenames and
if the request type matches the MuxMatcherType.
"""
# If there is no matcher_blob, we don't match
if not self._matcher_blob:
return False
incoming_request_type = "fim" if thing_to_match.is_fim_request else "chat"
is_request_type_match = self._matcher_blob == incoming_request_type
if is_request_type_match:
is_rule_matched = self._is_matcher_in_filenames(
thing_to_match.client_type, thing_to_match.body
) and self._is_request_type_match(thing_to_match.is_fim_request)
if is_rule_matched:
logger.info(
"Request type rule matched",
matcher=self._matcher_blob,
request_type=incoming_request_type,
"Request type and rule matched",
matcher=self._mux_rule.matcher,
is_fim_request=thing_to_match.is_fim_request,
)
return is_request_type_match
return is_rule_matched


class MuxingRulesinWorkspaces:
Expand Down
Loading