forked from Netflix/dispatch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatabase.py
More file actions
191 lines (143 loc) · 5.91 KB
/
Copy pathdatabase.py
File metadata and controls
191 lines (143 loc) · 5.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import re
from typing import Any, List
from itertools import groupby
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import Query, sessionmaker
from sqlalchemy_filters import apply_pagination, apply_sort, apply_filters
from sqlalchemy_searchable import make_searchable
from sqlalchemy_searchable import search as search_db
from starlette.requests import Request
from dispatch.common.utils.composite_search import CompositeSearch
from .config import SQLALCHEMY_DATABASE_URI
engine = create_engine(str(SQLALCHEMY_DATABASE_URI))
SessionLocal = sessionmaker(bind=engine)
def resolve_table_name(name):
"""Resolves table names to their mapped names."""
names = re.split("(?=[A-Z])", name) # noqa
return "_".join([x.lower() for x in names if x])
class CustomBase:
@declared_attr
def __tablename__(self):
return resolve_table_name(self.__name__)
Base = declarative_base(cls=CustomBase)
make_searchable(Base.metadata)
def get_db(request: Request):
return request.state.db
def get_model_name_by_tablename(table_fullname: str) -> str:
"""Returns the model name of a given table."""
return get_class_by_tablename(table_fullname=table_fullname).__name__
def get_class_by_tablename(table_fullname: str) -> Any:
"""Return class reference mapped to table."""
mapped_name = resolve_table_name(table_fullname)
for c in Base._decl_class_registry.values():
if hasattr(c, "__table__") and c.__table__.fullname == mapped_name:
return c
raise Exception(f"Incorrect tablename '{mapped_name}'. Check the name of your model.")
def paginate(query: Query, page: int, items_per_page: int):
# Never pass a negative OFFSET value to SQL.
offset_adj = 0 if page <= 0 else page - 1
items = query.limit(items_per_page).offset(offset_adj * items_per_page).all()
total = query.order_by(None).count()
return items, total
def composite_search(*, db_session, query_str: str, models: List[Base]):
"""Perform a multi-table search based on the supplied query."""
s = CompositeSearch(db_session, models)
q = s.build_query(query_str, sort=True)
return s.search(query=q)
def search(*, db_session, query_str: str, model: str):
"""Perform a search based on the query."""
q = db_session.query(get_class_by_tablename(model))
return search_db(q, query_str, sort=True)
def create_filter_spec(model, fields, ops, values):
"""Creates a filter spec."""
filters = []
if fields and ops and values:
for field, op, value in zip(fields, ops, values):
# we have a complex field, we may need to join
if "." in field:
complex_model, complex_field = field.split(".")
filters.append(
{
"model": get_model_name_by_tablename(complex_model),
"field": complex_field,
"op": op,
"value": value,
}
)
else:
filters.append({"model": model, "field": field, "op": op, "value": value})
filter_spec = []
# group by field (or for same fields and for different fields)
data = sorted(filters, key=lambda x: x["model"])
for k, g in groupby(data, key=lambda x: x["model"]):
filter_spec.append({"or": list(g)})
if filter_spec:
return {"and": filter_spec}
return filter_spec
def create_sort_spec(model, sort_by, descending):
"""Creates sort_spec."""
sort_spec = []
if sort_by and descending:
for field, direction in zip(sort_by, descending):
direction = "desc" if direction else "asc"
# we have a complex field, we may need to join
if "." in field:
complex_model, complex_field = field.split(".")
sort_spec.append(
{
"model": get_model_name_by_tablename(complex_model),
"field": complex_field,
"direction": direction,
}
)
else:
sort_spec.append({"model": model, "field": field, "direction": direction})
return sort_spec
def get_all(*, db_session, model):
"""Fetches a query object based on the model class name."""
return db_session.query(get_class_by_tablename(model))
def join_required_attrs(query, model, join_attrs, fields):
"""Determines which attrs (if any) require a join."""
if not fields:
return query
if not join_attrs:
return query
for field, attr in join_attrs:
for f in fields:
if field in f:
query = query.join(getattr(model, attr))
return query
def search_filter_sort_paginate(
db_session,
model,
query_str: str = None,
page: int = 1,
items_per_page: int = 5,
sort_by: List[str] = None,
descending: List[bool] = None,
fields: List[str] = None,
ops: List[str] = None,
values: List[str] = None,
join_attrs: List[str] = None,
):
"""Common functionality for searching, filtering and sorting"""
model_cls = get_class_by_tablename(model)
if query_str:
query = search(db_session=db_session, query_str=query_str, model=model)
else:
query = db_session.query(model_cls)
query = join_required_attrs(query, model_cls, join_attrs, fields)
filter_spec = create_filter_spec(model, fields, ops, values)
query = apply_filters(query, filter_spec)
sort_spec = create_sort_spec(model, sort_by, descending)
query = apply_sort(query, sort_spec)
if items_per_page == -1:
items_per_page = None
query, pagination = apply_pagination(query, page_number=page, page_size=items_per_page)
return {
"items": query.all(),
"itemsPerPage": pagination.page_size,
"page": pagination.page_number,
"total": pagination.total_results,
}