11from typing import TYPE_CHECKING , Any , Optional , Type , Union , cast
22
3- from pypika import Table , functions
3+ from pypika import Case , Table , functions
44from pypika .functions import DistinctOptionFunction
55from pypika .terms import ArithmeticExpression
66from pypika .terms import Function as BaseFunction
77
88from tortoise .exceptions import ConfigurationError
99from tortoise .expressions import F
1010from tortoise .fields .relational import BackwardFKRelation , ForeignKeyFieldInstance , RelationalField
11+ from tortoise .query_utils import Q , QueryModifier
1112
1213if TYPE_CHECKING : # pragma: nocoverage
1314 from tortoise .models import Model
@@ -100,7 +101,7 @@ def _resolve_field_for_model(
100101 if func :
101102 field = func (self .field_object , field )
102103
103- return {"joins" : joins , "field" : self . _get_function_field ( field , * default_values ) }
104+ return {"joins" : joins , "field" : field }
104105
105106 def resolve (self , model : "Type[Model]" , table : Table ) -> dict :
106107 """
@@ -114,6 +115,7 @@ def resolve(self, model: "Type[Model]", table: Table) -> dict:
114115
115116 if isinstance (self .field , str ):
116117 function = self ._resolve_field_for_model (model , table , self .field , * self .default_values )
118+ function ["field" ] = self ._get_function_field (function ["field" ], * self .default_values )
117119 return function
118120 else :
119121 field , field_object = F .resolver_arithmetic_expression (model , self .field )
@@ -134,10 +136,15 @@ class Aggregate(Function):
134136 database_func = DistinctOptionFunction
135137
136138 def __init__ (
137- self , field : Union [str , F , ArithmeticExpression ], * default_values : Any , distinct = False
139+ self ,
140+ field : Union [str , F , ArithmeticExpression ],
141+ * default_values : Any ,
142+ distinct = False ,
143+ _filter : Optional [Q ] = None ,
138144 ) -> None :
139145 super ().__init__ (field , * default_values )
140146 self .distinct = distinct
147+ self .filter = _filter
141148
142149 def _get_function_field (
143150 self , field : "Union[ArithmeticExpression, Field, str]" , * default_values
@@ -147,6 +154,18 @@ def _get_function_field(
147154 else :
148155 return self .database_func (field , * default_values )
149156
157+ def _resolve_field_for_model (
158+ self , model : "Type[Model]" , table : Table , field : str , * default_values : Any
159+ ) -> dict :
160+ ret = super ()._resolve_field_for_model (model , table , field , default_values )
161+ if self .filter :
162+ modifier = QueryModifier ()
163+ modifier &= self .filter .resolve (model , {}, {}, model ._meta .basetable )
164+ where_criterion , joins , having_criterion = modifier .get_query_modifiers ()
165+ ret ["field" ] = Case ().when (where_criterion , ret ["field" ]).else_ (None )
166+
167+ return ret
168+
150169
151170##############################################################################
152171# Standard functions
0 commit comments