1+ from fastapi import HTTPException
2+ from typing import List , Literal
13from .model_interfaces import DiffgramFile , Prediction , Attribute , Instance
24class DiffgramBaseModel ():
3- def init (self ):
4- pass
5+ diffgram_allowed_types = Literal ['image' , 'frame' , 'video' , 'text' , 'audio' , 'sensor_fusion' , 'geospatial' ]
6+
7+ def __init__ (
8+ self ,
9+ allowed_types : list = None
10+ ):
11+ if allowed_types is not None :
12+ if not isinstance (allowed_types , list ):
13+ raise ValueError ('allowed_types must be of type list' )
14+
15+ for allowed_type in allowed_types :
16+ if allowed_type not in self .diffgram_allowed_types :
17+ raise ValueError (f"{ allowed_type } is not valid Diffgram file type" )
18+
19+ self .allowed_types = allowed_types
520
621 def infere (self , file : DiffgramFile ) -> Prediction :
722 raise NotImplementedError
@@ -14,7 +29,11 @@ def ping(self):
1429
1530 def serve (self , app ):
1631 @app .post ("/infere" )
17- async def predict (file : DiffgramFile ):
32+ async def infere_route (file : DiffgramFile ):
33+ if self .allowed_types is not None :
34+ if file .type not in self .allowed_types :
35+ raise HTTPException (status_code = 404 , detail = f"This model does not support { file .type } files" )
36+
1837 predictions = self .infere (file )
1938
2039 if not isinstance (predictions , Prediction ):
@@ -38,7 +57,7 @@ async def predict(file: DiffgramFile):
3857 }
3958
4059 @app .get ("/get_schema" )
41- async def schema ():
60+ async def get_schema_route ():
4261 return {
4362 "message" : "Get schema here"
4463 }
0 commit comments