Skip to content

Commit c816c47

Browse files
WIP: add typechecking for allowed files
1 parent 34ed4fa commit c816c47

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

sdk/diffgram/models/base_model.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,22 @@
1+
from fastapi import HTTPException
2+
from typing import List, Literal
13
from .model_interfaces import DiffgramFile, Prediction, Attribute, Instance
24
class DiffgramBaseModel():
3-
def init(self):
4-
pass
5+
diffgram_allowed_types = Literal['image', 'frame', 'video', 'text', 'audio', 'sensor_fusion', 'geospatial']
6+
7+
def __init__(
8+
self,
9+
allowed_types: list = None
10+
):
11+
if allowed_types is not None:
12+
if not isinstance(allowed_types, list):
13+
raise ValueError('allowed_types must be of type list')
14+
15+
for allowed_type in allowed_types:
16+
if allowed_type not in self.diffgram_allowed_types:
17+
raise ValueError(f"{allowed_type} is not valid Diffgram file type")
18+
19+
self.allowed_types = allowed_types
520

621
def infere(self, file: DiffgramFile) -> Prediction:
722
raise NotImplementedError
@@ -14,7 +29,11 @@ def ping(self):
1429

1530
def serve(self, app):
1631
@app.post("/infere")
17-
async def predict(file: DiffgramFile):
32+
async def infere_route(file: DiffgramFile):
33+
if self.allowed_types is not None:
34+
if file.type not in self.allowed_types:
35+
raise HTTPException(status_code=404, detail=f"This model does not support {file.type} files")
36+
1837
predictions = self.infere(file)
1938

2039
if not isinstance(predictions, Prediction):
@@ -38,7 +57,7 @@ async def predict(file: DiffgramFile):
3857
}
3958

4059
@app.get("/get_schema")
41-
async def schema():
60+
async def get_schema_route():
4261
return {
4362
"message": "Get schema here"
4463
}

sdk/diffgram/models/model_interfaces.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from pydantic import BaseModel
2-
from typing import Optional
3-
from typing import List
4-
2+
from typing import Optional, List, Literal
53
class DiffgramFile(BaseModel):
64
id: int
7-
type: str
5+
type: Literal['image', 'frame', 'video', 'text', 'audio', 'sensor_fusion', 'geospatial']
86

97
class Attribute(BaseModel):
108
id: int

sdk/samples/model_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ def get_schema(self):
1111
return super().get_schema()
1212

1313
app = FastAPI()
14-
MyTestModel().serve(app)
14+
MyTestModel(allowed_types=["image"]).serve(app)

0 commit comments

Comments
 (0)