Skip to content

Commit 9b2c3da

Browse files
authored
Extension bindings (apache#266)
* Introduce to_variant trait function to LogicalNode and create Explain LogicalNode bindings * Cargo fmt * bindings for Extension LogicalNode * Add missing classes to list of exports so test_imports will pass * Update to point to proper repo * Update pytest to adhere to aggregate calls being wrapped in projections * Address linter change which causes a pytest to fail
1 parent 2172d3f commit 9b2c3da

9 files changed

Lines changed: 228 additions & 134 deletions

File tree

Cargo.lock

Lines changed: 151 additions & 117 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ default = ["mimalloc"]
3434
tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync"] }
3535
rand = "0.8"
3636
pyo3 = { version = "0.18.0", features = ["extension-module", "abi3", "abi3-py37"] }
37-
datafusion = { version = "19.0.0", features = ["pyarrow", "avro"] }
38-
datafusion-expr = "19.0.0"
39-
datafusion-optimizer = "19.0.0"
40-
datafusion-common = { version = "19.0.0", features = ["pyarrow"] }
41-
datafusion-sql = "19.0.0"
42-
datafusion-substrait = "19.0.0"
37+
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "dd98aab", features = ["pyarrow", "avro"]}
38+
datafusion-expr = { git = "https://github.com/apache/arrow-datafusion.git", rev = "dd98aab" }
39+
datafusion-optimizer = { git = "https://github.com/apache/arrow-datafusion.git", rev = "dd98aab" }
40+
datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git", rev = "dd98aab", features = ["pyarrow"]}
41+
datafusion-sql = { git = "https://github.com/apache/arrow-datafusion.git", rev = "dd98aab" }
42+
datafusion-substrait = { git = "https://github.com/apache/arrow-datafusion.git", rev = "dd98aab" }
4343
uuid = { version = "1.2", features = ["v4"] }
4444
mimalloc = { version = "*", optional = true, default-features = false }
4545
async-trait = "0.1"

datafusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
TryCast,
7878
Between,
7979
Explain,
80+
Extension,
8081
)
8182

8283
__version__ = importlib_metadata.version(__name__)
@@ -129,6 +130,7 @@
129130
"TryCast",
130131
"Between",
131132
"Explain",
133+
"Extension",
132134
]
133135

134136

datafusion/tests/test_dataframe.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -350,14 +350,13 @@ def test_logical_plan(aggregate_df):
350350
def test_optimized_logical_plan(aggregate_df):
351351
plan = aggregate_df.optimized_logical_plan()
352352

353-
expected = "Projection: test.c1, SUM(test.c2)"
353+
expected = "Aggregate: groupBy=[[test.c1]], aggr=[[SUM(test.c2)]]"
354354

355355
assert expected == plan.display()
356356

357357
expected = (
358-
"Projection: test.c1, SUM(test.c2)\n"
359-
" Aggregate: groupBy=[[test.c1]], aggr=[[SUM(test.c2)]]\n"
360-
" TableScan: test projection=[c1, c2]"
358+
"Aggregate: groupBy=[[test.c1]], aggr=[[SUM(test.c2)]]\n"
359+
" TableScan: test projection=[c1, c2]"
361360
)
362361

363362
assert expected == plan.display_indent()
@@ -366,9 +365,7 @@ def test_optimized_logical_plan(aggregate_df):
366365
def test_execution_plan(aggregate_df):
367366
plan = aggregate_df.execution_plan()
368367

369-
expected = (
370-
"ProjectionExec: expr=[c1@0 as c1, SUM(test.c2)@1 as SUM(test.c2)]\n"
371-
)
368+
expected = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[SUM(test.c2)]\n" # noqa: E501
372369

373370
assert expected == plan.display()
374371

@@ -382,7 +379,6 @@ def test_execution_plan(aggregate_df):
382379

383380
# indent plan will be different for everyone due to absolute path
384381
# to filename, so we just check for some expected content
385-
assert "ProjectionExec:" in indent
386382
assert "AggregateExec:" in indent
387383
assert "CoalesceBatchesExec:" in indent
388384
assert "RepartitionExec:" in indent

datafusion/tests/test_imports.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777
Cast,
7878
TryCast,
7979
Between,
80+
Explain,
81+
Extension,
8082
)
8183

8284

@@ -143,6 +145,8 @@ def test_class_module_is_datafusion():
143145
Cast,
144146
TryCast,
145147
Between,
148+
Explain,
149+
Extension,
146150
]:
147151
assert klass.__module__ == "datafusion.expr"
148152

src/context.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ use datafusion::physical_plan::SendableRecordBatchStream;
4949
use datafusion::prelude::{
5050
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
5151
};
52+
use datafusion_common::config::Extensions;
5253
use datafusion_common::ScalarValue;
5354
use pyo3::types::PyTuple;
5455
use tokio::runtime::Runtime;
@@ -698,19 +699,20 @@ impl PySessionContext {
698699
part: usize,
699700
py: Python,
700701
) -> PyResult<PyRecordBatchStream> {
701-
let ctx = Arc::new(TaskContext::new(
702+
let ctx = TaskContext::try_new(
702703
"task_id".to_string(),
703704
"session_id".to_string(),
704705
HashMap::new(),
705706
HashMap::new(),
706707
HashMap::new(),
707708
Arc::new(RuntimeEnv::default()),
708-
));
709+
Extensions::default(),
710+
);
709711
// create a Tokio runtime to run the async code
710712
let rt = Runtime::new().unwrap();
711713
let plan = plan.plan.clone();
712714
let fut: JoinHandle<datafusion_common::Result<SendableRecordBatchStream>> =
713-
rt.spawn(async move { plan.execute(part, ctx) });
715+
rt.spawn(async move { plan.execute(part, Arc::new(ctx?)) });
714716
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
715717
Ok(PyRecordBatchStream::new(stream?))
716718
}

src/expr.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pub mod cross_join;
5151
pub mod empty_relation;
5252
pub mod exists;
5353
pub mod explain;
54+
pub mod extension;
5455
pub mod filter;
5556
pub mod grouping_set;
5657
pub mod in_list;
@@ -272,6 +273,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
272273
m.add_class::<join::PyJoinConstraint>()?;
273274
m.add_class::<cross_join::PyCrossJoin>()?;
274275
m.add_class::<union::PyUnion>()?;
276+
m.add_class::<extension::PyExtension>()?;
275277
m.add_class::<filter::PyFilter>()?;
276278
m.add_class::<projection::PyProjection>()?;
277279
m.add_class::<table_scan::PyTableScan>()?;

src/expr/extension.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion_expr::Extension;
19+
use pyo3::prelude::*;
20+
21+
use crate::sql::logical::PyLogicalPlan;
22+
23+
use super::logical_node::LogicalNode;
24+
25+
#[pyclass(name = "Extension", module = "datafusion.expr", subclass)]
26+
#[derive(Clone)]
27+
pub struct PyExtension {
28+
pub node: Extension,
29+
}
30+
31+
impl From<Extension> for PyExtension {
32+
fn from(node: Extension) -> PyExtension {
33+
PyExtension { node }
34+
}
35+
}
36+
37+
#[pymethods]
38+
impl PyExtension {
39+
fn name(&self) -> PyResult<String> {
40+
Ok(self.node.node.name().to_string())
41+
}
42+
}
43+
44+
impl LogicalNode for PyExtension {
45+
fn inputs(&self) -> Vec<PyLogicalPlan> {
46+
vec![]
47+
}
48+
49+
fn to_variant(&self, py: Python) -> PyResult<PyObject> {
50+
Ok(self.clone().into_py(py))
51+
}
52+
}

src/sql/logical.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use crate::expr::aggregate::PyAggregate;
2222
use crate::expr::analyze::PyAnalyze;
2323
use crate::expr::empty_relation::PyEmptyRelation;
2424
use crate::expr::explain::PyExplain;
25+
use crate::expr::extension::PyExtension;
2526
use crate::expr::filter::PyFilter;
2627
use crate::expr::limit::PyLimit;
2728
use crate::expr::projection::PyProjection;
@@ -60,6 +61,7 @@ impl PyLogicalPlan {
6061
LogicalPlan::Analyze(plan) => PyAnalyze::from(plan.clone()).to_variant(py),
6162
LogicalPlan::EmptyRelation(plan) => PyEmptyRelation::from(plan.clone()).to_variant(py),
6263
LogicalPlan::Explain(plan) => PyExplain::from(plan.clone()).to_variant(py),
64+
LogicalPlan::Extension(plan) => PyExtension::from(plan.clone()).to_variant(py),
6365
LogicalPlan::Filter(plan) => PyFilter::from(plan.clone()).to_variant(py),
6466
LogicalPlan::Limit(plan) => PyLimit::from(plan.clone()).to_variant(py),
6567
LogicalPlan::Projection(plan) => PyProjection::from(plan.clone()).to_variant(py),

0 commit comments

Comments
 (0)