Skip to content

Commit 803115d

Browse files
authored
Allow to filter on an aggregate function (tortoise#362)
1 parent fc21bed commit 803115d

5 files changed

Lines changed: 70 additions & 3 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Changelog
1212

1313
0.16.8
1414
------
15+
- Allow `Q` expression to function with `_filter` parameter
1516
- Add ``group by`` support
1617
- Fixed regression where ``GROUP BY`` class is missing for an aggregate with a specified order.
1718

examples/functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from tortoise import Tortoise, fields, run_async
22
from tortoise.functions import Coalesce, Count, Length, Lower, Min, Sum, Trim, Upper
33
from tortoise.models import Model
4+
from tortoise.query_utils import Q
45

56

67
class Tournament(Model):
@@ -57,6 +58,11 @@ async def run():
5758
await event.participants.add(participants[0], participants[1])
5859

5960
print(await Tournament.all().annotate(events_count=Count("events")).filter(events_count__gte=1))
61+
print(
62+
await Tournament.all()
63+
.annotate(events_count_with_filter=Count("events", _filter=Q(name="New Tournament")))
64+
.filter(events_count_with_filter__gte=1)
65+
)
6066

6167
print(await Event.filter(id=event.id).first().annotate(lowest_team_id=Min("participants__id")))
6268

tests/test_aggregation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from tortoise.contrib import test
33
from tortoise.exceptions import ConfigurationError
44
from tortoise.functions import Avg, Count, Min, Sum
5+
from tortoise.query_utils import Q
56

67

78
class TestAggregation(test.TestCase):
@@ -91,6 +92,26 @@ async def test_aggregation_with_distinct(self):
9192
self.assertEqual(school_with_distinct_count.events_count, 3)
9293
self.assertEqual(school_with_distinct_count.minrelations_count, 2)
9394

95+
async def test_aggregation_with_filter(self):
96+
tournament = await Tournament.create(name="New Tournament")
97+
await Event.create(name="Event 1", tournament=tournament)
98+
await Event.create(name="Event 2", tournament=tournament)
99+
await Event.create(name="Event 3", tournament=tournament)
100+
101+
tournament_with_filter = (
102+
await Tournament.all()
103+
.annotate(
104+
all=Count("events", _filter=Q(name="New Tournament")),
105+
one=Count("events", _filter=Q(events__name="Event 1")),
106+
two=Count("events", _filter=Q(events__name__not="Event 1")),
107+
)
108+
.first()
109+
)
110+
111+
self.assertEqual(tournament_with_filter.all, 3)
112+
self.assertEqual(tournament_with_filter.one, 1)
113+
self.assertEqual(tournament_with_filter.two, 2)
114+
94115
async def test_group_aggregation(self):
95116
author = await Author.create(name="Some One")
96117
await Book.create(name="First!", author=author, rating=4)

tests/test_source_field.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tortoise.contrib import test
99
from tortoise.expressions import F
1010
from tortoise.functions import Coalesce, Count, Length, Lower, Trim, Upper
11+
from tortoise.query_utils import Q
1112

1213

1314
class StraightFieldTests(test.TestCase):
@@ -181,6 +182,25 @@ async def test_function(self):
181182
obj2 = await self.model.get(eyedee=obj1.eyedee)
182183
self.assertEqual(obj2.chars, "aaa")
183184

185+
async def test_aggregation_with_filter(self):
186+
obj1 = await self.model.create(chars="aaa")
187+
await self.model.create(chars="bbb", fk=obj1)
188+
await self.model.create(chars="ccc", fk=obj1)
189+
190+
obj = (
191+
await self.model.filter(chars="aaa")
192+
.annotate(
193+
all=Count("fkrev", _filter=Q(chars="aaa")),
194+
one=Count("fkrev", _filter=Q(fkrev__chars="bbb")),
195+
no=Count("fkrev", _filter=Q(fkrev__chars="aaa")),
196+
)
197+
.first()
198+
)
199+
200+
self.assertEqual(obj.all, 2)
201+
self.assertEqual(obj.one, 1)
202+
self.assertEqual(obj.no, 0)
203+
184204
async def test_filter_by_aggregation_field_coalesce(self):
185205
await self.model.create(chars="aaa", nullable="null")
186206
await self.model.create(chars="bbb")

tortoise/functions.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from typing import TYPE_CHECKING, Any, Optional, Type, Union, cast
22

3-
from pypika import Table, functions
3+
from pypika import Case, Table, functions
44
from pypika.functions import DistinctOptionFunction
55
from pypika.terms import ArithmeticExpression
66
from pypika.terms import Function as BaseFunction
77

88
from tortoise.exceptions import ConfigurationError
99
from tortoise.expressions import F
1010
from tortoise.fields.relational import BackwardFKRelation, ForeignKeyFieldInstance, RelationalField
11+
from tortoise.query_utils import Q, QueryModifier
1112

1213
if TYPE_CHECKING: # pragma: nocoverage
1314
from tortoise.models import Model
@@ -100,7 +101,7 @@ def _resolve_field_for_model(
100101
if func:
101102
field = func(self.field_object, field)
102103

103-
return {"joins": joins, "field": self._get_function_field(field, *default_values)}
104+
return {"joins": joins, "field": field}
104105

105106
def resolve(self, model: "Type[Model]", table: Table) -> dict:
106107
"""
@@ -114,6 +115,7 @@ def resolve(self, model: "Type[Model]", table: Table) -> dict:
114115

115116
if isinstance(self.field, str):
116117
function = self._resolve_field_for_model(model, table, self.field, *self.default_values)
118+
function["field"] = self._get_function_field(function["field"], *self.default_values)
117119
return function
118120
else:
119121
field, field_object = F.resolver_arithmetic_expression(model, self.field)
@@ -134,10 +136,15 @@ class Aggregate(Function):
134136
database_func = DistinctOptionFunction
135137

136138
def __init__(
137-
self, field: Union[str, F, ArithmeticExpression], *default_values: Any, distinct=False
139+
self,
140+
field: Union[str, F, ArithmeticExpression],
141+
*default_values: Any,
142+
distinct=False,
143+
_filter: Optional[Q] = None,
138144
) -> None:
139145
super().__init__(field, *default_values)
140146
self.distinct = distinct
147+
self.filter = _filter
141148

142149
def _get_function_field(
143150
self, field: "Union[ArithmeticExpression, Field, str]", *default_values
@@ -147,6 +154,18 @@ def _get_function_field(
147154
else:
148155
return self.database_func(field, *default_values)
149156

157+
def _resolve_field_for_model(
158+
self, model: "Type[Model]", table: Table, field: str, *default_values: Any
159+
) -> dict:
160+
ret = super()._resolve_field_for_model(model, table, field, default_values)
161+
if self.filter:
162+
modifier = QueryModifier()
163+
modifier &= self.filter.resolve(model, {}, {}, model._meta.basetable)
164+
where_criterion, joins, having_criterion = modifier.get_query_modifiers()
165+
ret["field"] = Case().when(where_criterion, ret["field"]).else_(None)
166+
167+
return ret
168+
150169

151170
##############################################################################
152171
# Standard functions

0 commit comments

Comments
 (0)