Skip to content

Commit 4cdc8aa

Browse files
authored
Merge pull request google#41 from google/force-index
Add force_index condition
2 parents 0133ac0 + 774df9c commit 4cdc8aa

9 files changed

Lines changed: 73 additions & 14 deletions

File tree

spanner_orm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
Timestamp = field.Timestamp
4545

4646
equal_to = condition.equal_to
47+
force_index = condition.force_index
4748
greater_than = condition.greater_than
4849
greater_than_or_equal_to = condition.greater_than_or_equal_to
4950
includes = condition.includes

spanner_orm/admin/metadata.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,12 @@ def indexes(cls):
113113
indexes = collections.defaultdict(dict)
114114
for schema in index_schemas:
115115
key = (schema.table_name, schema.index_name)
116-
indexes[schema.table_name][schema.index_name] = index.Index(
116+
new_index = index.Index(
117117
index_columns[key],
118-
schema.index_name,
119118
parent=schema.parent_table_name,
120119
null_filtered=schema.is_null_filtered,
121120
unique=schema.is_unique,
122121
storing_columns=storing_columns[key])
122+
new_index.name = schema.index_name
123+
indexes[schema.table_name][schema.index_name] = new_index
123124
return indexes

spanner_orm/admin/update.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,10 @@ def validate(self):
289289
if not self._model:
290290
raise error.SpannerError('Table {} does not exist'.format(self._table))
291291

292-
if self._index not in self._model.indexes:
292+
db_index = self._model.indexes.get(self._index)
293+
if not db_index:
293294
raise error.SpannerError('Index {} does not exist'.format(self._index))
294-
if self._index == index.Index.PRIMARY_INDEX:
295+
if db_index.primary_index:
295296
raise error.SpannerError('Index {} is the primary index'.format(
296297
self._index))
297298

spanner_orm/condition.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424

2525
class Segment(enum.Enum):
2626
"""The segment of the SQL query that a Condition belongs to."""
27-
WHERE = 1
28-
ORDER_BY = 2
29-
LIMIT = 3
30-
JOIN = 4
27+
FROM = 1
28+
JOIN = 2
29+
WHERE = 3
30+
ORDER_BY = 4
31+
LIMIT = 5
3132

3233

3334
class Condition(abc.ABC):
@@ -118,6 +119,35 @@ def _validate(self, model):
118119
assert (origin.field_type() == dest.field_type() and
119120
origin.nullable() == dest.nullable())
120121

122+
class ForceIndexCondition(Condition):
123+
"""Used to indicate which index should be used in a Spanner query."""
124+
125+
def __init__(self, name):
126+
super().__init__()
127+
self.name = name
128+
self.index = None
129+
130+
def bind(self, model):
131+
super().bind(model)
132+
self.index = self.model.indexes[self.name]
133+
134+
def _params(self):
135+
return {}
136+
137+
@staticmethod
138+
def segment():
139+
return Segment.FROM
140+
141+
def _sql(self):
142+
return '@{{FORCE_INDEX={}}}'.format(self.index.name)
143+
144+
def _types(self):
145+
return {}
146+
147+
def _validate(self, model):
148+
assert self.name in model.indexes
149+
assert not model.indexes[self.name].primary
150+
121151

122152
class IncludesCondition(Condition):
123153
"""Used to include related models via a relation in a Spanner query."""
@@ -439,6 +469,10 @@ def equal_to(column, value):
439469
return EqualityCondition(column, value)
440470

441471

472+
def force_index(index):
473+
return ForceIndexCondition(index)
474+
475+
442476
def greater_than(column, value):
443477
return ComparisonCondition('>', column, value)
444478

spanner_orm/index.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@ class Index(object):
2020

2121
def __init__(self,
2222
columns,
23-
name,
2423
parent=None,
2524
null_filtered=False,
2625
unique=False,
2726
storing_columns=None):
2827
assert len(columns) > 0, 'An index must have at least one column'
2928
self.columns = columns
30-
self.name = name
29+
self.name = None
3130
self.parent = parent
3231
self.null_filtered = null_filtered
3332
self.unique = unique
3433
self.storing_columns = storing_columns or []
34+
35+
@property
36+
def primary(self):
37+
return self.name == self.PRIMARY_INDEX

spanner_orm/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def finalize(self):
6060

6161
if index.Index.PRIMARY_INDEX not in self.indexes:
6262
primary_keys = [f.name for f in sorted_fields if f.primary_key()]
63-
primary_index = index.Index(primary_keys, index.Index.PRIMARY_INDEX)
64-
self.indexes[primary_index.name] = primary_index
63+
primary_index = index.Index(primary_keys)
64+
primary_index.name = index.Index.PRIMARY_INDEX
65+
self.indexes[index.Index.PRIMARY_INDEX] = primary_index
6566
self.primary_keys = self.indexes[index.Index.PRIMARY_INDEX].columns
6667

6768
self.columns = [f.name for f in sorted_fields]
@@ -87,7 +88,7 @@ def add_relation(self, name, new_relation):
8788

8889
def add_index(self, name, new_index):
8990
new_index.name = name
90-
self.relations[name] = new_index
91+
self.indexes[name] = new_index
9192

9293

9394
class ModelBase(type):

spanner_orm/query.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,17 @@ def _select(self):
7777

7878
def _from(self):
7979
"""Processes the FROM segment of the SQL query."""
80-
return (' FROM {}'.format(self._model.table), {}, {})
80+
froms = self._segments(condition.Segment.FROM)
81+
index_sql = ''
82+
if froms:
83+
if len(froms) != 1:
84+
raise error.SpannerError('Only one index can be specified')
85+
force_index = froms[0]
86+
index_sql = force_index.sql()
87+
88+
sql = ' FROM {}{}'.format(self._model.table, index_sql)
89+
90+
return (sql, {}, {})
8191

8292
def _where(self):
8393
"""Processes the WHERE segment of the SQL query."""

spanner_orm/tests/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Models used by unit tests."""
1616

1717
from spanner_orm import field
18+
from spanner_orm import index
1819
from spanner_orm import model
1920
from spanner_orm import relationship
2021

@@ -57,3 +58,5 @@ class UnittestModel(model.Model):
5758
string_2 = field.Field(field.String, nullable=True)
5859
timestamp = field.Field(field.Timestamp)
5960
string_array = field.Field(field.StringArray, nullable=True)
61+
62+
test_index = index.Index(['string_2'])

spanner_orm/tests/query_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ def test_only_one_limit_allowed(self):
149149
with self.assertRaises(error.SpannerError):
150150
self.select(condition.limit(2), condition.limit(2))
151151

152+
def test_force_index(self):
153+
select_query = self.select(condition.force_index('test_index'))
154+
expected_sql = 'FROM table@{FORCE_INDEX=test_index}'
155+
self.assertEndsWith(select_query.sql(), expected_sql)
156+
152157
def includes(self, relation, *conditions):
153158
include_condition = condition.includes(relation, list(conditions))
154159
return query.SelectQuery(models.ChildTestModel, [include_condition])

0 commit comments

Comments
 (0)