-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathentities.py
More file actions
106 lines (91 loc) · 3.26 KB
/
entities.py
File metadata and controls
106 lines (91 loc) · 3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import logging
from fastapi import APIRouter, Depends, Query
from feast.api.registry.rest.rest_utils import (
aggregate_across_projects,
create_grpc_pagination_params,
create_grpc_sorting_params,
get_object_relationships,
get_pagination_params,
get_relationships_for_objects,
get_sorting_params,
grpc_call,
)
from feast.protos.feast.registry import RegistryServer_pb2
logger = logging.getLogger(__name__)
def get_entity_router(grpc_handler) -> APIRouter:
router = APIRouter()
@router.get("/entities")
def list_entities(
project: str = Query(...),
allow_cache: bool = Query(default=True),
include_relationships: bool = Query(
False, description="Include relationships for each entity"
),
pagination_params: dict = Depends(get_pagination_params),
sorting_params: dict = Depends(get_sorting_params),
):
req = RegistryServer_pb2.ListEntitiesRequest(
project=project,
allow_cache=allow_cache,
pagination=create_grpc_pagination_params(pagination_params),
sorting=create_grpc_sorting_params(sorting_params),
)
response = grpc_call(grpc_handler.ListEntities, req)
entities = response.get("entities", [])
result = {
"entities": entities,
"pagination": response.get("pagination", {}),
}
if include_relationships:
relationships = get_relationships_for_objects(
grpc_handler, entities, "entity", project, allow_cache
)
result["relationships"] = relationships
return result
@router.get("/entities/all")
def list_all_entities(
allow_cache: bool = Query(default=True),
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100),
sort_by: str = Query(None),
sort_order: str = Query("asc"),
include_relationships: bool = Query(
False, description="Include relationships for each entity"
),
):
return aggregate_across_projects(
grpc_handler=grpc_handler,
list_method=grpc_handler.ListEntities,
request_cls=RegistryServer_pb2.ListEntitiesRequest,
response_key="entities",
object_type="entity",
allow_cache=allow_cache,
page=page,
limit=limit,
sort_by=sort_by,
sort_order=sort_order,
include_relationships=include_relationships,
)
@router.get("/entities/{name}")
def get_entity(
name: str,
project: str = Query(...),
include_relationships: bool = Query(
False, description="Include relationships for this entity"
),
allow_cache: bool = Query(default=True),
):
req = RegistryServer_pb2.GetEntityRequest(
name=name,
project=project,
allow_cache=allow_cache,
)
entity = grpc_call(grpc_handler.GetEntity, req)
result = entity
if include_relationships:
relationships = get_object_relationships(
grpc_handler, "entity", name, project, allow_cache
)
result["relationships"] = relationships
return result
return router