Skip to content
This repository was archived by the owner on Mar 2, 2026. It is now read-only.

Commit de5dbef

Browse files
committed
feat: Add support for Sum and Avg aggregation query
Add .sum() and .avg() functions to aggregation Refactor limit to be passed in to the nested query's limit Unit tests
1 parent b31a944 commit de5dbef

File tree

3 files changed

+264
-13
lines changed

3 files changed

+264
-13
lines changed

google/cloud/datastore/aggregation.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class BaseAggregation(ABC):
3939
Base class representing an Aggregation operation in Datastore
4040
"""
4141

42+
def __init__(self, alias=None):
43+
self.alias = alias
44+
4245
@abc.abstractmethod
4346
def _to_pb(self):
4447
"""
@@ -59,7 +62,7 @@ class CountAggregation(BaseAggregation):
5962
"""
6063

6164
def __init__(self, alias=None):
62-
self.alias = alias
65+
super(CountAggregation, self).__init__(alias=alias)
6366

6467
def _to_pb(self):
6568
"""
@@ -71,6 +74,61 @@ def _to_pb(self):
7174
return aggregation_pb
7275

7376

77+
class SumAggregation(BaseAggregation):
78+
"""
79+
Representation of a "Sum" aggregation query.
80+
81+
:type property_ref: str
82+
:param property_ref: The property_ref for the aggregation.
83+
84+
:type value: int
85+
:param value: The resulting value from the aggregation.
86+
87+
"""
88+
89+
def __init__(self, property_ref, alias=None):
90+
self.property_ref = property_ref
91+
super(SumAggregation, self).__init__(alias=alias)
92+
93+
def _to_pb(self):
94+
"""
95+
Convert this instance to the protobuf representation
96+
"""
97+
aggregation_pb = query_pb2.AggregationQuery.Aggregation()
98+
aggregation_pb.alias = self.alias
99+
aggregation_pb.sum_ = query_pb2.AggregationQuery.Aggregation.Sum()
100+
aggregation_pb.sum_.property.name = self.property_ref
101+
aggregation_pb.alias = self.alias
102+
return aggregation_pb
103+
104+
105+
class AvgAggregation(BaseAggregation):
106+
"""
107+
Representation of a "Avg" aggregation query.
108+
109+
:type property_ref: str
110+
:param property_ref: The property_ref for the aggregation.
111+
112+
:type value: int
113+
:param value: The resulting value from the aggregation.
114+
115+
"""
116+
117+
def __init__(self, property_ref, alias=None):
118+
self.property_ref = property_ref
119+
super(AvgAggregation, self).__init__(alias=alias)
120+
121+
def _to_pb(self):
122+
"""
123+
Convert this instance to the protobuf representation
124+
"""
125+
aggregation_pb = query_pb2.AggregationQuery.Aggregation()
126+
aggregation_pb.avg = query_pb2.AggregationQuery.Aggregation.Avg()
127+
aggregation_pb.avg.property.name = self.property_ref
128+
aggregation_pb.alias = self.alias
129+
return aggregation_pb
130+
131+
74132
class AggregationResult(object):
75133
"""
76134
A class representing result from Aggregation Query
@@ -154,6 +212,28 @@ def count(self, alias=None):
154212
self._aggregations.append(count_aggregation)
155213
return self
156214

215+
def sum(self, property_ref, alias=None):
216+
"""
217+
Adds a sum over the nested query
218+
219+
:type property_ref: str
220+
:param property_ref: The property_ref for the sum
221+
"""
222+
sum_aggregation = SumAggregation(property_ref=property_ref, alias=alias)
223+
self._aggregations.append(sum_aggregation)
224+
return self
225+
226+
def avg(self, property_ref, alias=None):
227+
"""
228+
Adds a avg over the nested query
229+
230+
:type property_ref: str
231+
:param property_ref: The property_ref for the sum
232+
"""
233+
avg_aggregation = AvgAggregation(property_ref=property_ref, alias=alias)
234+
self._aggregations.append(avg_aggregation)
235+
return self
236+
157237
def add_aggregation(self, aggregation):
158238
"""
159239
Adds an aggregation operation to the nested query
@@ -327,8 +407,7 @@ def _build_protobuf(self):
327407
"""
328408
pb = self._aggregation_query._to_pb()
329409
if self._limit is not None and self._limit > 0:
330-
for aggregation in pb.aggregations:
331-
aggregation.count.up_to = self._limit
410+
pb.nested_query.limit = self._limit
332411
return pb
333412

334413
def _process_query_results(self, response_pb):

tests/system/test_aggregation_query.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,54 @@ def test_aggregation_query_with_alias(aggregation_query_client, nested_query):
9393
assert r.value > 0
9494

9595

96+
def test_sum_query_default(aggregation_query_client, nested_query):
97+
query = nested_query
98+
99+
aggregation_query = aggregation_query_client.aggregation_query(query)
100+
aggregation_query.sum("person")
101+
result = _do_fetch(aggregation_query)
102+
assert len(result) == 1
103+
for r in result[0]:
104+
assert r.alias == "property_1"
105+
assert r.value == 8
106+
107+
108+
def test_sum_query_with_alias(aggregation_query_client, nested_query):
109+
query = nested_query
110+
111+
aggregation_query = aggregation_query_client.aggregation_query(query)
112+
aggregation_query.sum("person", alias="sum_person")
113+
result = _do_fetch(aggregation_query)
114+
assert len(result) == 1
115+
for r in result[0]:
116+
assert r.alias == "sum_person"
117+
assert r.value > 0
118+
119+
120+
def test_avg_query_default(aggregation_query_client, nested_query):
121+
query = nested_query
122+
123+
aggregation_query = aggregation_query_client.aggregation_query(query)
124+
aggregation_query.avg("person")
125+
result = _do_fetch(aggregation_query)
126+
assert len(result) == 1
127+
for r in result[0]:
128+
assert r.alias == "property_1"
129+
assert r.value == 8
130+
131+
132+
def test_avg_query_with_alias(aggregation_query_client, nested_query):
133+
query = nested_query
134+
135+
aggregation_query = aggregation_query_client.aggregation_query(query)
136+
aggregation_query.avg("person", alias="avg_person")
137+
result = _do_fetch(aggregation_query)
138+
assert len(result) == 1
139+
for r in result[0]:
140+
assert r.alias == "avg_person"
141+
assert r.value > 0
142+
143+
96144
def test_aggregation_query_with_limit(aggregation_query_client, nested_query):
97145
query = nested_query
98146

@@ -121,41 +169,60 @@ def test_aggregation_query_multiple_aggregations(
121169
aggregation_query = aggregation_query_client.aggregation_query(query)
122170
aggregation_query.count(alias="total")
123171
aggregation_query.count(alias="all")
172+
aggregation_query.sum("person", alias="sum_person")
173+
aggregation_query.avg("person", alias="avg_person")
124174
result = _do_fetch(aggregation_query)
125175
assert len(result) == 1
126176
for r in result[0]:
127-
assert r.alias in ["all", "total"]
177+
assert r.alias in ["all", "total", "sum_person", "avg_person"]
128178
assert r.value > 0
129179

130180

131181
def test_aggregation_query_add_aggregation(aggregation_query_client, nested_query):
132182
from google.cloud.datastore.aggregation import CountAggregation
183+
from google.cloud.datastore.aggregation import SumAggregation
184+
from google.cloud.datastore.aggregation import AvgAggregation
133185

134186
query = nested_query
135187

136188
aggregation_query = aggregation_query_client.aggregation_query(query)
137189
count_aggregation = CountAggregation(alias="total")
138190
aggregation_query.add_aggregation(count_aggregation)
191+
192+
sum_aggregation = SumAggregation("person", alias="sum_person")
193+
aggregation_query.add_aggregation(sum_aggregation)
194+
195+
avg_aggregation = AvgAggregation("person", alias="avg_person")
196+
aggregation_query.add_aggregation(avg_aggregation)
197+
139198
result = _do_fetch(aggregation_query)
140199
assert len(result) == 1
141200
for r in result[0]:
142-
assert r.alias == "total"
201+
assert r.alias in ["total", "sum_person", "avg_person"]
143202
assert r.value > 0
144203

145204

146205
def test_aggregation_query_add_aggregations(aggregation_query_client, nested_query):
147-
from google.cloud.datastore.aggregation import CountAggregation
206+
from google.cloud.datastore.aggregation import (
207+
CountAggregation,
208+
SumAggregation,
209+
AvgAggregation,
210+
)
148211

149212
query = nested_query
150213

151214
aggregation_query = aggregation_query_client.aggregation_query(query)
152215
count_aggregation_1 = CountAggregation(alias="total")
153216
count_aggregation_2 = CountAggregation(alias="all")
154-
aggregation_query.add_aggregations([count_aggregation_1, count_aggregation_2])
217+
sum_aggregation = SumAggregation("person", alias="sum_person")
218+
avg_aggregation = AvgAggregation("person", alias="avg_person")
219+
aggregation_query.add_aggregations(
220+
[count_aggregation_1, count_aggregation_2, sum_aggregation, avg_aggregation]
221+
)
155222
result = _do_fetch(aggregation_query)
156223
assert len(result) == 1
157224
for r in result[0]:
158-
assert r.alias in ["total", "all"]
225+
assert r.alias in ["total", "all", "sum_person", "avg_person"]
159226
assert r.value > 0
160227

161228

@@ -202,11 +269,13 @@ def test_aggregation_query_with_nested_query_filtered(
202269

203270
aggregation_query = aggregation_query_client.aggregation_query(query)
204271
aggregation_query.count(alias="total")
272+
aggregation_query.sum("person", alias="sum_person")
273+
aggregation_query.avg("person", alias="avg_person")
205274
result = _do_fetch(aggregation_query)
206275
assert len(result) == 1
207276

208277
for r in result[0]:
209-
assert r.alias == "total"
278+
assert r.alias in ["total", "sum_person", "avg_person"]
210279
assert r.value == 6
211280

212281

@@ -226,9 +295,11 @@ def test_aggregation_query_with_nested_query_multiple_filters(
226295

227296
aggregation_query = aggregation_query_client.aggregation_query(query)
228297
aggregation_query.count(alias="total")
298+
aggregation_query.sum("person", alias="sum_person")
299+
aggregation_query.avg("person", alias="avg_person")
229300
result = _do_fetch(aggregation_query)
230301
assert len(result) == 1
231302

232303
for r in result[0]:
233-
assert r.alias == "total"
304+
assert r.alias in ["total", "sum_person", "avg_person"]
234305
assert r.value == 4

0 commit comments

Comments
 (0)