Skip to content

Commit afd1cd5

Browse files
committed
pushdown scalar function
Signed-off-by: Mikhail Kot <mikhail@spiraldb.com>
1 parent a21bf76 commit afd1cd5

23 files changed

Lines changed: 727 additions & 44 deletions

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-duckdb/build.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ const DEFAULT_DUCKDB_VERSION: &str = "1.5.3";
2727

2828
const BUILD_ARTIFACTS: [&str; 3] = ["libduckdb.dylib", "libduckdb.so", "libduckdb_static.a"];
2929

30-
const SOURCE_FILES: [&str; 7] = [
30+
const SOURCE_FILES: [&str; 8] = [
3131
"cpp/vortex_duckdb.cpp",
3232
"cpp/copy_function.cpp",
3333
"cpp/expr.cpp",
3434
"cpp/scalar_fn_pushdown.cpp",
35+
"cpp/aggregate_fn_pushdown.cpp",
3536
"cpp/table_filter.cpp",
3637
"cpp/table_function.cpp",
3738
"cpp/vector.cpp",
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
#include "aggregate_fn_pushdown.hpp"
4+
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
5+
#include "duckdb/planner/operator/logical_aggregate.hpp"
6+
#include "scalar_fn_pushdown.hpp"
7+
#include "table_function.hpp"
8+
9+
using enum LogicalOperatorType;
10+
11+
LogicalOperatorPtr TryPushdownAggregateFunctions(ClientContext &context, LogicalOperatorPtr plan) {
12+
Analyses analyses;
13+
Projections projections;
14+
FindGetsAndProjections(*plan, analyses, projections);
15+
if (analyses.empty()) {
16+
return plan;
17+
}
18+
return RewriteAggregates(context, std::move(plan), analyses, projections);
19+
}
20+
21+
LogicalOperatorPtr RewriteAggregates(ClientContext &context,
22+
LogicalOperatorPtr op,
23+
Analyses &analyses,
24+
const Projections &projections) {
25+
for (auto &child : op->children) {
26+
child = RewriteAggregates(context, std::move(child), analyses, projections);
27+
}
28+
if (op->type == LOGICAL_AGGREGATE_AND_GROUP_BY) {
29+
return TryReplaceAggregate(context, std::move(op), analyses, projections);
30+
}
31+
return op;
32+
}
33+
34+
static bool IsUngrouped(const LogicalAggregate &agg) {
35+
return agg.groups.empty() && agg.grouping_sets.empty() && agg.grouping_functions.empty() &&
36+
!agg.expressions.empty();
37+
}
38+
39+
constexpr inline idx_t COUNT_STAR_PROJ_IDX = std::numeric_limits<TableColumnStorageIndex>::max();
40+
41+
LogicalOperatorPtr TryReplaceAggregate(ClientContext &context,
42+
LogicalOperatorPtr op,
43+
Analyses &analyses,
44+
const Projections &projections) {
45+
LogicalAggregate &agg = op->Cast<LogicalAggregate>();
46+
if (!IsUngrouped(agg)) {
47+
return op;
48+
}
49+
50+
LogicalGet *const get = GetChildGet(agg);
51+
if (get == nullptr) {
52+
return op;
53+
}
54+
55+
vector<std::pair<TableColumnStorageIndex, const Expression &>> input;
56+
const idx_t N = agg.expressions.size();
57+
input.reserve(N);
58+
59+
for (const auto &expr : agg.expressions) {
60+
if (expr->GetExpressionClass() != ExpressionClass::BOUND_AGGREGATE) {
61+
return op;
62+
}
63+
const auto &bound_aggr = expr->Cast<BoundAggregateExpression>();
64+
if (bound_aggr.IsDistinct() || bound_aggr.filter != nullptr || bound_aggr.order_bys != nullptr) {
65+
return op;
66+
}
67+
68+
if (bound_aggr.function.name == "count_star") {
69+
input.emplace_back(COUNT_STAR_PROJ_IDX, *expr);
70+
continue;
71+
}
72+
73+
if (bound_aggr.children.size() != 1 ||
74+
bound_aggr.children[0]->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) {
75+
return op;
76+
}
77+
const auto &bound_col = bound_aggr.children[0]->Cast<BoundColumnRefExpression>();
78+
const auto binding = Resolve(bound_col.binding, analyses, projections);
79+
if (!binding || &binding->analysis.get != get) {
80+
return op;
81+
}
82+
const TableColumnStorageIndex storage_index = binding->analysis.StorageIndex(binding->column_index);
83+
input.emplace_back(storage_index, *expr);
84+
}
85+
86+
if (!aggregate_pushdown(context, {*get, input})) {
87+
return op;
88+
}
89+
90+
// GET now returns one column per aggregate. Expand column into multiple
91+
// column if there are many aggregates per column.
92+
auto &column_ids = get->GetMutableColumnIds();
93+
get->types.resize(N);
94+
get->returned_types.resize(N);
95+
column_ids.resize(N);
96+
97+
vector<string> names(N); // need a copy because we reference original names
98+
99+
for (idx_t i = 0; i < N; i++) {
100+
const auto &[storage_index, expr] = input[i];
101+
names[i] = storage_index == COUNT_STAR_PROJ_IDX ? "count_star()" : get->names[storage_index];
102+
get->types[i] = expr.return_type;
103+
get->returned_types[i] = expr.return_type;
104+
column_ids[i] = ColumnIndex {i};
105+
}
106+
get->names = std::move(names);
107+
get->projection_ids.clear();
108+
get->table_index = agg.aggregate_index;
109+
110+
unique_ptr<LogicalOperator> &child = agg.children[0];
111+
if (child->type == LOGICAL_GET) {
112+
return std::move(child);
113+
}
114+
D_ASSERT(child->type == LOGICAL_PROJECTION);
115+
D_ASSERT(child->children.size() == 1);
116+
D_ASSERT(child->children[0]->type == LOGICAL_GET);
117+
return std::move(child->children[0]);
118+
}
119+
120+
LogicalGet *GetChildGet(const LogicalAggregate &agg) {
121+
if (agg.children.size() != 1) {
122+
return nullptr;
123+
}
124+
LogicalOperator &child = *agg.children[0];
125+
LogicalOperator *op;
126+
if (child.type == LOGICAL_GET) {
127+
op = &child;
128+
} else if (child.type == LOGICAL_PROJECTION && child.children.size() == 1 &&
129+
child.children[0]->type == LOGICAL_GET) {
130+
op = child.children[0].get();
131+
} else {
132+
return nullptr;
133+
}
134+
LogicalGet &get = op->Cast<LogicalGet>();
135+
return get.function.bind == duckdb_vx_table_function_bind ? &get : nullptr;
136+
}

vortex-duckdb/cpp/expr.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33

44
#include "expr.h"
55
#include "duckdb/function/scalar_function.hpp"
6+
#include "duckdb/function/aggregate_function.hpp"
67
#include "duckdb/planner/expression/bound_between_expression.hpp"
78
#include "duckdb/planner/expression/bound_columnref_expression.hpp"
89
#include "duckdb/planner/expression/bound_comparison_expression.hpp"
910
#include "duckdb/planner/expression/bound_constant_expression.hpp"
11+
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
1012
#include "duckdb/planner/expression/bound_function_expression.hpp"
1113
#include "duckdb/planner/expression/bound_operator_expression.hpp"
1214
#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
@@ -21,6 +23,11 @@ extern "C" const char *duckdb_vx_sfunc_name(duckdb_vx_sfunc ffi_func) {
2123
return func->name.c_str();
2224
}
2325

26+
extern "C" const char *duckdb_vx_agg_func_name(duckdb_vx_agg_func ffi) {
27+
D_ASSERT(ffi);
28+
return reinterpret_cast<AggregateFunction *>(ffi)->name.c_str();
29+
}
30+
2431
extern "C" const char *duckdb_vx_expr_to_string(duckdb_vx_expr ffi_expr) {
2532
if (!ffi_expr) {
2633
return nullptr;
@@ -129,3 +136,9 @@ extern "C" void duckdb_vx_expr_get_bound_function(duckdb_vx_expr ffi_expr,
129136
out->scalar_function = reinterpret_cast<duckdb_vx_sfunc>(&expr.function);
130137
out->bind_info = expr.bind_info.get();
131138
}
139+
140+
extern "C" duckdb_vx_agg_func duckdb_vx_expr_get_bound_aggregate_function(duckdb_vx_expr ffi_expr) {
141+
D_ASSERT(ffi_expr);
142+
auto &expr = reinterpret_cast<Expression *>(ffi_expr)->Cast<BoundAggregateExpression>();
143+
return reinterpret_cast<duckdb_vx_agg_func>(&expr.function);
144+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
#pragma once
4+
#include "scalar_fn_pushdown.hpp"
5+
#include "duckdb/optimizer/optimizer_extension.hpp"
6+
7+
using namespace duckdb;
8+
9+
// Push UNGROUPED_AGGREGATE's of form agg(T) and count_star() into GET.
10+
LogicalOperatorPtr TryPushdownAggregateFunctions(ClientContext &context, LogicalOperatorPtr plan);
11+
12+
LogicalOperatorPtr RewriteAggregates(ClientContext &context,
13+
LogicalOperatorPtr op,
14+
Analyses &analyses,
15+
const Projections &projections);
16+
17+
LogicalOperatorPtr TryReplaceAggregate(ClientContext &context,
18+
LogicalOperatorPtr op,
19+
Analyses &analyses,
20+
const Projections &projections);
21+
22+
// return GET for UNGROUPED_AGGREGATE -> [GET] or for UNGROUPED_AGGREGATE ->
23+
// PROJECTION -> [GET], nullptr if not found.
24+
LogicalGet *GetChildGet(const LogicalAggregate &agg);

vortex-duckdb/cpp/include/expr.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@ extern "C" {
1010
#endif
1111

1212
typedef struct duckdb_vx_sfunc_ *duckdb_vx_sfunc;
13+
typedef struct duckdb_vx_agg_func_ *duckdb_vx_agg_func;
1314

1415
const char *duckdb_vx_sfunc_name(duckdb_vx_sfunc ffi_func);
1516

1617
typedef struct duckdb_vx_expr_ *duckdb_vx_expr;
1718

19+
const char *duckdb_vx_agg_func_name(duckdb_vx_agg_func func);
20+
1821
/// Return the string representation of the expression. Must be freed with `duckdb_free`.
1922
const char *duckdb_vx_expr_to_string(duckdb_vx_expr expr);
2023

@@ -264,6 +267,8 @@ typedef struct {
264267

265268
void duckdb_vx_expr_get_bound_function(duckdb_vx_expr expr, duckdb_vx_expr_bound_function *out);
266269

270+
duckdb_vx_agg_func duckdb_vx_expr_get_bound_aggregate_function(duckdb_vx_expr expr);
271+
267272
#ifdef __cplusplus /* End C ABI */
268273
}
269274
#endif

vortex-duckdb/cpp/include/table_function.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ typedef struct {
8888

8989
duckdb_state duckdb_vx_register_table_functions(duckdb_database ffi_db);
9090

91+
typedef struct duckdb_vx_agg_input_ *duckdb_vx_agg_input;
92+
idx_t duckdb_vx_aggregate_len(duckdb_vx_agg_input ffi);
93+
duckdb_vx_expr duckdb_vx_aggregate_i(duckdb_vx_agg_input ffi, idx_t i, idx_t *proj_idx);
94+
9195
#ifdef __cplusplus
9296
}
9397
#endif

vortex-duckdb/cpp/include/table_function.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,12 @@ struct TableFunctionProjectionExpressionInput {
2525
// true if we can push down the expression, false otherwise
2626
bool projection_expression_pushdown(duckdb::ClientContext &context,
2727
const TableFunctionProjectionExpressionInput &input);
28+
29+
// API is subject to change, but you can push only none or all aggregates since
30+
// this changes output chunk cardinality
31+
struct TableFunctionUngroupedAggregateInput {
32+
const duckdb::LogicalGet &get;
33+
const duckdb::vector<std::pair<idx_t, const duckdb::Expression &>> &projections;
34+
};
35+
36+
bool aggregate_pushdown(duckdb::ClientContext &context, const TableFunctionUngroupedAggregateInput &input);

vortex-duckdb/cpp/table_function.cpp

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "data.hpp"
55
#include "error.hpp"
66
#include "table_function.hpp"
7+
#include "expr.h"
78
#include "vortex_duckdb.h"
89
#include "table_function.h"
910
#include "vortex.h"
@@ -171,12 +172,16 @@ struct CTableBindResult {
171172
vector<string> &names;
172173
};
173174

175+
// This is a flaw of Duckdb API which doesn't allow passing non-const
176+
// expressions. We never modify the value on Rust side.
177+
static duckdb_vx_expr get_ffi_expr(const Expression &expr) {
178+
return reinterpret_cast<duckdb_vx_expr>(const_cast<Expression *>(&expr));
179+
}
180+
174181
bool projection_expression_pushdown(ClientContext &, const TableFunctionProjectionExpressionInput &input) {
175182
const auto &bind = input.get.bind_data->Cast<CTableBindData>();
176183

177-
// This is a flaw of Duckdb API which doesn't allow passing non-const
178-
// expressions. We never modify the value on Rust side.
179-
auto ffi_expr = reinterpret_cast<duckdb_vx_expr>(const_cast<Expression *>(&input.expression));
184+
duckdb_vx_expr ffi_expr = get_ffi_expr(input.expression);
180185
void *const ffi_bind = bind.ffi_data->DataPtr();
181186
duckdb_vx_error error_out = nullptr;
182187

@@ -191,6 +196,33 @@ bool projection_expression_pushdown(ClientContext &, const TableFunctionProjecti
191196
return ret;
192197
}
193198

199+
using Projections = vector<std::pair<idx_t, const Expression &>>;
200+
201+
extern "C" {
202+
idx_t duckdb_vx_aggregate_len(duckdb_vx_agg_input ffi) {
203+
return reinterpret_cast<const Projections *>(ffi)->size();
204+
}
205+
206+
duckdb_vx_expr duckdb_vx_aggregate_i(duckdb_vx_agg_input ffi, idx_t i, idx_t *proj_idx) {
207+
const Projections &projections = *reinterpret_cast<const Projections *>(ffi);
208+
*proj_idx = projections[i].first;
209+
return get_ffi_expr(projections[i].second);
210+
}
211+
}
212+
213+
bool aggregate_pushdown(ClientContext &, const TableFunctionUngroupedAggregateInput &input) {
214+
const auto &bind = input.get.bind_data->Cast<CTableBindData>();
215+
void *const ffi_bind = bind.ffi_data->DataPtr();
216+
duckdb_vx_error error_out = nullptr;
217+
const auto ffi_input =
218+
reinterpret_cast<duckdb_vx_agg_input>(const_cast<Projections *>(&input.projections));
219+
const bool res = duckdb_table_function_pushdown_projection_aggregates(ffi_bind, ffi_input, &error_out);
220+
if (error_out) {
221+
throw BinderException(IntoErrString(error_out));
222+
}
223+
return res;
224+
}
225+
194226
/**
195227
* Called for every new query. For example, if there is a VIEW over *.vortex,
196228
* and after a query another file is added matching the glob, for second query
@@ -238,10 +270,11 @@ unique_ptr<GlobalTableFunctionState> c_init_global(ClientContext &context, Table
238270
}
239271

240272
unique_ptr<LocalTableFunctionState>
241-
init_local(ExecutionContext &, TableFunctionInitInput &, GlobalTableFunctionState *global_state) {
273+
init_local(ExecutionContext &, TableFunctionInitInput &input, GlobalTableFunctionState *global_state) {
274+
const void *const ffi_bind = input.bind_data->Cast<CTableBindData>().ffi_data->DataPtr();
242275
void *const ffi_global = global_state->Cast<CTableGlobalData>().ffi_data->DataPtr();
243276

244-
duckdb_vx_data ffi_local_data = duckdb_table_function_init_local(ffi_global);
277+
duckdb_vx_data ffi_local_data = duckdb_table_function_init_local(ffi_bind, ffi_global);
245278
auto cdata = unique_ptr<CData>(reinterpret_cast<CData *>(ffi_local_data));
246279
return make_uniq<CTableLocalData>(std::move(cdata));
247280
}

vortex-duckdb/cpp/vortex_duckdb.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
#include "aggregate_fn_pushdown.hpp"
45
#include "data.hpp"
56
#include "error.hpp"
67
#include "scalar_fn_pushdown.hpp"
@@ -269,6 +270,7 @@ extern "C" duckdb_blob duckdb_vx_value_get_geometry(duckdb_value value) {
269270

270271
static void VortexOptimizeFunction(OptimizerExtensionInput &input, unique_ptr<LogicalOperator> &plan) {
271272
plan = TryPushdownScalarFunctions(input.context, std::move(plan));
273+
plan = TryPushdownAggregateFunctions(input.context, std::move(plan));
272274
}
273275

274276
struct VortexOptimizerExtension final : OptimizerExtension {

0 commit comments

Comments
 (0)