Aggregation

An aggregate or aggregation is a function where the values of multiple rows are processed together to form a single summary value. For performing an aggregation, DataFusion provides the aggregate()

In [1]: from datafusion import SessionContext, col, lit, functions as f

In [2]: ctx = SessionContext()

In [3]: df = ctx.read_csv("pokemon.csv")

In [4]: col_type_1 = col('"Type 1"')

In [5]: col_type_2 = col('"Type 2"')

In [6]: col_speed = col('"Speed"')

In [7]: col_attack = col('"Attack"')

In [8]: df.aggregate([col_type_1], [
   ...:     f.approx_distinct(col_speed).alias("Count"),
   ...:     f.approx_median(col_speed).alias("Median Speed"),
   ...:     f.approx_percentile_cont(col_speed, 0.9).alias("90% Speed")])
   ...: 
Out[8]: 
DataFrame()
+----------+-------+--------------+-----------+
| Type 1   | Count | Median Speed | 90% Speed |
+----------+-------+--------------+-----------+
| Bug      | 11    | 63           | 107       |
| Poison   | 12    | 55           | 85        |
| Electric | 8     | 100          | 136       |
| Fairy    | 2     | 47           | 60        |
| Normal   | 20    | 71           | 110       |
| Ice      | 2     | 90           | 95        |
| Grass    | 8     | 55           | 80        |
| Fire     | 8     | 91           | 100       |
| Water    | 21    | 70           | 90        |
| Ground   | 7     | 40           | 112       |
+----------+-------+--------------+-----------+
Data truncated.

When the group_by list is empty the aggregation is done over the whole DataFrame. For grouping the group_by list must contain at least one column.

In [9]: df.aggregate([col_type_1], [
   ...:     f.max(col_speed).alias("Max Speed"),
   ...:     f.avg(col_speed).alias("Avg Speed"),
   ...:     f.min(col_speed).alias("Min Speed")])
   ...: 
Out[9]: 
DataFrame()
+----------+-----------+--------------------+-----------+
| Type 1   | Max Speed | Avg Speed          | Min Speed |
+----------+-----------+--------------------+-----------+
| Bug      | 145       | 66.78571428571429  | 25        |
| Poison   | 90        | 58.785714285714285 | 25        |
| Electric | 140       | 98.88888888888889  | 45        |
| Fairy    | 60        | 47.5               | 35        |
| Normal   | 121       | 72.75              | 20        |
| Ice      | 95        | 90.0               | 85        |
| Grass    | 80        | 54.23076923076923  | 30        |
| Fire     | 105       | 86.28571428571429  | 60        |
| Water    | 115       | 67.25806451612904  | 15        |
| Ground   | 120       | 58.125             | 25        |
+----------+-----------+--------------------+-----------+
Data truncated.

More than one column can be used for grouping

In [10]: df.aggregate([col_type_1, col_type_2], [
   ....:     f.max(col_speed).alias("Max Speed"),
   ....:     f.avg(col_speed).alias("Avg Speed"),
   ....:     f.min(col_speed).alias("Min Speed")])
   ....: 
Out[10]: 
DataFrame()
+----------+---------+-----------+--------------------+-----------+
| Type 1   | Type 2  | Max Speed | Avg Speed          | Min Speed |
+----------+---------+-----------+--------------------+-----------+
| Bug      |         | 85        | 53.333333333333336 | 30        |
| Normal   | Flying  | 121       | 83.77777777777777  | 56        |
| Poison   |         | 80        | 51.7               | 25        |
| Electric |         | 140       | 112.5              | 90        |
| Fairy    |         | 60        | 47.5               | 35        |
| Water    | Ice     | 70        | 66.66666666666667  | 60        |
| Ice      | Psychic | 95        | 95.0               | 95        |
| Ice      | Flying  | 85        | 85.0               | 85        |
| Fire     | Flying  | 100       | 96.66666666666667  | 90        |
| Fire     | Dragon  | 100       | 100.0              | 100       |
+----------+---------+-----------+--------------------+-----------+
Data truncated.

Setting Parameters

Each of the built in aggregate functions provides arguments for the parameters that affect their operation. These can also be overridden using the builder approach to setting any of the following parameters. When you use the builder, you must call build() to finish. For example, these two expressions are equivalent.

In [11]: first_1 = f.first_value(col("a"), order_by=[col("a")])

In [12]: first_2 = f.first_value(col("a")).order_by(col("a")).build()

Ordering

You can control the order in which rows are processed by window functions by providing a list of order_by functions for the order_by parameter. In the following example, we sort the Pokemon by their attack in increasing order and take the first value, which gives us the Pokemon with the smallest attack value in each Type 1.

In [13]: df.aggregate(
   ....:     [col('"Type 1"')],
   ....:     [f.first_value(
   ....:         col('"Name"'),
   ....:         order_by=[col('"Attack"').sort(ascending=True)]
   ....:         ).alias("Smallest Attack")
   ....:     ])
   ....: 
Out[13]: 
DataFrame()
+----------+-----------------+
| Type 1   | Smallest Attack |
+----------+-----------------+
| Bug      | Metapod         |
| Poison   | Zubat           |
| Electric | Voltorb         |
| Fairy    | Clefairy        |
| Normal   | Chansey         |
| Ice      | Jynx            |
| Grass    | Exeggcute       |
| Fire     | Vulpix          |
| Water    | Magikarp        |
| Ground   | Cubone          |
+----------+-----------------+
Data truncated.

Distinct

When you set the parameter distinct to True, then unique values will only be evaluated one time each. Suppose we want to create an array of all of the Type 2 for each Type 1 of our Pokemon set. Since there will be many entries of Type 2 we only one each distinct value.

In [14]: df.aggregate([col_type_1], [f.array_agg(col_type_2, distinct=True).alias("Type 2 List")])
Out[14]: 
DataFrame()
+----------+--------------------------------------------------+
| Type 1   | Type 2 List                                      |
+----------+--------------------------------------------------+
| Bug      | [, Poison, Grass, Flying]                        |
| Poison   | [, Flying, Ground]                               |
| Electric | [, Steel, Flying]                                |
| Fairy    | []                                               |
| Normal   | [Flying, , Fairy]                                |
| Ice      | [Psychic, Flying]                                |
| Grass    | [, Poison, Psychic]                              |
| Fire     | [Flying, Dragon, ]                               |
| Water    | [Fighting, Poison, Flying, Psychic, Ice, , Dark] |
| Ground   | [Rock, ]                                         |
+----------+--------------------------------------------------+
Data truncated.

In the output of the above we can see that there are some Type 1 for which the Type 2 entry is null. In reality, we probably want to filter those out. We can do this in two ways. First, we can filter DataFrame rows that have no Type 2. If we do this, we might have some Type 1 entries entirely removed. The second is we can use the filter argument described below.

In [15]: df.filter(col_type_2.is_not_null()).aggregate([col_type_1], [f.array_agg(col_type_2, distinct=True).alias("Type 2 List")])
Out[15]: 
DataFrame()
+----------+------------------------------------------------+
| Type 1   | Type 2 List                                    |
+----------+------------------------------------------------+
| Bug      | [Grass, Flying, Poison]                        |
| Poison   | [Flying, Ground]                               |
| Electric | [Flying, Steel]                                |
| Normal   | [Flying, Fairy]                                |
| Ice      | [Flying, Psychic]                              |
| Grass    | [Psychic, Poison]                              |
| Fire     | [Flying, Dragon]                               |
| Water    | [Flying, Dark, Poison, Fighting, Psychic, Ice] |
| Rock     | [Flying, Water, Ground]                        |
| Ghost    | [Poison]                                       |
+----------+------------------------------------------------+
Data truncated.

In [16]: df.aggregate([col_type_1], [f.array_agg(col_type_2, distinct=True, filter=col_type_2.is_not_null()).alias("Type 2 List")])
Out[16]: 
DataFrame()
+----------+------------------------------------------------+
| Type 1   | Type 2 List                                    |
+----------+------------------------------------------------+
| Bug      | [Grass, Flying, Poison]                        |
| Poison   | [Ground, Flying]                               |
| Electric | [Steel, Flying]                                |
| Fairy    |                                                |
| Normal   | [Fairy, Flying]                                |
| Ice      | [Psychic, Flying]                              |
| Grass    | [Poison, Psychic]                              |
| Fire     | [Dragon, Flying]                               |
| Water    | [Dark, Psychic, Poison, Ice, Flying, Fighting] |
| Ground   | [Rock]                                         |
+----------+------------------------------------------------+
Data truncated.

Which approach you take should depend on your use case.

Null Treatment

This option allows you to either respect or ignore null values.

One common usage for handling nulls is the case where you want to find the first value within a partition. By setting the null treatment to ignore nulls, we can find the first non-null value in our partition.

In [17]: from datafusion.common import NullTreatment

In [18]: df.aggregate([col_type_1], [
   ....:     f.first_value(
   ....:         col_type_2,
   ....:         order_by=[col_attack],
   ....:         null_treatment=NullTreatment.RESPECT_NULLS
   ....:     ).alias("Lowest Attack Type 2")])
   ....: 
Out[18]: 
DataFrame()
+----------+----------------------+
| Type 1   | Lowest Attack Type 2 |
+----------+----------------------+
| Bug      |                      |
| Poison   | Flying               |
| Electric |                      |
| Fairy    |                      |
| Normal   |                      |
| Ice      | Psychic              |
| Grass    | Psychic              |
| Fire     |                      |
| Water    |                      |
| Ground   |                      |
+----------+----------------------+
Data truncated.

In [19]: df.aggregate([col_type_1], [
   ....:     f.first_value(
   ....:         col_type_2,
   ....:         order_by=[col_attack],
   ....:         null_treatment=NullTreatment.IGNORE_NULLS
   ....:     ).alias("Lowest Attack Type 2")])
   ....: 
Out[19]: 
DataFrame()
+----------+----------------------+
| Type 1   | Lowest Attack Type 2 |
+----------+----------------------+
| Bug      | Poison               |
| Poison   | Flying               |
| Electric | Steel                |
| Fairy    |                      |
| Normal   | Flying               |
| Ice      | Psychic              |
| Grass    | Psychic              |
| Fire     | Flying               |
| Water    | Poison               |
| Ground   | Rock                 |
+----------+----------------------+
Data truncated.

Filter

Using the filter option is useful for filtering results to include in the aggregate function. It can be seen in the example above on how this can be useful to only filter rows evaluated by the aggregate function without filtering rows from the entire DataFrame.

Filter takes a single expression.

Suppose we want to find the speed values for only Pokemon that have low Attack values.

In [20]: df.aggregate([col_type_1], [
   ....:     f.avg(col_speed).alias("Avg Speed All"),
   ....:     f.avg(col_speed, filter=col_attack < lit(50)).alias("Avg Speed Low Attack")])
   ....: 
Out[20]: 
DataFrame()
+----------+--------------------+----------------------+
| Type 1   | Avg Speed All      | Avg Speed Low Attack |
+----------+--------------------+----------------------+
| Bug      | 66.78571428571429  | 46.0                 |
| Poison   | 58.785714285714285 | 48.0                 |
| Electric | 98.88888888888889  | 72.5                 |
| Fairy    | 47.5               | 35.0                 |
| Normal   | 72.75              | 52.8                 |
| Ice      | 90.0               |                      |
| Grass    | 54.23076923076923  | 42.5                 |
| Fire     | 86.28571428571429  | 65.0                 |
| Water    | 67.25806451612904  | 63.833333333333336   |
| Ground   | 58.125             |                      |
+----------+--------------------+----------------------+
Data truncated.

Grouping Sets

The default style of aggregation produces one row per group. Sometimes you want a single query to produce rows at multiple levels of detail — for example, totals per type and an overall grand total, or subtotals for every combination of two columns plus the individual column totals. Writing separate queries and concatenating them is tedious and runs the data multiple times. Grouping sets solve this by letting you specify several grouping levels in one pass.

DataFusion supports three grouping set styles through the GroupingSet class:

  • rollup() — hierarchical subtotals, like a drill-down report

  • cube() — every possible subtotal combination, like a pivot table

  • grouping_sets() — explicitly list exactly which grouping levels you want

Because result rows come from different grouping levels, a column that is not part of a particular level will be null in that row. Use grouping() to distinguish a real null in the data from one that means “this column was aggregated across.” It returns 0 when the column is a grouping key for that row, and 1 when it is not.

Rollup

rollup() creates a hierarchy. rollup(a, b) produces grouping sets (a, b), (a), and () — like nested subtotals in a report. This is useful when your columns have a natural hierarchy, such as region → city or type → subtype.

Suppose we want to summarize Pokemon stats by Type 1 with subtotals and a grand total. With the default aggregation style we would need two separate queries. With rollup we get it all at once:

In [21]: from datafusion.expr import GroupingSet

In [22]: df.aggregate(
   ....:     [GroupingSet.rollup(col_type_1)],
   ....:     [f.count(col_speed).alias("Count"),
   ....:      f.avg(col_speed).alias("Avg Speed"),
   ....:      f.max(col_speed).alias("Max Speed")]
   ....: ).sort(col_type_1.sort(ascending=True, nulls_first=True))
   ....: 
Out[22]: 
DataFrame()
+----------+-------+-------------------+-----------+
| Type 1   | Count | Avg Speed         | Max Speed |
+----------+-------+-------------------+-----------+
|          | 163   | 71.65030674846626 | 150       |
| Bug      | 14    | 66.78571428571429 | 145       |
| Dragon   | 3     | 66.66666666666667 | 80        |
| Electric | 9     | 98.88888888888889 | 140       |
| Fairy    | 2     | 47.5              | 60        |
| Fighting | 7     | 66.14285714285714 | 95        |
| Fire     | 14    | 86.28571428571429 | 105       |
| Ghost    | 4     | 103.75            | 130       |
| Grass    | 13    | 54.23076923076923 | 80        |
| Ground   | 8     | 58.125            | 120       |
+----------+-------+-------------------+-----------+
Data truncated.

The first row — where Type 1 is null — is the grand total across all types. But how do you tell a grand-total null apart from a Pokemon that genuinely has no type? The grouping() function returns 0 when the column is a grouping key for that row and 1 when it is aggregated across.

Note

Due to an upstream DataFusion limitation (apache/datafusion#21411), .alias() cannot be applied directly to a grouping() expression — it will raise an error at execution time. Instead, use with_column_renamed() on the result DataFrame to give the column a readable name. Once the upstream issue is resolved, you will be able to use .alias() directly and the workaround below will no longer be necessary.

The raw column name generated by grouping() contains internal identifiers, so we use with_column_renamed() to clean it up:

In [23]: result = df.aggregate(
   ....:     [GroupingSet.rollup(col_type_1)],
   ....:     [f.count(col_speed).alias("Count"),
   ....:      f.avg(col_speed).alias("Avg Speed"),
   ....:      f.grouping(col_type_1)]
   ....: )
   ....: 

In [24]: for field in result.schema():
   ....:     if field.name.startswith("grouping("):
   ....:         result = result.with_column_renamed(field.name, "Is Total")
   ....: 

In [25]: result.sort(col_type_1.sort(ascending=True, nulls_first=True))
Out[25]: 
DataFrame()
+----------+-------+-------------------+----------+
| Type 1   | Count | Avg Speed         | Is Total |
+----------+-------+-------------------+----------+
|          | 163   | 71.65030674846626 | 1        |
| Bug      | 14    | 66.78571428571429 | 0        |
| Dragon   | 3     | 66.66666666666667 | 0        |
| Electric | 9     | 98.88888888888889 | 0        |
| Fairy    | 2     | 47.5              | 0        |
| Fighting | 7     | 66.14285714285714 | 0        |
| Fire     | 14    | 86.28571428571429 | 0        |
| Ghost    | 4     | 103.75            | 0        |
| Grass    | 13    | 54.23076923076923 | 0        |
| Ground   | 8     | 58.125            | 0        |
+----------+-------+-------------------+----------+
Data truncated.

With two columns the hierarchy becomes more apparent. rollup(Type 1, Type 2) produces:

  • one row per (Type 1, Type 2) pair — the most detailed level

  • one row per Type 1 — subtotals

  • one grand total row

In [26]: df.aggregate(
   ....:     [GroupingSet.rollup(col_type_1, col_type_2)],
   ....:     [f.count(col_speed).alias("Count"),
   ....:      f.avg(col_speed).alias("Avg Speed")]
   ....: ).sort(
   ....:     col_type_1.sort(ascending=True, nulls_first=True),
   ....:     col_type_2.sort(ascending=True, nulls_first=True)
   ....: )
   ....: 
Out[26]: 
DataFrame()
+----------+--------+-------+--------------------+
| Type 1   | Type 2 | Count | Avg Speed          |
+----------+--------+-------+--------------------+
|          |        | 163   | 71.65030674846626  |
| Bug      |        | 14    | 66.78571428571429  |
| Bug      |        | 3     | 53.333333333333336 |
| Bug      | Flying | 3     | 93.33333333333333  |
| Bug      | Grass  | 2     | 27.5               |
| Bug      | Poison | 6     | 73.33333333333333  |
| Dragon   |        | 2     | 60.0               |
| Dragon   |        | 3     | 66.66666666666667  |
| Dragon   | Flying | 1     | 80.0               |
| Electric |        | 9     | 98.88888888888889  |
+----------+--------+-------+--------------------+
Data truncated.

Cube

cube() produces every possible subset. cube(a, b) produces grouping sets (a, b), (a), (b), and () — one more than rollup because it also includes (b) alone. This is useful when neither column is “above” the other in a hierarchy and you want all cross-tabulations.

For our Pokemon data, cube(Type 1, Type 2) gives us stats broken down by the type pair, by Type 1 alone, by Type 2 alone, and a grand total — all in one query:

In [27]: df.aggregate(
   ....:     [GroupingSet.cube(col_type_1, col_type_2)],
   ....:     [f.count(col_speed).alias("Count"),
   ....:      f.avg(col_speed).alias("Avg Speed")]
   ....: ).sort(
   ....:     col_type_1.sort(ascending=True, nulls_first=True),
   ....:     col_type_2.sort(ascending=True, nulls_first=True)
   ....: )
   ....: 
Out[27]: 
DataFrame()
+--------+----------+-------+--------------------+
| Type 1 | Type 2   | Count | Avg Speed          |
+--------+----------+-------+--------------------+
|        |          | 163   | 71.65030674846626  |
|        |          | 86    | 72.46511627906976  |
|        | Dark     | 1     | 81.0               |
|        | Dragon   | 1     | 100.0              |
|        | Fairy    | 3     | 51.666666666666664 |
|        | Fighting | 1     | 70.0               |
|        | Flying   | 23    | 91.08695652173913  |
|        | Grass    | 2     | 27.5               |
|        | Ground   | 6     | 55.166666666666664 |
|        | Ice      | 3     | 66.66666666666667  |
+--------+----------+-------+--------------------+
Data truncated.

Compared to the rollup example above, notice the extra rows where Type 1 is null but Type 2 has a value — those are the per-Type 2 subtotals that rollup does not include.

Explicit Grouping Sets

grouping_sets() lets you list exactly which grouping levels you need when rollup or cube would produce too many or too few. Each argument is a list of columns forming one grouping set.

For example, if we want only the per-Type 1 totals and per-Type 2 totals — but not the full (Type 1, Type 2) detail rows or the grand total — we can ask for exactly that:

In [28]: df.aggregate(
   ....:     [GroupingSet.grouping_sets([col_type_1], [col_type_2])],
   ....:     [f.count(col_speed).alias("Count"),
   ....:      f.avg(col_speed).alias("Avg Speed")]
   ....: ).sort(
   ....:     col_type_1.sort(ascending=True, nulls_first=True),
   ....:     col_type_2.sort(ascending=True, nulls_first=True)
   ....: )
   ....: 
Out[28]: 
DataFrame()
+--------+----------+-------+--------------------+
| Type 1 | Type 2   | Count | Avg Speed          |
+--------+----------+-------+--------------------+
|        |          | 86    | 72.46511627906976  |
|        | Dark     | 1     | 81.0               |
|        | Dragon   | 1     | 100.0              |
|        | Fairy    | 3     | 51.666666666666664 |
|        | Fighting | 1     | 70.0               |
|        | Flying   | 23    | 91.08695652173913  |
|        | Grass    | 2     | 27.5               |
|        | Ground   | 6     | 55.166666666666664 |
|        | Ice      | 3     | 66.66666666666667  |
|        | Poison   | 22    | 71.5909090909091   |
+--------+----------+-------+--------------------+
Data truncated.

Each row belongs to exactly one grouping level. The grouping() function tells you which level each row comes from:

In [29]: result = df.aggregate(
   ....:     [GroupingSet.grouping_sets([col_type_1], [col_type_2])],
   ....:     [f.count(col_speed).alias("Count"),
   ....:      f.avg(col_speed).alias("Avg Speed"),
   ....:      f.grouping(col_type_1),
   ....:      f.grouping(col_type_2)]
   ....: )
   ....: 

In [30]: for field in result.schema():
   ....:     if field.name.startswith("grouping("):
   ....:         clean = field.name.split(".")[-1].rstrip(")")
   ....:         result = result.with_column_renamed(field.name, f"grouping({clean})")
   ....: 

In [31]: result.sort(
   ....:     col_type_1.sort(ascending=True, nulls_first=True),
   ....:     col_type_2.sort(ascending=True, nulls_first=True)
   ....: )
   ....: 
Out[31]: 
DataFrame()
+--------+----------+-------+--------------------+------------------+------------------+
| Type 1 | Type 2   | Count | Avg Speed          | grouping(Type 1) | grouping(Type 2) |
+--------+----------+-------+--------------------+------------------+------------------+
|        |          | 86    | 72.46511627906976  | 1                | 0                |
|        | Dark     | 1     | 81.0               | 1                | 0                |
|        | Dragon   | 1     | 100.0              | 1                | 0                |
|        | Fairy    | 3     | 51.666666666666664 | 1                | 0                |
|        | Fighting | 1     | 70.0               | 1                | 0                |
|        | Flying   | 23    | 91.08695652173913  | 1                | 0                |
|        | Grass    | 2     | 27.5               | 1                | 0                |
|        | Ground   | 6     | 55.166666666666664 | 1                | 0                |
|        | Ice      | 3     | 66.66666666666667  | 1                | 0                |
|        | Poison   | 22    | 71.5909090909091   | 1                | 0                |
+--------+----------+-------+--------------------+------------------+------------------+
Data truncated.

Where grouping(Type 1) is 0 the row is a per-Type 1 total (and Type 2 is null). Where grouping(Type 2) is 0 the row is a per-Type 2 total (and Type 1 is null).

Aggregate Functions

The available aggregate functions are:

  1. Comparison Functions
  2. Math Functions
  3. Array Functions
  4. Logical Functions
  5. Statistical Functions
  6. Linear Regression Functions
  7. Positional Functions
  8. String Functions
  9. Percentile Functions
  10. Grouping Set Functions - datafusion.functions.grouping() - datafusion.expr.GroupingSet.rollup() - datafusion.expr.GroupingSet.cube() - datafusion.expr.GroupingSet.grouping_sets()