Skip to content

Commit 49387d2

Browse files
committed
django: fix pattern lookups with bilateral transform
fixes #419
1 parent fccdacc commit 49387d2

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

packages/django-google-spanner/django_spanner/features.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
157157
'model_fields.test_decimalfield.DecimalFieldTests.test_roundtrip_with_trailing_zeros',
158158
# No CHECK constraints in Spanner.
159159
'model_fields.test_integerfield.PositiveIntegerFieldTests.test_negative_values',
160-
# contains lookup crashes with bilateral transform:
161-
# https://github.com/googleapis/python-spanner-django/issues/419
162-
'custom_lookups.tests.BilateralTransformTests.test_bilateral_upper',
163160
# Spanner doesn't support the variance the standard deviation database
164161
# functions:
165162
'aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_numerical_aggregates',

packages/django-google-spanner/django_spanner/lookups.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ def contains(self, compiler, connection):
1717
lhs_sql, params = self.process_lhs(compiler, connection)
1818
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
1919
params.extend(rhs_params)
20-
rhs_sql = self.get_rhs_op(connection, rhs_sql)
2120
is_icontains = self.lookup_name.startswith('i')
2221
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
22+
rhs_sql = self.get_rhs_op(connection, rhs_sql)
2323
# Chop the leading and trailing percent signs that Django adds to the
2424
# param since this isn't a LIKE query as Django expects.
2525
params[0] = params[0][1:-1]
@@ -29,12 +29,11 @@ def contains(self, compiler, connection):
2929
# rhs_sql is REGEXP_CONTAINS(%s, %%s), and lhs_sql is the column name.
3030
return rhs_sql % lhs_sql, params
3131
else:
32-
# rhs is the expression/column to use as the base of the regular
32+
# rhs_sql is the expression/column to use as the base of the regular
3333
# expression.
34-
rhs = compiler.compile(self.rhs)[0]
3534
if is_icontains:
36-
rhs = "CONCAT('(?i)', " + rhs + ")"
37-
return 'REGEXP_CONTAINS(%s, %s)' % (lhs_sql, connection.pattern_esc.format(rhs)), params
35+
rhs_sql = "CONCAT('(?i)', " + rhs_sql + ")"
36+
return 'REGEXP_CONTAINS(%s, %s)' % (lhs_sql, connection.pattern_esc.format(rhs_sql)), params
3837

3938

4039
def iexact(self, compiler, connection):
@@ -59,36 +58,35 @@ def regex(self, compiler, connection):
5958
lhs_sql, params = self.process_lhs(compiler, connection)
6059
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
6160
params.extend(rhs_params)
62-
rhs_sql = self.get_rhs_op(connection, rhs_sql)
6361
is_iregex = self.lookup_name.startswith('i')
6462
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
63+
rhs_sql = self.get_rhs_op(connection, rhs_sql)
6564
if is_iregex:
6665
params[0] = '(?i)%s' % params[0]
6766
else:
6867
params[0] = str(params[0])
6968
# rhs_sql is REGEXP_CONTAINS(%s, %%s), and lhs_sql is the column name.
7069
return rhs_sql % lhs_sql, params
7170
else:
72-
# rhs is the expression/column to use as the base of the regular
71+
# rhs_sql is the expression/column to use as the base of the regular
7372
# expression.
74-
rhs = compiler.compile(self.rhs)[0]
7573
if is_iregex:
76-
rhs = "CONCAT('(?i)', " + rhs + ")"
77-
return 'REGEXP_CONTAINS(%s, %s)' % (lhs_sql, rhs), params
74+
rhs_sql = "CONCAT('(?i)', " + rhs_sql + ")"
75+
return 'REGEXP_CONTAINS(%s, %s)' % (lhs_sql, rhs_sql), params
7876

7977

8078
def startswith_endswith(self, compiler, connection):
8179
"""startswith, endswith, istartswith, and iendswith lookups."""
8280
lhs_sql, params = self.process_lhs(compiler, connection)
8381
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
8482
params.extend(rhs_params)
85-
rhs_sql = self.get_rhs_op(connection, rhs_sql)
8683
is_startswith = 'startswith' in self.lookup_name
8784
is_endswith = 'endswith' in self.lookup_name
8885
is_insensitive = self.lookup_name.startswith('i')
8986
# Chop the leading (endswith) or trailing (startswith) percent sign that
9087
# Django adds to the param since this isn't a LIKE query as Django expects.
9188
if self.rhs_is_direct_value() and params and not self.bilateral_transforms:
89+
rhs_sql = self.get_rhs_op(connection, rhs_sql)
9290
if is_endswith:
9391
params[0] = str(params[0][1:]) + '$'
9492
else:
@@ -99,14 +97,14 @@ def startswith_endswith(self, compiler, connection):
9997
# rhs_sql is REGEXP_CONTAINS(%s, %%s), and lhs_sql is the column name.
10098
return rhs_sql % lhs_sql, params
10199
else:
102-
# rhs is the expression/column to use as the base of the regular
100+
# rhs_sql is the expression/column to use as the base of the regular
103101
# expression.
104102
sql = "CONCAT('"
105103
if is_startswith:
106104
sql += '^'
107105
if is_insensitive:
108106
sql += '(?i)'
109-
sql += "', " + compiler.compile(self.rhs)[0]
107+
sql += "', " + rhs_sql
110108
if is_endswith:
111109
sql += ", '$'"
112110
sql += ")"

0 commit comments

Comments
 (0)