Skip to content

Commit 01de0ff

Browse files
Add typing to the model wrapper
1 parent a92ebed commit 01de0ff

5 files changed

Lines changed: 39 additions & 10 deletions

File tree

sdk/diffgram.egg-info/SOURCES.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ diffgram/label/test_label.py
3636
diffgram/member/__init__.py
3737
diffgram/models/__init__.py
3838
diffgram/models/base_model.py
39+
diffgram/models/model_interfaces.py
3940
diffgram/pytorch_diffgram/__init__.py
4041
diffgram/pytorch_diffgram/diffgram_pytorch_dataset.py
4142
diffgram/regular/__init__.py

sdk/diffgram/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
from diffgram.core.core import Project
66
from diffgram.file.file import File
77
from diffgram.task.task import Task
8-
from diffgram.models.base_model import DiffgramBaseModel
8+
from diffgram.models.base_model import DiffgramBaseModel, Instance

sdk/diffgram/models/base_model.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from fastapi import FastAPI
2-
1+
from .model_interfaces import DiffgramFile, Instance
2+
from typing import List
33
class DiffgramBaseModel():
44
def init(self):
55
pass
66

7-
def infere(self):
7+
def infere(self, file: DiffgramFile) -> List[Instance]:
88
raise NotImplementedError
99

1010
def get_schema(self):
@@ -14,10 +14,21 @@ def ping(self):
1414
pass
1515

1616
def serve(self, app):
17-
@app.get("/infere")
18-
async def predict():
17+
@app.post("/infere")
18+
async def predict(file: DiffgramFile):
19+
predictions = self.infere(file)
20+
21+
if not isinstance(predictions, List):
22+
raise ValueError('infere should return List of type Instance')
23+
24+
for prediction in predictions:
25+
res = isinstance(prediction, Instance)
26+
if not res:
27+
raise ValueError('infere should return List of type Instance')
28+
1929
return {
20-
"message": "Infere route"
30+
"file": file,
31+
"predictions": predictions
2132
}
2233

2334
@app.get("/get_schema")
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from pydantic import BaseModel
2+
from typing import List
3+
4+
class DiffgramFile(BaseModel):
5+
id: int
6+
type: str
7+
8+
class Instance(BaseModel):
9+
id: int

sdk/samples/model_server.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1-
from diffgram import DiffgramBaseModel
1+
from diffgram import DiffgramBaseModel, Instance
22
from fastapi import FastAPI
33

4-
app = FastAPI()
4+
class MyTestModel(DiffgramBaseModel):
5+
def infere(self, file):
6+
return [
7+
Instance(id=1)
8+
]
9+
10+
def get_schema(self):
11+
return super().get_schema()
512

6-
DiffgramBaseModel().serve(app)
13+
app = FastAPI()
14+
MyTestModel().serve(app)

0 commit comments

Comments
 (0)