From 061f3ab74ad14edd9554280cb04109faeab2c5a2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 15 May 2026 14:17:59 -0400 Subject: [PATCH 1/2] feat: pickle support for Expr via inline scalar UDF encoding Adds Python-aware encoding to PythonLogicalCodec/PythonPhysicalCodec so a ScalarUDF defined in Python travels inside the serialized expression (cloudpickled into fun_definition) instead of needing a matching registration on the receiver. With that in place, Expr gains __reduce__ + classmethod from_bytes(buf, ctx=None) so pickle.dumps / pickle.loads work end-to-end on expressions built from col, lit, built-in functions, and Python scalar UDFs. Wire format is framed as ; the version byte lets a too-new/too-old payload surface a clean Execution error instead of an opaque cloudpickle unpack failure. Schema serde is via arrow-rs's native IPC (no pyarrow round-trip). Cloudpickle module handle is cached per-interpreter through PyOnceLock. Worker-side context resolution lives in a new datafusion.ipc module: set_worker_ctx / get_worker_ctx / clear_worker_ctx plus a private _resolve_ctx helper consulted by Expr.from_bytes. Priority is explicit ctx > worker ctx > global SessionContext. FFI UDFs still travel by name and require the matching registration on the receiver's context. Aggregate and window UDF inline encoding, the per-session with_python_udf_inlining toggle, sender-side context, and the user-guide docs land in follow-on PRs. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/codec.rs | 428 +++++++++++++++++++++++++++--- crates/core/src/udf.rs | 82 +++++- pyproject.toml | 7 + python/datafusion/__init__.py | 3 +- python/datafusion/expr.py | 58 +++- python/datafusion/ipc.py | 113 ++++++++ python/datafusion/user_defined.py | 10 + python/tests/test_expr.py | 4 +- python/tests/test_pickle_expr.py | 127 +++++++++ uv.lock | 13 +- 10 files changed, 788 insertions(+), 57 deletions(-) create mode 100644 python/datafusion/ipc.py create mode 100644 python/tests/test_pickle_expr.py diff --git a/crates/core/src/codec.rs b/crates/core/src/codec.rs index 088532df2..c95d8cb19 100644 --- a/crates/core/src/codec.rs +++ b/crates/core/src/codec.rs @@ -19,11 +19,11 @@ //! //! Datafusion-python plans can carry references to Python-defined //! objects that the upstream protobuf codecs do not know how to -//! serialize: pure-Python scalar / aggregate / window UDFs, Python -//! query-planning extensions, and so on. Their state lives inside -//! `Py` callables and closures rather than being recoverable -//! from a name in the receiver's function registry. To ship a plan -//! across a process boundary (pickle, `multiprocessing`, Ray actor, +//! serialize: pure-Python scalar UDFs, Python query-planning +//! extensions, and so on. Their state lives inside `Py` +//! callables and closures rather than being recoverable from a name +//! in the receiver's function registry. To ship a plan across a +//! process boundary (pickle, `multiprocessing`, Ray actor, //! `datafusion-distributed`, etc.) those payloads have to be encoded //! into the proto wire format itself. //! @@ -48,52 +48,121 @@ //! plans to survive a serialization round-trip. Both codecs share //! the same payload framing for that reason. //! -//! Payloads emitted by these codecs are tagged with an 8-byte magic -//! prefix so the decoder can distinguish them from arbitrary bytes -//! (empty `fun_definition` from the default codec, user FFI payloads -//! that picked a non-colliding prefix). Dispatch precedence on -//! decode: **Python-inline payload (magic prefix match) → `inner` -//! codec → caller's `FunctionRegistry` fallback.** +//! Payloads emitted by these codecs are framed as +//! ` `. The +//! family magic identifies the UDF flavor; the version byte lets the +//! decoder reject too-new or too-old payloads with a clean error +//! instead of falling into an opaque `cloudpickle` tuple-unpack +//! failure when the tuple shape changes. Dispatch precedence on +//! decode: **family match + supported version → `inner` codec → +//! caller's `FunctionRegistry` fallback.** //! -//! ## Wire-format magic prefix registry +//! ## Wire-format family registry //! -//! | Layer + kind | Magic prefix | -//! | ----------------------------- | ------------ | -//! | `PythonLogicalCodec` scalar | `DFPYUDF1` | -//! | `PythonLogicalCodec` agg | `DFPYUDA1` | -//! | `PythonLogicalCodec` window | `DFPYUDW1` | -//! | `PythonPhysicalCodec` scalar | `DFPYUDF1` | -//! | `PythonPhysicalCodec` agg | `DFPYUDA1` | -//! | `PythonPhysicalCodec` window | `DFPYUDW1` | -//! | `PythonPhysicalCodec` expr | `DFPYPE1` | -//! | User FFI extension codec | user-chosen | -//! | Default codec | (none) | +//! | Layer + kind | Family prefix | +//! | ----------------------------- | ------------- | +//! | `PythonLogicalCodec` scalar | `DFPYUDF` | +//! | `PythonPhysicalCodec` scalar | `DFPYUDF` | +//! | User FFI extension codec | user-chosen | +//! | Default codec | (none) | //! -//! Downstream FFI codecs should pick non-colliding prefixes (use a -//! `DF` namespace plus a crate-specific suffix). The codec +//! Aggregate and window UDF families are reserved for follow-on work. +//! +//! Current wire-format version is [`WIRE_VERSION_CURRENT`]; supported +//! receive range is `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`. +//! Bump [`WIRE_VERSION_CURRENT`] whenever the cloudpickle tuple shape +//! changes; raise [`WIRE_VERSION_MIN_SUPPORTED`] when dropping support +//! for an older shape. +//! +//! Downstream FFI codecs should pick non-colliding family prefixes +//! (use a `DF` namespace plus a crate-specific suffix). The codec //! implementations in this module currently delegate every method to //! `inner`; the encoder/decoder hooks for each kind are added as the //! corresponding Python-side type becomes serializable. use std::sync::Arc; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::ipc::reader::StreamReader; +use arrow::ipc::writer::StreamWriter; use datafusion::common::{Result, TableReference}; use datafusion::datasource::TableProvider; use datafusion::datasource::file_format::FileFormatFactory; use datafusion::execution::TaskContext; -use datafusion::logical_expr::{AggregateUDF, Extension, LogicalPlan, ScalarUDF, WindowUDF}; +use datafusion::logical_expr::{ + AggregateUDF, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, + Volatility, WindowUDF, +}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::logical_plan::{DefaultLogicalExtensionCodec, LogicalExtensionCodec}; use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; +use pyo3::prelude::*; +use pyo3::sync::PyOnceLock; +use pyo3::types::{PyBytes, PyTuple}; -/// Wire-format prefix that tags a `fun_definition` payload as an -/// inlined Python scalar UDF (cloudpickled tuple of name, callable, -/// input schema, return field, volatility). Defined once here so -/// the encoder and decoder cannot drift. -#[allow(dead_code)] -pub(crate) const PY_SCALAR_UDF_MAGIC: &[u8] = b"DFPYUDF1"; +use crate::udf::PythonFunctionScalarUDF; + +// Wire-format framing for inlined Python UDF payloads. +// +// Layout: ` `. +// The family magic identifies the UDF flavor; the version byte lets +// the decoder reject too-new or too-old payloads with a clean error +// instead of falling into an opaque `cloudpickle` tuple-unpack failure +// when the tuple shape changes. Bump [`WIRE_VERSION_CURRENT`] whenever +// the tuple shape changes; raise [`WIRE_VERSION_MIN_SUPPORTED`] when +// dropping support for an older shape. + +/// Family prefix for an inlined Python scalar UDF +/// (cloudpickled tuple of name, callable, input schema, return field, +/// volatility). +pub(crate) const PY_SCALAR_UDF_FAMILY: &[u8] = b"DFPYUDF"; + +/// Wire-format version this build emits. +pub(crate) const WIRE_VERSION_CURRENT: u8 = 1; + +/// Oldest wire-format version this build still decodes. Bump when +/// retiring support for an older payload shape. +pub(crate) const WIRE_VERSION_MIN_SUPPORTED: u8 = 1; + +/// Tag `buf` with the framing header for `family` at the current +/// wire-format version. Append-only — the caller writes the +/// cloudpickle payload after. +fn write_wire_header(buf: &mut Vec, family: &[u8]) { + buf.extend_from_slice(family); + buf.push(WIRE_VERSION_CURRENT); +} + +/// Inspect the framing on `buf`. +/// +/// * `Ok(None)` — `buf` does not carry `family`. The caller should +/// delegate to its `inner` codec. +/// * `Ok(Some(payload))` — `buf` carries `family` at a version this +/// build accepts; `payload` is the cloudpickle blob. +/// * `Err(_)` — `buf` carries `family` but at a version outside +/// `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`. The error +/// names the version and the supported range so an operator can +/// diagnose sender/receiver version drift instead of seeing an +/// opaque cloudpickle tuple-unpack failure. +fn strip_wire_header<'a>(buf: &'a [u8], family: &[u8], kind: &str) -> Result> { + if !buf.starts_with(family) { + return Ok(None); + } + let version_idx = family.len(); + let Some(&version) = buf.get(version_idx) else { + return Err(datafusion::error::DataFusionError::Execution(format!( + "Truncated inline Python {kind} payload: missing wire-format version byte" + ))); + }; + if !(WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT).contains(&version) { + return Err(datafusion::error::DataFusionError::Execution(format!( + "Inline Python {kind} payload wire-format version v{version}; \ + this build supports v{WIRE_VERSION_MIN_SUPPORTED}..=v{WIRE_VERSION_CURRENT}. \ + Align datafusion-python versions on sender and receiver." + ))); + } + Ok(Some(&buf[version_idx + 1..])) +} /// `LogicalExtensionCodec` parked on every `SessionContext`. Holds /// the Python-aware encoding hooks for logical-layer types @@ -177,10 +246,16 @@ impl LogicalExtensionCodec for PythonLogicalCodec { } fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_scalar_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udf(node, buf) } fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udf) = try_decode_python_scalar_udf(buf)? { + return Ok(udf); + } self.inner.try_decode_udf(name, buf) } @@ -212,7 +287,7 @@ impl LogicalExtensionCodec for PythonLogicalCodec { /// encoding on this layer too — otherwise a plan with a Python UDF /// would round-trip at the logical level but break at the physical /// level. Both layers reuse the shared payload framing -/// ([`PY_SCALAR_UDF_MAGIC`] et al.) so the wire format is identical. +/// ([`PY_SCALAR_UDF_FAMILY`]) so the wire format is identical. #[derive(Debug)] pub struct PythonPhysicalCodec { inner: Arc, @@ -249,10 +324,16 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { } fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_scalar_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udf(node, buf) } fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udf) = try_decode_python_scalar_udf(buf)? { + return Ok(udf); + } self.inner.try_decode_udf(name, buf) } @@ -284,3 +365,282 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { self.inner.try_decode_udwf(name, buf) } } + +// ============================================================================= +// Shared Python scalar UDF encode / decode helpers +// +// Both `PythonLogicalCodec` and `PythonPhysicalCodec` consult these on +// every `try_encode_udf` / `try_decode_udf` call. Same wire format on +// both layers — a Python `ScalarUDF` referenced inside a `LogicalPlan` +// or an `ExecutionPlan` round-trips identically. +// ============================================================================= + +/// Encode a Python scalar UDF inline if `node` is one. Returns +/// `Ok(true)` when the payload (`DFPYUDF` family prefix, version byte, +/// cloudpickled tuple) was written and the caller should skip its +/// inner codec. Returns `Ok(false)` for any non-Python UDF, signalling +/// the caller to delegate to its `inner`. +pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut Vec) -> Result { + let Some(py_udf) = node + .inner() + .as_any() + .downcast_ref::() + else { + return Ok(false); + }; + + Python::attach(|py| -> Result { + let bytes = encode_python_scalar_udf(py, py_udf) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + write_wire_header(buf, PY_SCALAR_UDF_FAMILY); + buf.extend_from_slice(&bytes); + Ok(true) + }) +} + +/// Decode an inline Python scalar UDF payload. Returns `Ok(None)` +/// when `buf` does not carry the `DFPYUDF` family prefix, signalling +/// the caller to delegate to its `inner` codec (and eventually the +/// `FunctionRegistry`). +pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> Result>> { + let Some(payload) = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF")? else { + return Ok(None); + }; + + Python::attach(|py| -> Result>> { + let udf = decode_python_scalar_udf(py, payload) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + Ok(Some(Arc::new(ScalarUDF::new_from_impl(udf)))) + }) +} + +/// Build the cloudpickle payload for a `PythonFunctionScalarUDF`. +/// +/// Layout: `cloudpickle.dumps((name, func, input_schema_bytes, +/// return_schema_bytes, volatility_str))`. Schema blobs are produced +/// by arrow-rs's native IPC stream writer (no pyarrow round-trip) and +/// decoded with the matching stream reader on the receiver. See +/// [`build_input_schema_bytes`] for what the input blob carries. +fn encode_python_scalar_udf(py: Python<'_>, udf: &PythonFunctionScalarUDF) -> PyResult> { + let signature = udf.signature(); + let input_dtypes = signature_input_dtypes(signature, "PythonFunctionScalarUDF")?; + let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?; + let return_schema_bytes = build_single_field_schema_bytes(udf.return_field().as_ref())?; + let volatility = volatility_wire_str(signature.volatility); + + let payload = PyTuple::new( + py, + [ + udf.name().into_pyobject(py)?.into_any(), + udf.func().bind(py).clone().into_any(), + PyBytes::new(py, &input_schema_bytes).into_any(), + PyBytes::new(py, &return_schema_bytes).into_any(), + volatility.into_pyobject(py)?.into_any(), + ], + )?; + + cloudpickle(py)? + .call_method1("dumps", (payload,))? + .extract::>() +} + +/// Inverse of [`encode_python_scalar_udf`]. +fn decode_python_scalar_udf(py: Python<'_>, payload: &[u8]) -> PyResult { + let tuple = cloudpickle(py)? + .call_method1("loads", (PyBytes::new(py, payload),))? + .cast_into::()?; + + let name: String = tuple.get_item(0)?.extract()?; + let func: Py = tuple.get_item(1)?.unbind(); + let input_schema_bytes: Vec = tuple.get_item(2)?.extract()?; + let return_schema_bytes: Vec = tuple.get_item(3)?.extract()?; + let volatility_str: String = tuple.get_item(4)?.extract()?; + + let input_types = read_input_dtypes(&input_schema_bytes)?; + let return_field = read_single_return_field(&return_schema_bytes, "PythonFunctionScalarUDF")?; + let volatility = parse_volatility_str(&volatility_str)?; + + Ok(PythonFunctionScalarUDF::from_parts( + name, + func, + input_types, + return_field, + volatility, + )) +} + +/// Serialize a `Schema` to a self-contained IPC stream containing +/// only the schema message (no record batches). Inverse: +/// [`schema_from_ipc_bytes`]. +fn schema_to_ipc_bytes(schema: &Schema) -> arrow::error::Result> { + let mut buf: Vec = Vec::new(); + { + let mut writer = StreamWriter::try_new(&mut buf, schema)?; + writer.finish()?; + } + Ok(buf) +} + +/// Decode an IPC stream containing only a schema message back into a +/// `Schema`. Inverse: [`schema_to_ipc_bytes`]. +fn schema_from_ipc_bytes(bytes: &[u8]) -> arrow::error::Result { + let reader = StreamReader::try_new(std::io::Cursor::new(bytes), None)?; + Ok(reader.schema().as_ref().clone()) +} + +/// Extract the per-arg `DataType`s from a `Signature` known to be +/// `TypeSignature::Exact` (all Python-defined UDFs are constructed +/// with `Signature::exact`). Any other variant indicates the impl was +/// not built by this crate's UDF/UDAF/UDWF constructors. +fn signature_input_dtypes(signature: &Signature, kind: &str) -> PyResult> { + match &signature.type_signature { + TypeSignature::Exact(types) => Ok(types.clone()), + other => Err(pyo3::exceptions::PyValueError::new_err(format!( + "{kind} expected Signature::Exact, got {other:?}" + ))), + } +} + +/// Wrap per-arg `DataType`s in synthetic `arg_{i}` fields and emit +/// the IPC schema blob the encoder writes into the cloudpickle tuple. +/// +/// The names and `nullable: true` are arbitrary: the underlying +/// `TypeSignature::Exact` carries no per-input nullability or +/// metadata, and the receiver collapses these fields back to +/// `Vec` via [`read_input_dtypes`], so anything set here +/// beyond the data type is discarded on decode. +fn build_input_schema_bytes(dtypes: &[DataType]) -> PyResult> { + let fields: Vec = dtypes + .iter() + .enumerate() + .map(|(i, dt)| Field::new(format!("arg_{i}"), dt.clone(), true)) + .collect(); + schema_to_ipc_bytes(&Schema::new(fields)).map_err(arrow_to_py_err) +} + +/// Emit a single-field IPC schema blob. Used for return-type and +/// state-field payloads where the receiver needs to recover field +/// metadata (names, nullability, key/value attributes) verbatim. +fn build_single_field_schema_bytes(field: &Field) -> PyResult> { + schema_to_ipc_bytes(&Schema::new(vec![field.clone()])).map_err(arrow_to_py_err) +} + +/// Decode the per-arg `DataType`s the encoder wrote via +/// [`build_input_schema_bytes`]. +fn read_input_dtypes(bytes: &[u8]) -> PyResult> { + let schema = schema_from_ipc_bytes(bytes).map_err(arrow_to_py_err)?; + Ok(schema + .fields() + .iter() + .map(|f| f.data_type().clone()) + .collect()) +} + +/// Decode a single-field IPC schema blob and return that field by +/// value. `kind` names the UDF flavor in the error message produced +/// when the blob is empty (should be unreachable for sender-side +/// payloads built via [`build_single_field_schema_bytes`]). +fn read_single_return_field(bytes: &[u8], kind: &str) -> PyResult { + let schema = schema_from_ipc_bytes(bytes).map_err(arrow_to_py_err)?; + let field = schema.fields().first().ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err(format!( + "{kind} return schema must contain exactly one field" + )) + })?; + Ok(field.as_ref().clone()) +} + +fn arrow_to_py_err(e: arrow::error::ArrowError) -> PyErr { + pyo3::exceptions::PyValueError::new_err(format!("{e}")) +} + +fn parse_volatility_str(s: &str) -> PyResult { + datafusion_python_util::parse_volatility(s) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}"))) +} + +/// Stable wire-format string for a `Volatility`. Pinned to the three +/// tokens [`datafusion_python_util::parse_volatility`] accepts, so an +/// upstream change to `Volatility`'s `Debug` repr cannot silently +/// produce bytes the decoder rejects. +fn volatility_wire_str(v: Volatility) -> &'static str { + match v { + Volatility::Immutable => "immutable", + Volatility::Stable => "stable", + Volatility::Volatile => "volatile", + } +} + +/// Cached handle to the `cloudpickle` module. +/// +/// The encode/decode helpers above would otherwise re-resolve the +/// module on every call. `py.import` is backed by `sys.modules` and +/// therefore cheap, but each call still walks a dict and re-binds the +/// result; a plan with many Python UDFs pays that cost per UDF. +/// +/// `PyOnceLock` scopes the cached `Py` to the current +/// interpreter, so the slot drops cleanly on interpreter teardown +/// (relevant under CPython subinterpreters, PEP 684) instead of +/// resurrecting a `Py` rooted in a dead interpreter on the next call. +fn cloudpickle<'py>(py: Python<'py>) -> PyResult> { + static CLOUDPICKLE: PyOnceLock> = PyOnceLock::new(); + CLOUDPICKLE + .get_or_try_init(py, || Ok(py.import("cloudpickle")?.unbind().into_any())) + .map(|cached| cached.bind(py).clone()) +} + +#[cfg(test)] +mod wire_header_tests { + use super::*; + + #[test] + fn strip_returns_none_when_family_absent() { + let buf = b"OTHER_PAYLOAD"; + assert!(matches!( + strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF"), + Ok(None) + )); + } + + #[test] + fn strip_errors_on_truncated_version_byte() { + let buf = PY_SCALAR_UDF_FAMILY; + let err = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").unwrap_err(); + assert!(format!("{err}").contains("missing wire-format version byte")); + } + + #[test] + fn strip_errors_on_too_new_version() { + let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); + buf.push(WIRE_VERSION_CURRENT.saturating_add(1)); + buf.extend_from_slice(b"payload"); + let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("wire-format version v")); + assert!(msg.contains("supports")); + assert!(msg.contains("Align datafusion-python versions")); + } + + #[test] + fn strip_errors_on_too_old_version() { + if WIRE_VERSION_MIN_SUPPORTED == 0 { + return; + } + let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); + buf.push(WIRE_VERSION_MIN_SUPPORTED - 1); + buf.extend_from_slice(b"payload"); + assert!(strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").is_err()); + } + + #[test] + fn write_then_strip_round_trips_payload() { + let mut buf = Vec::new(); + write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY); + buf.extend_from_slice(b"scalar-payload"); + + let payload = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF") + .unwrap() + .unwrap(); + assert_eq!(payload, b"scalar-payload"); + } +} diff --git a/crates/core/src/udf.rs b/crates/core/src/udf.rs index c0a39cb47..72cdddba1 100644 --- a/crates/core/src/udf.rs +++ b/crates/core/src/udf.rs @@ -43,7 +43,7 @@ use crate::expr::PyExpr; /// This struct holds the Python written function that is a /// ScalarUDF. #[derive(Debug)] -struct PythonFunctionScalarUDF { +pub(crate) struct PythonFunctionScalarUDF { name: String, func: Py, signature: Signature, @@ -67,6 +67,37 @@ impl PythonFunctionScalarUDF { return_field: Arc::new(return_field), } } + + /// Stored Python callable. Consumed by the codec to cloudpickle + /// the function body across process boundaries. + pub(crate) fn func(&self) -> &Py { + &self.func + } + + pub(crate) fn return_field(&self) -> &FieldRef { + &self.return_field + } + + /// Reconstruct a `PythonFunctionScalarUDF` from the parts emitted + /// by the codec. Inputs collapse to `Vec` because + /// `Signature::exact` cannot carry per-input nullability or + /// metadata — the encoder is free to discard that side of the + /// schema. `return_field` is kept as a `Field` so the post-decode + /// nullability and metadata match the sender's instance. + pub(crate) fn from_parts( + name: String, + func: Py, + input_types: Vec, + return_field: Field, + volatility: Volatility, + ) -> Self { + Self { + name, + func, + signature: Signature::exact(input_types, volatility), + return_field: Arc::new(return_field), + } + } } impl Eq for PythonFunctionScalarUDF {} @@ -75,21 +106,51 @@ impl PartialEq for PythonFunctionScalarUDF { self.name == other.name && self.signature == other.signature && self.return_field == other.return_field - && Python::attach(|py| self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false)) + // Identical pointers ⇒ same Python object. Most equality + // checks compare `Arc`-shared clones of the same UDF + // (e.g. expression rewriting), so the pointer match short- + // circuits before touching the GIL. + && (self.func.as_ptr() == other.func.as_ptr() + || Python::attach(|py| { + // Rust's `PartialEq` cannot return `Result`, so we + // have to pick a side when Python `__eq__` raises. + // `false` is the conservative choice — better to + // report two UDFs as distinct than to wrongly + // merge them — but the silent miss can still + // surface as expression-dedup or cache-lookup + // anomalies. Log at `debug` so the failure is + // observable without flooding production logs. + // FIXME: revisit if upstream `ScalarUDFImpl` + // exposes a fallible `PartialEq`. + self.func + .bind(py) + .eq(other.func.bind(py)) + .unwrap_or_else(|e| { + log::debug!( + target: "datafusion_python::udf", + "PythonFunctionScalarUDF {:?} __eq__ raised; treating as unequal: {e}", + self.name, + ); + false + }) + })) } } impl Hash for PythonFunctionScalarUDF { fn hash(&self, state: &mut H) { + // Hash only the identifying header (name + signature + return + // field). Skipping `func` is intentional: the Rust `Hash` + // contract requires `a == b ⇒ hash(a) == hash(b)`, not the + // converse, so a coarser hash is sound — `PartialEq` still + // disambiguates two UDFs with the same header but distinct + // callables. Falling back to a sentinel on `py_hash` failure + // (as a prior revision did) silently mapped every unhashable + // closure to the same bucket; that is the worst case for a + // hashmap and is what this rewrite avoids. self.name.hash(state); self.signature.hash(state); self.return_field.hash(state); - - Python::attach(|py| { - let py_hash = self.func.bind(py).hash().unwrap_or(0); // Handle unhashable objects - - state.write_isize(py_hash); - }); } } @@ -220,4 +281,9 @@ impl PyScalarUDF { fn __repr__(&self) -> PyResult { Ok(format!("ScalarUDF({})", self.function.name())) } + + #[getter] + fn name(&self) -> &str { + self.function.name() + } } diff --git a/pyproject.toml b/pyproject.toml index 951f7adc3..a02f4608a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,13 @@ classifiers = [ "Programming Language :: Rust", ] dependencies = [ + # cloudpickle is invoked by the Rust-side PythonLogicalCodec / + # PythonPhysicalCodec via pyo3 to serialize Python UDF callables — + # scalar, aggregate, and window — into the proto wire format. + # Lazy-imported on the encode / decode hot paths (and cached after + # the first import), so users who never serialize a plan or + # expression incur no runtime cost beyond the install footprint. + "cloudpickle>=2.0", "pyarrow>=16.0.0;python_version<'3.14'", "pyarrow>=22.0.0;python_version>='3.14'", "typing-extensions;python_version<'3.13'", diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index f08b464bb..dfdeef07e 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -65,7 +65,7 @@ import importlib_metadata # type: ignore[import] # Public submodules -from . import functions, object_store, substrait, unparser +from . import functions, ipc, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. from ._internal import Config @@ -142,6 +142,7 @@ "configure_formatter", "expr", "functions", + "ipc", "lit", "literal", "object_store", diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index e0135e3ed..10b011ffb 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -434,23 +434,59 @@ def variant_name(self) -> str: return self.expr.variant_name() def to_bytes(self, ctx: SessionContext | None = None) -> bytes: - """Serialize this expression to protobuf bytes. + """Serialize this expression to bytes for shipping to another process. - When ``ctx`` is supplied, encoding routes through the session's - installed :class:`LogicalExtensionCodec`. Without ``ctx`` a - default codec is used. + Use this — or :func:`pickle.dumps` — to send an expression to a + worker process for distributed evaluation. + + When ``ctx`` is supplied, encoding routes through that session's + installed :class:`LogicalExtensionCodec`. When ``ctx`` is + ``None``, the default codec is used. + + Built-in functions and Python scalar UDFs travel inside the + returned bytes; the worker does not need to pre-register them. + UDFs imported via the FFI capsule protocol travel by name only + and must be registered on the worker. """ ctx_arg = ctx.ctx if ctx is not None else None return self.expr.to_bytes(ctx_arg) - @staticmethod - def from_bytes(ctx: SessionContext, data: bytes) -> Expr: - """Decode an expression from serialized protobuf bytes. - - ``ctx`` provides the function registry for resolving UDF - references and the logical codec for in-band Python payloads. + @classmethod + def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr: + """Reconstruct an expression from serialized bytes. + + Accepts output of :meth:`to_bytes` or :func:`pickle.dumps`. + ``ctx`` is the :class:`SessionContext` used to resolve any + function references that travel by name (e.g. FFI UDFs). When + ``ctx`` is ``None`` the worker context installed via + :func:`datafusion.ipc.set_worker_ctx` is consulted; if no worker + context is installed, the global :class:`SessionContext` is used + (sufficient for built-ins and Python scalar UDFs, plus any UDFs + registered on the global context). + """ + from datafusion.ipc import _resolve_ctx + + resolved = _resolve_ctx(ctx) + return cls(expr_internal.RawExpr.from_bytes(resolved.ctx, buf)) + + def __reduce__(self) -> tuple: + """Pickle protocol hook. + + Lets expressions be shipped to worker processes via + :func:`pickle.dumps` / :func:`pickle.loads`. Built-in functions + and Python scalar UDFs travel inside the pickle bytes; only + FFI-capsule UDFs require pre-registration on the worker. The + worker's :class:`SessionContext` for resolving those references + is looked up via :func:`datafusion.ipc.set_worker_ctx`, falling + back to the global :class:`SessionContext` if none has been + installed on the worker. """ - return Expr(expr_internal.RawExpr.from_bytes(ctx.ctx, data)) + return (Expr._reconstruct, (self.to_bytes(),)) + + @classmethod + def _reconstruct(cls, proto_bytes: bytes) -> Expr: + """Internal entry point used by :meth:`__reduce__` on unpickle.""" + return cls.from_bytes(proto_bytes) def __richcmp__(self, other: Expr, op: int) -> Expr: """Comparison operator.""" diff --git a/python/datafusion/ipc.py b/python/datafusion/ipc.py new file mode 100644 index 000000000..d1867a917 --- /dev/null +++ b/python/datafusion/ipc.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Worker-side setup for distributing DataFusion expressions. + +When a :class:`Expr` is shipped to a worker process (e.g. through +:func:`multiprocessing.Pool` or a Ray actor), the worker reconstructs the +expression against a :class:`SessionContext`. If the expression references +UDFs imported via the FFI capsule protocol — or any UDF the worker would +otherwise resolve from its registered functions rather than from inside +the shipped expression — install a configured :class:`SessionContext` +once per worker: + +.. code-block:: python + + from datafusion import SessionContext + from datafusion.ipc import set_worker_ctx + + def init_worker(): + ctx = SessionContext() + ctx.register_udaf(my_ffi_aggregate) + set_worker_ctx(ctx) + +Built-in functions and Python scalar UDFs travel inside the shipped +expression itself and do not need pre-registration on the worker. +""" + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from datafusion.context import SessionContext + + +__all__ = [ + "clear_worker_ctx", + "get_worker_ctx", + "set_worker_ctx", +] + + +_local = threading.local() + + +def set_worker_ctx(ctx: SessionContext) -> None: + """Install this worker's :class:`SessionContext` for shipped expressions. + + Call once per worker — typically from a ``multiprocessing.Pool`` + initializer or a Ray actor ``__init__``. Idempotent: overwrites any + previous value. Stored in a thread-local slot, so each thread within a + worker may install its own context independently. + """ + _local.ctx = ctx + + +def clear_worker_ctx() -> None: + """Remove this worker's installed :class:`SessionContext`. + + After clearing, expressions reconstructed in this worker fall back to + the global :class:`SessionContext` — adequate for built-ins and Python + scalar UDFs, but anything imported via the FFI capsule protocol must + be registered on the global context to resolve. + """ + if hasattr(_local, "ctx"): + del _local.ctx + + +def get_worker_ctx() -> SessionContext | None: + """Return this worker's installed :class:`SessionContext`, or ``None``.""" + return getattr(_local, "ctx", None) + + +def _resolve_ctx( + explicit_ctx: SessionContext | None = None, +) -> SessionContext: + """Resolve a context for Expr reconstruction. + + Priority: explicit argument > worker context > global context. + Falling back to the global :class:`SessionContext` (instead of a + freshly constructed one) preserves any registrations the user has + installed on it. + """ + if explicit_ctx is not None: + return explicit_ctx + worker = get_worker_ctx() + if worker is not None: + return worker + # Lazy import: `datafusion/__init__.py` imports `datafusion.ipc` + # before `datafusion.context`, so a module-top import would force + # `datafusion.context` to load mid-init of `datafusion.ipc`. The + # cycle is benign today (context.py only pulls expr.py at module + # scope, neither pulls ipc.py back), but a single new import in + # context.py's transitive deps could turn it into a real cycle. + # Deferring keeps `datafusion.ipc` import-order-independent. + from datafusion.context import SessionContext # noqa: PLC0415 + + return SessionContext.global_ctx() diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 848ab4cee..f80b613a2 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -141,6 +141,16 @@ def __init__( name, func, input_fields, return_field, str(volatility) ) + @property + def name(self) -> str: + """Return the registered name of this UDF. + + For UDFs imported via the FFI capsule protocol, this is the + name the capsule itself reports — not the ``name`` argument + passed to the constructor (which is ignored on the FFI path). + """ + return self._udf.name + def __repr__(self) -> str: """Print a string representation of the Scalar UDF.""" return self._udf.__repr__() diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 6a466f6f2..e1fdeab44 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -1186,7 +1186,7 @@ def test_expr_to_bytes_roundtrip(ctx: SessionContext) -> None: original = col("a") + lit(1) blob = original.to_bytes(ctx) - restored = Expr.from_bytes(ctx, blob) + restored = Expr.from_bytes(blob, ctx=ctx) # Canonical name preserves the structure of the expression even # though the underlying PyExpr instances are different. @@ -1201,6 +1201,6 @@ def test_expr_to_bytes_no_ctx_default_codec() -> None: fresh = SessionContext() original = col("a") * lit(2) blob = original.to_bytes() # encode side: default codec - restored = Expr.from_bytes(fresh, blob) + restored = Expr.from_bytes(blob, ctx=fresh) assert restored.canonical_name() == original.canonical_name() diff --git a/python/tests/test_pickle_expr.py b/python/tests/test_pickle_expr.py new file mode 100644 index 000000000..c0d749271 --- /dev/null +++ b/python/tests/test_pickle_expr.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""In-process pickle round-trip tests for :class:`Expr`. + +Built-in functions and Python scalar UDFs travel with the pickled +expression and do not need worker-side pre-registration. The worker +context (:mod:`datafusion.ipc`) is only consulted for UDFs imported +via the FFI capsule protocol. +""" + +from __future__ import annotations + +import pickle + +import pyarrow as pa +import pytest +from datafusion import Expr, SessionContext, col, lit, udf +from datafusion.ipc import ( + clear_worker_ctx, + set_worker_ctx, +) + + +@pytest.fixture(autouse=True) +def _reset_worker_ctx(): + """Ensure every test starts with no worker context installed.""" + clear_worker_ctx() + yield + clear_worker_ctx() + + +def _double_udf(): + return udf( + lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]), + [pa.int64()], + pa.int64(), + volatility="immutable", + name="double", + ) + + +class TestProtoRoundTrip: + def test_builtin_round_trip(self): + e = col("a") + lit(1) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert decoded.canonical_name() == e.canonical_name() + + def test_to_bytes_from_bytes(self): + e = col("x") * lit(7) + blob = e.to_bytes() + assert isinstance(blob, bytes) + decoded = Expr.from_bytes(blob) + assert decoded.canonical_name() == e.canonical_name() + + def test_explicit_ctx_used(self, ctx): + e = col("a") + lit(1) + decoded = Expr.from_bytes(e.to_bytes(), ctx=ctx) + assert decoded.canonical_name() == e.canonical_name() + + +class TestUDFCodec: + """Python scalar UDFs ride inside the proto blob via the Rust codec. + + No worker context needed on the receiver — the cloudpickled callable is + embedded in ``fun_definition`` and reconstructed automatically. + """ + + def test_udf_self_contained_blob(self): + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + # The codec inlines the callable, so the blob is much bigger than a + # pure built-in blob but doesn't depend on receiver-side registration. + assert len(blob) > 200 + + def test_udf_decodes_into_fresh_ctx(self): + e = _double_udf()(col("a")) + blob = e.to_bytes() + fresh = SessionContext() + decoded = Expr.from_bytes(blob, ctx=fresh) + assert "double" in decoded.canonical_name() + + def test_udf_decodes_via_pickle_with_no_worker_ctx(self): + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "double" in decoded.canonical_name() + + def test_udf_decodes_via_pickle_with_worker_ctx(self): + set_worker_ctx(SessionContext()) + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "double" in decoded.canonical_name() + + def test_closure_capturing_udf_names_match(self): + captured_multiplier = 7 + + def fn(arr): + return pa.array([(v.as_py() or 0) * captured_multiplier for v in arr]) + + u = udf( + fn, + [pa.int64()], + pa.int64(), + volatility="immutable", + name="times_seven", + ) + e = u(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert decoded.canonical_name() == e.canonical_name() diff --git a/uv.lock b/uv.lock index 3b7135e32..3fd3eec4b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -257,6 +257,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767, upload-time = "2024-12-24T18:12:32.852Z" }, ] +[[package]] +name = "cloudpickle" +version = "3.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" }, +] + [[package]] name = "codespell" version = "2.4.1" @@ -316,6 +325,7 @@ wheels = [ name = "datafusion" source = { editable = "." } dependencies = [ + { name = "cloudpickle" }, { name = "pyarrow" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] @@ -351,6 +361,7 @@ docs = [ [package.metadata] requires-dist = [ + { name = "cloudpickle", specifier = ">=2.0" }, { name = "pyarrow", marker = "python_full_version < '3.14'", specifier = ">=16.0.0" }, { name = "pyarrow", marker = "python_full_version >= '3.14'", specifier = ">=22.0.0" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, From d0baeb6c04d8e751ca1b3b8ea6b3ab9d2262c101 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 15 May 2026 14:23:53 -0400 Subject: [PATCH 2/2] feat: inline encoding for Python aggregate and window UDFs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends the PythonLogicalCodec / PythonPhysicalCodec inline encoding introduced for scalar UDFs to also cover Python-defined aggregate and window UDFs. The cloudpickle tuple shape per family is: DFPYUDA (agg) (name, accumulator_factory, input_schema_bytes, return_schema_bytes, state_schema_bytes, volatility_str) DFPYUDW (window) (name, evaluator_factory, input_schema_bytes, return_schema_bytes, volatility_str) Same wire-framing as scalar (family magic + version byte + cloudpickle blob), same schema serde (arrow-rs native IPC), same cached cloudpickle handle. The agg state schema is encoded as a full IPC schema so the post-decode UDF reports the same names + nullability + metadata as the sender — relevant for accumulators whose StateFieldsArgs consumers key off names rather than positional DataType. Required restructuring two existing UDF impls so the codec can grab the Python callable directly: * udaf.rs: replaces create_udaf + AccumulatorFactoryFunction closure with a named PythonFunctionAggregateUDF that stores the Py accumulator factory. Synthesizes state_{i} field names when the Python constructor passes only Vec; from_parts preserves the full state schema on the decode side. * udwf.rs: renames MultiColumnWindowUDF -> PythonFunctionWindowUDF, drops the PartitionEvaluatorFactory PtrEq wrapper, stores the Py evaluator directly. PartialEq and Hash get the same pointer-identity fast path + debug-log exception handling already on PythonFunctionScalarUDF. User-facing surface: * AggregateUDF.name and WindowUDF.name properties (parallel to the ScalarUDF.name shipped in PR1). * Existing UDAF/UDWF construction paths are unchanged. The per-session with_python_udf_inlining toggle, sender-side context, strict refusal, and user-guide docs land in PRs 3-4 of this series. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/codec.rs | 264 +++++++++++++++++++++++++++++- crates/core/src/udaf.rs | 186 +++++++++++++++++++-- crates/core/src/udwf.rs | 113 +++++++++---- python/datafusion/expr.py | 23 +-- python/datafusion/ipc.py | 9 +- python/datafusion/user_defined.py | 20 +++ python/tests/test_pickle_expr.py | 116 ++++++++++++- 7 files changed, 653 insertions(+), 78 deletions(-) diff --git a/crates/core/src/codec.rs b/crates/core/src/codec.rs index c95d8cb19..272a1f9b9 100644 --- a/crates/core/src/codec.rs +++ b/crates/core/src/codec.rs @@ -62,12 +62,14 @@ //! | Layer + kind | Family prefix | //! | ----------------------------- | ------------- | //! | `PythonLogicalCodec` scalar | `DFPYUDF` | +//! | `PythonLogicalCodec` agg | `DFPYUDA` | +//! | `PythonLogicalCodec` window | `DFPYUDW` | //! | `PythonPhysicalCodec` scalar | `DFPYUDF` | +//! | `PythonPhysicalCodec` agg | `DFPYUDA` | +//! | `PythonPhysicalCodec` window | `DFPYUDW` | //! | User FFI extension codec | user-chosen | //! | Default codec | (none) | //! -//! Aggregate and window UDF families are reserved for follow-on work. -//! //! Current wire-format version is [`WIRE_VERSION_CURRENT`]; supported //! receive range is `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`. //! Bump [`WIRE_VERSION_CURRENT`] whenever the cloudpickle tuple shape @@ -90,8 +92,8 @@ use datafusion::datasource::TableProvider; use datafusion::datasource::file_format::FileFormatFactory; use datafusion::execution::TaskContext; use datafusion::logical_expr::{ - AggregateUDF, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, - Volatility, WindowUDF, + AggregateUDF, AggregateUDFImpl, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature, + TypeSignature, Volatility, WindowUDF, WindowUDFImpl, }; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ExecutionPlan; @@ -101,7 +103,9 @@ use pyo3::prelude::*; use pyo3::sync::PyOnceLock; use pyo3::types::{PyBytes, PyTuple}; +use crate::udaf::PythonFunctionAggregateUDF; use crate::udf::PythonFunctionScalarUDF; +use crate::udwf::PythonFunctionWindowUDF; // Wire-format framing for inlined Python UDF payloads. // @@ -118,6 +122,16 @@ use crate::udf::PythonFunctionScalarUDF; /// volatility). pub(crate) const PY_SCALAR_UDF_FAMILY: &[u8] = b"DFPYUDF"; +/// Family prefix for an inlined Python aggregate UDF +/// (cloudpickled tuple of name, accumulator factory, input schema, +/// return type, state types schema, volatility). +pub(crate) const PY_AGG_UDF_FAMILY: &[u8] = b"DFPYUDA"; + +/// Family prefix for an inlined Python window UDF +/// (cloudpickled tuple of name, evaluator factory, input schema, +/// return type, volatility). +pub(crate) const PY_WINDOW_UDF_FAMILY: &[u8] = b"DFPYUDW"; + /// Wire-format version this build emits. pub(crate) const WIRE_VERSION_CURRENT: u8 = 1; @@ -260,18 +274,30 @@ impl LogicalExtensionCodec for PythonLogicalCodec { } fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_agg_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udaf(node, buf) } fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udaf) = try_decode_python_agg_udf(buf)? { + return Ok(udaf); + } self.inner.try_decode_udaf(name, buf) } fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_window_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udwf(node, buf) } fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udwf) = try_decode_python_window_udf(buf)? { + return Ok(udwf); + } self.inner.try_decode_udwf(name, buf) } } @@ -350,18 +376,30 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { } fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_agg_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udaf(node, buf) } fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udaf) = try_decode_python_agg_udf(buf)? { + return Ok(udaf); + } self.inner.try_decode_udaf(name, buf) } fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_window_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udwf(node, buf) } fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udwf) = try_decode_python_window_udf(buf)? { + return Ok(udwf); + } self.inner.try_decode_udwf(name, buf) } } @@ -525,6 +563,11 @@ fn build_single_field_schema_bytes(field: &Field) -> PyResult> { schema_to_ipc_bytes(&Schema::new(vec![field.clone()])).map_err(arrow_to_py_err) } +/// Emit a multi-field IPC schema blob. +fn build_schema_bytes(fields: Vec) -> PyResult> { + schema_to_ipc_bytes(&Schema::new(fields)).map_err(arrow_to_py_err) +} + /// Decode the per-arg `DataType`s the encoder wrote via /// [`build_input_schema_bytes`]. fn read_input_dtypes(bytes: &[u8]) -> PyResult> { @@ -589,6 +632,200 @@ fn cloudpickle<'py>(py: Python<'py>) -> PyResult> { .map(|cached| cached.bind(py).clone()) } +// ============================================================================= +// Shared Python window UDF encode / decode helpers +// +// Cloudpickle tuple shape: `(name, evaluator_factory, input_schema_bytes, +// return_schema_bytes, volatility_str)`. The evaluator factory is the +// Python callable that produces a new evaluator instance per partition. +// ============================================================================= + +pub(crate) fn try_encode_python_window_udf(node: &WindowUDF, buf: &mut Vec) -> Result { + let Some(py_udf) = node + .inner() + .as_any() + .downcast_ref::() + else { + return Ok(false); + }; + + Python::attach(|py| -> Result { + let bytes = encode_python_window_udf(py, py_udf) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + write_wire_header(buf, PY_WINDOW_UDF_FAMILY); + buf.extend_from_slice(&bytes); + Ok(true) + }) +} + +pub(crate) fn try_decode_python_window_udf(buf: &[u8]) -> Result>> { + let Some(payload) = strip_wire_header(buf, PY_WINDOW_UDF_FAMILY, "window UDF")? else { + return Ok(None); + }; + + Python::attach(|py| -> Result>> { + let udf = decode_python_window_udf(py, payload) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + Ok(Some(Arc::new(WindowUDF::new_from_impl(udf)))) + }) +} + +fn encode_python_window_udf(py: Python<'_>, udf: &PythonFunctionWindowUDF) -> PyResult> { + let signature = WindowUDFImpl::signature(udf); + let input_dtypes = signature_input_dtypes(signature, "PythonFunctionWindowUDF")?; + let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?; + let return_field = Field::new("result", udf.return_type().clone(), true); + let return_schema_bytes = build_single_field_schema_bytes(&return_field)?; + let volatility = volatility_wire_str(signature.volatility); + + let payload = PyTuple::new( + py, + [ + WindowUDFImpl::name(udf).into_pyobject(py)?.into_any(), + udf.evaluator().bind(py).clone().into_any(), + PyBytes::new(py, &input_schema_bytes).into_any(), + PyBytes::new(py, &return_schema_bytes).into_any(), + volatility.into_pyobject(py)?.into_any(), + ], + )?; + + cloudpickle(py)? + .call_method1("dumps", (payload,))? + .extract::>() +} + +fn decode_python_window_udf(py: Python<'_>, payload: &[u8]) -> PyResult { + let tuple = cloudpickle(py)? + .call_method1("loads", (PyBytes::new(py, payload),))? + .cast_into::()?; + + let name: String = tuple.get_item(0)?.extract()?; + let evaluator: Py = tuple.get_item(1)?.unbind(); + let input_schema_bytes: Vec = tuple.get_item(2)?.extract()?; + let return_schema_bytes: Vec = tuple.get_item(3)?.extract()?; + let volatility_str: String = tuple.get_item(4)?.extract()?; + + let input_types = read_input_dtypes(&input_schema_bytes)?; + let return_type = read_single_return_field(&return_schema_bytes, "PythonFunctionWindowUDF")? + .data_type() + .clone(); + let volatility = parse_volatility_str(&volatility_str)?; + + Ok(PythonFunctionWindowUDF::new( + name, + evaluator, + input_types, + return_type, + volatility, + )) +} + +// ============================================================================= +// Shared Python aggregate UDF encode / decode helpers +// +// Cloudpickle tuple shape: `(name, accumulator_factory, input_schema_bytes, +// return_type_bytes, state_schema_bytes, volatility_str)`. The accumulator +// factory is the Python callable that produces a new accumulator instance +// per partition. +// ============================================================================= + +pub(crate) fn try_encode_python_agg_udf(node: &AggregateUDF, buf: &mut Vec) -> Result { + let Some(py_udf) = node + .inner() + .as_any() + .downcast_ref::() + else { + return Ok(false); + }; + + Python::attach(|py| -> Result { + let bytes = encode_python_agg_udf(py, py_udf) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + write_wire_header(buf, PY_AGG_UDF_FAMILY); + buf.extend_from_slice(&bytes); + Ok(true) + }) +} + +pub(crate) fn try_decode_python_agg_udf(buf: &[u8]) -> Result>> { + let Some(payload) = strip_wire_header(buf, PY_AGG_UDF_FAMILY, "aggregate UDF")? else { + return Ok(None); + }; + + Python::attach(|py| -> Result>> { + let udf = decode_python_agg_udf(py, payload) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + Ok(Some(Arc::new(AggregateUDF::new_from_impl(udf)))) + }) +} + +fn encode_python_agg_udf(py: Python<'_>, udf: &PythonFunctionAggregateUDF) -> PyResult> { + let signature = AggregateUDFImpl::signature(udf); + let input_dtypes = signature_input_dtypes(signature, "PythonFunctionAggregateUDF")?; + let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?; + let return_field = Field::new("result", udf.return_type().clone(), true); + let return_schema_bytes = build_single_field_schema_bytes(&return_field)?; + let state_fields: Vec = udf + .state_fields_ref() + .iter() + .map(|f| f.as_ref().clone()) + .collect(); + let state_schema_bytes = build_schema_bytes(state_fields)?; + let volatility = volatility_wire_str(signature.volatility); + + let payload = PyTuple::new( + py, + [ + AggregateUDFImpl::name(udf).into_pyobject(py)?.into_any(), + udf.accumulator().bind(py).clone().into_any(), + PyBytes::new(py, &input_schema_bytes).into_any(), + PyBytes::new(py, &return_schema_bytes).into_any(), + PyBytes::new(py, &state_schema_bytes).into_any(), + volatility.into_pyobject(py)?.into_any(), + ], + )?; + + cloudpickle(py)? + .call_method1("dumps", (payload,))? + .extract::>() +} + +fn decode_python_agg_udf(py: Python<'_>, payload: &[u8]) -> PyResult { + let tuple = cloudpickle(py)? + .call_method1("loads", (PyBytes::new(py, payload),))? + .cast_into::()?; + + let name: String = tuple.get_item(0)?.extract()?; + let accumulator: Py = tuple.get_item(1)?.unbind(); + let input_schema_bytes: Vec = tuple.get_item(2)?.extract()?; + let return_schema_bytes: Vec = tuple.get_item(3)?.extract()?; + let state_schema_bytes: Vec = tuple.get_item(4)?.extract()?; + let volatility_str: String = tuple.get_item(5)?.extract()?; + + let input_types = read_input_dtypes(&input_schema_bytes)?; + let return_type = read_single_return_field(&return_schema_bytes, "PythonFunctionAggregateUDF")? + .data_type() + .clone(); + // Preserve the encoded state field metadata (names, nullability, + // arbitrary key/value attributes) so the post-decode UDF reports + // the same state schema as the sender's instance — important for + // accumulators whose `StateFieldsArgs` consumers key off names or + // nullability rather than positional `DataType`. + let state_schema = schema_from_ipc_bytes(&state_schema_bytes).map_err(arrow_to_py_err)?; + let state_fields: Vec = + state_schema.fields().iter().cloned().collect(); + let volatility = parse_volatility_str(&volatility_str)?; + + Ok(PythonFunctionAggregateUDF::from_parts( + name, + accumulator, + input_types, + return_type, + state_fields, + volatility, + )) +} + #[cfg(test)] mod wire_header_tests { use super::*; @@ -635,12 +872,23 @@ mod wire_header_tests { #[test] fn write_then_strip_round_trips_payload() { let mut buf = Vec::new(); - write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY); - buf.extend_from_slice(b"scalar-payload"); + write_wire_header(&mut buf, PY_AGG_UDF_FAMILY); + buf.extend_from_slice(b"agg-payload"); - let payload = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF") + let payload = strip_wire_header(&buf, PY_AGG_UDF_FAMILY, "aggregate UDF") .unwrap() .unwrap(); - assert_eq!(payload, b"scalar-payload"); + assert_eq!(payload, b"agg-payload"); + } + + #[test] + fn strip_does_not_match_a_different_family() { + let mut buf = Vec::new(); + write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY); + buf.extend_from_slice(b"payload"); + assert!(matches!( + strip_wire_header(&buf, PY_WINDOW_UDF_FAMILY, "window UDF"), + Ok(None) + )); } } diff --git a/crates/core/src/udaf.rs b/crates/core/src/udaf.rs index 80ef51716..cb84fa375 100644 --- a/crates/core/src/udaf.rs +++ b/crates/core/src/udaf.rs @@ -15,16 +15,18 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::ptr::NonNull; use std::sync::Arc; use datafusion::arrow::array::ArrayRef; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::{DataType, Field, FieldRef}; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::common::ScalarValue; use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf, + Accumulator, AggregateUDF, AggregateUDFImpl, Signature, Volatility, }; use datafusion_ffi::udaf::FFI_AggregateUDF; use datafusion_python_util::parse_volatility; @@ -144,15 +146,161 @@ impl Accumulator for RustAccumulator { } } -pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { - Arc::new(move |_args| -> Result> { - let accum = Python::attach(|py| { - accum - .call0(py) - .map_err(|e| DataFusionError::Execution(format!("{e}"))) - })?; - Ok(Box::new(RustAccumulator::new(accum))) - }) +fn instantiate_accumulator(accum: &Py) -> Result> { + let instance = Python::attach(|py| { + accum + .call0(py) + .map_err(|e| DataFusionError::Execution(format!("{e}"))) + })?; + Ok(Box::new(RustAccumulator::new(instance))) +} + +/// Named-struct `AggregateUDFImpl` for Python-defined aggregate UDFs. +/// Holds the Python accumulator factory directly so the codec can +/// downcast and cloudpickle it across process boundaries. +#[derive(Debug)] +pub(crate) struct PythonFunctionAggregateUDF { + name: String, + accumulator: Py, + signature: Signature, + return_type: DataType, + state_fields: Vec, +} + +impl PythonFunctionAggregateUDF { + fn new( + name: String, + accumulator: Py, + input_types: Vec, + return_type: DataType, + state_types: Vec, + volatility: Volatility, + ) -> Self { + let signature = Signature::exact(input_types, volatility); + let state_fields = state_types + .into_iter() + .enumerate() + .map(|(i, t)| Arc::new(Field::new(format!("state_{i}"), t, true))) + .collect(); + Self { + name, + accumulator, + signature, + return_type, + state_fields, + } + } + + /// Stored Python callable that returns a fresh accumulator instance + /// per partition. Consumed by the codec to cloudpickle the factory + /// across process boundaries. + pub(crate) fn accumulator(&self) -> &Py { + &self.accumulator + } + + pub(crate) fn return_type(&self) -> &DataType { + &self.return_type + } + + pub(crate) fn state_fields_ref(&self) -> &[FieldRef] { + &self.state_fields + } + + /// Reconstruct a `PythonFunctionAggregateUDF` from the parts emitted + /// by the codec. `state_fields` carries the full state schema + /// (names, data types, nullability, metadata) — the codec extracts + /// it from the IPC payload, so the post-decode state schema is + /// identical to the pre-encode one. Use [`Self::new`] when only + /// `Vec` is available (e.g. the Python constructor path, + /// where field names are synthesized). + pub(crate) fn from_parts( + name: String, + accumulator: Py, + input_types: Vec, + return_type: DataType, + state_fields: Vec, + volatility: Volatility, + ) -> Self { + Self { + name, + accumulator, + signature: Signature::exact(input_types, volatility), + return_type, + state_fields, + } + } +} + +impl Eq for PythonFunctionAggregateUDF {} +impl PartialEq for PythonFunctionAggregateUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.signature == other.signature + && self.return_type == other.return_type + && self.state_fields == other.state_fields + // Pointer-identity fast path: `Arc`-shared clones of the + // same UDF skip the GIL roundtrip. Falls through to Python + // `__eq__` only for two distinct callables. + && (self.accumulator.as_ptr() == other.accumulator.as_ptr() + || Python::attach(|py| { + // See `PythonFunctionScalarUDF::eq` for the + // rationale on swallowing the exception as `false` + // and logging at `debug`. FIXME: revisit if + // upstream `AggregateUDFImpl` exposes a fallible + // `PartialEq`. + self.accumulator + .bind(py) + .eq(other.accumulator.bind(py)) + .unwrap_or_else(|e| { + log::debug!( + target: "datafusion_python::udaf", + "PythonFunctionAggregateUDF {:?} __eq__ raised; treating as unequal: {e}", + self.name, + ); + false + }) + })) + } +} + +impl std::hash::Hash for PythonFunctionAggregateUDF { + fn hash(&self, state: &mut H) { + // See `PythonFunctionScalarUDF`'s `Hash` impl for the + // rationale: hash the identifying header only and let + // `PartialEq` disambiguate callables. + self.name.hash(state); + self.signature.hash(state); + self.return_type.hash(state); + for f in &self.state_fields { + f.hash(state); + } + } +} + +impl AggregateUDFImpl for PythonFunctionAggregateUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + instantiate_accumulator(&self.accumulator) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(self.state_fields.clone()) + } } fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult { @@ -190,14 +338,15 @@ impl PyAggregateUDF { state_type: PyArrowType>, volatility: &str, ) -> PyResult { - let function = create_udaf( - name, + let py_udf = PythonFunctionAggregateUDF::new( + name.to_string(), + accumulator, input_type.0, - Arc::new(return_type.0), + return_type.0, + state_type.0, parse_volatility(volatility)?, - to_rust_accumulator(accumulator), - Arc::new(state_type.0), ); + let function = AggregateUDF::new_from_impl(py_udf); Ok(Self { function }) } @@ -231,4 +380,9 @@ impl PyAggregateUDF { fn __repr__(&self) -> PyResult { Ok(format!("AggregateUDF({})", self.function.name())) } + + #[getter] + fn name(&self) -> &str { + self.function.name() + } } diff --git a/crates/core/src/udwf.rs b/crates/core/src/udwf.rs index 1d3608ada..5ce09e6d2 100644 --- a/crates/core/src/udwf.rs +++ b/crates/core/src/udwf.rs @@ -25,10 +25,9 @@ use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow}; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs}; -use datafusion::logical_expr::ptr_eq::PtrEq; use datafusion::logical_expr::window_state::WindowAggState; use datafusion::logical_expr::{ - PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl, + PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; use datafusion::scalar::ScalarValue; use datafusion_ffi::udwf::FFI_WindowUDF; @@ -198,15 +197,13 @@ impl PartitionEvaluator for RustPartitionEvaluator { } } -pub fn to_rust_partition_evaluator(evaluator: Py) -> PartitionEvaluatorFactory { - Arc::new(move || -> Result> { - let evaluator = Python::attach(|py| { - evaluator - .call0(py) - .map_err(|e| DataFusionError::Execution(e.to_string())) - })?; - Ok(Box::new(RustPartitionEvaluator::new(evaluator))) - }) +fn instantiate_partition_evaluator(evaluator: &Py) -> Result> { + let instance = Python::attach(|py| { + evaluator + .call0(py) + .map_err(|e| DataFusionError::Execution(e.to_string())) + })?; + Ok(Box::new(RustPartitionEvaluator::new(instance))) } /// Represents an WindowUDF @@ -234,14 +231,14 @@ impl PyWindowUDF { volatility: &str, ) -> PyResult { let return_type = return_type.0; - let input_types = input_types.into_iter().map(|t| t.0).collect(); + let input_types: Vec = input_types.into_iter().map(|t| t.0).collect(); - let function = WindowUDF::from(MultiColumnWindowUDF::new( + let function = WindowUDF::from(PythonFunctionWindowUDF::new( name, + evaluator, input_types, return_type, parse_volatility(volatility)?, - to_rust_partition_evaluator(evaluator), )); Ok(Self { function }) } @@ -276,47 +273,94 @@ impl PyWindowUDF { fn __repr__(&self) -> PyResult { Ok(format!("WindowUDF({})", self.function.name())) } + + #[getter] + fn name(&self) -> &str { + self.function.name() + } } -#[derive(Hash, Eq, PartialEq)] -pub struct MultiColumnWindowUDF { +#[derive(Debug)] +pub(crate) struct PythonFunctionWindowUDF { name: String, + evaluator: Py, signature: Signature, return_type: DataType, - partition_evaluator_factory: PtrEq, } -impl std::fmt::Debug for MultiColumnWindowUDF { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("WindowUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("return_type", &"") - .field("partition_evaluator_factory", &"") - .finish() - } -} - -impl MultiColumnWindowUDF { - pub fn new( +impl PythonFunctionWindowUDF { + pub(crate) fn new( name: impl Into, + evaluator: Py, input_types: Vec, return_type: DataType, volatility: Volatility, - partition_evaluator_factory: PartitionEvaluatorFactory, ) -> Self { let name = name.into(); let signature = Signature::exact(input_types, volatility); Self { name, + evaluator, signature, return_type, - partition_evaluator_factory: partition_evaluator_factory.into(), } } + + /// Stored Python callable that produces a fresh partition + /// evaluator instance per partition. Consumed by the codec to + /// cloudpickle the evaluator factory across process boundaries. + pub(crate) fn evaluator(&self) -> &Py { + &self.evaluator + } + + pub(crate) fn return_type(&self) -> &DataType { + &self.return_type + } +} + +impl Eq for PythonFunctionWindowUDF {} +impl PartialEq for PythonFunctionWindowUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.signature == other.signature + && self.return_type == other.return_type + // Pointer-identity fast path: `Arc`-shared clones of the + // same UDF skip the GIL roundtrip. Falls through to Python + // `__eq__` only for two distinct callables. + && (self.evaluator.as_ptr() == other.evaluator.as_ptr() + || Python::attach(|py| { + // See `PythonFunctionScalarUDF::eq` for the + // rationale on swallowing the exception as `false` + // and logging at `debug`. FIXME: revisit if + // upstream `WindowUDFImpl` exposes a fallible + // `PartialEq`. + self.evaluator + .bind(py) + .eq(other.evaluator.bind(py)) + .unwrap_or_else(|e| { + log::debug!( + target: "datafusion_python::udwf", + "PythonFunctionWindowUDF {:?} __eq__ raised; treating as unequal: {e}", + self.name, + ); + false + }) + })) + } +} + +impl std::hash::Hash for PythonFunctionWindowUDF { + fn hash(&self, state: &mut H) { + // See `PythonFunctionScalarUDF`'s `Hash` impl for the + // rationale: hash the identifying header only and let + // `PartialEq` disambiguate evaluators. + self.name.hash(state); + self.signature.hash(state); + self.return_type.hash(state); + } } -impl WindowUDFImpl for MultiColumnWindowUDF { +impl WindowUDFImpl for PythonFunctionWindowUDF { fn as_any(&self) -> &dyn Any { self } @@ -339,7 +383,6 @@ impl WindowUDFImpl for MultiColumnWindowUDF { &self, _partition_evaluator_args: PartitionEvaluatorArgs, ) -> Result> { - let _ = _partition_evaluator_args; - (self.partition_evaluator_factory)() + instantiate_partition_evaluator(&self.evaluator) } } diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 10b011ffb..b120d7f7b 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -443,10 +443,10 @@ def to_bytes(self, ctx: SessionContext | None = None) -> bytes: installed :class:`LogicalExtensionCodec`. When ``ctx`` is ``None``, the default codec is used. - Built-in functions and Python scalar UDFs travel inside the - returned bytes; the worker does not need to pre-register them. - UDFs imported via the FFI capsule protocol travel by name only - and must be registered on the worker. + Built-in functions and Python UDFs (scalar, aggregate, window) + travel inside the returned bytes; the worker does not need to + pre-register them. UDFs imported via the FFI capsule protocol + travel by name only and must be registered on the worker. """ ctx_arg = ctx.ctx if ctx is not None else None return self.expr.to_bytes(ctx_arg) @@ -461,7 +461,7 @@ def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr: ``ctx`` is ``None`` the worker context installed via :func:`datafusion.ipc.set_worker_ctx` is consulted; if no worker context is installed, the global :class:`SessionContext` is used - (sufficient for built-ins and Python scalar UDFs, plus any UDFs + (sufficient for built-ins and Python UDFs, plus any UDFs registered on the global context). """ from datafusion.ipc import _resolve_ctx @@ -474,12 +474,13 @@ def __reduce__(self) -> tuple: Lets expressions be shipped to worker processes via :func:`pickle.dumps` / :func:`pickle.loads`. Built-in functions - and Python scalar UDFs travel inside the pickle bytes; only - FFI-capsule UDFs require pre-registration on the worker. The - worker's :class:`SessionContext` for resolving those references - is looked up via :func:`datafusion.ipc.set_worker_ctx`, falling - back to the global :class:`SessionContext` if none has been - installed on the worker. + and Python UDFs (scalar, aggregate, window) travel inside the + pickle bytes; only FFI-capsule UDFs require pre-registration on + the worker. The worker's :class:`SessionContext` for resolving + those references is looked up via + :func:`datafusion.ipc.set_worker_ctx`, falling back to the + global :class:`SessionContext` if none has been installed on + the worker. """ return (Expr._reconstruct, (self.to_bytes(),)) diff --git a/python/datafusion/ipc.py b/python/datafusion/ipc.py index d1867a917..1076fa4af 100644 --- a/python/datafusion/ipc.py +++ b/python/datafusion/ipc.py @@ -35,8 +35,9 @@ def init_worker(): ctx.register_udaf(my_ffi_aggregate) set_worker_ctx(ctx) -Built-in functions and Python scalar UDFs travel inside the shipped -expression itself and do not need pre-registration on the worker. +Built-in functions and Python UDFs (scalar, aggregate, window) travel +inside the shipped expression itself and do not need pre-registration +on the worker. """ from __future__ import annotations @@ -74,8 +75,8 @@ def clear_worker_ctx() -> None: After clearing, expressions reconstructed in this worker fall back to the global :class:`SessionContext` — adequate for built-ins and Python - scalar UDFs, but anything imported via the FFI capsule protocol must - be registered on the global context to resolve. + UDFs (scalar, aggregate, window), but anything imported via the FFI + capsule protocol must be registered on the global context to resolve. """ if hasattr(_local, "ctx"): del _local.ctx diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index f80b613a2..da756473a 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -428,6 +428,16 @@ def __init__( str(volatility), ) + @property + def name(self) -> str: + """Return the registered name of this UDAF. + + For UDAFs imported via the FFI capsule protocol, this is the + name the capsule itself reports — not the ``name`` argument + passed to the constructor (which is ignored on the FFI path). + """ + return self._udaf.name + def __repr__(self) -> str: """Print a string representation of the Aggregate UDF.""" return self._udaf.__repr__() @@ -838,6 +848,16 @@ def __init__( name, func, input_types, return_type, str(volatility) ) + @property + def name(self) -> str: + """Return the registered name of this UDWF. + + For UDWFs imported via the FFI capsule protocol, this is the + name the capsule itself reports — not the ``name`` argument + passed to the constructor (which is ignored on the FFI path). + """ + return self._udwf.name + def __repr__(self) -> str: """Print a string representation of the Window UDF.""" return self._udwf.__repr__() diff --git a/python/tests/test_pickle_expr.py b/python/tests/test_pickle_expr.py index c0d749271..cb0a5b065 100644 --- a/python/tests/test_pickle_expr.py +++ b/python/tests/test_pickle_expr.py @@ -17,10 +17,10 @@ """In-process pickle round-trip tests for :class:`Expr`. -Built-in functions and Python scalar UDFs travel with the pickled -expression and do not need worker-side pre-registration. The worker -context (:mod:`datafusion.ipc`) is only consulted for UDFs imported -via the FFI capsule protocol. +Built-in functions and Python UDFs (scalar, aggregate, window) travel +with the pickled expression and do not need worker-side pre-registration. +The worker context (:mod:`datafusion.ipc`) is only consulted for UDFs +imported via the FFI capsule protocol. """ from __future__ import annotations @@ -125,3 +125,111 @@ def fn(arr): blob = pickle.dumps(e) decoded = pickle.loads(blob) # noqa: S301 assert decoded.canonical_name() == e.canonical_name() + + +class TestAggregateUDFCodec: + """Python aggregate UDFs travel inline like scalar UDFs.""" + + def _build_aggregate_udf(self): + from datafusion import udaf + from datafusion.user_defined import Accumulator + + class CountAcc(Accumulator): + def __init__(self): + self._count = 0 + + def state(self): + return [pa.scalar(self._count, type=pa.int64())] + + def update(self, values): + self._count += len(values) + + def merge(self, states): + for s in states: + self._count += s[0].as_py() + + def evaluate(self): + return pa.scalar(self._count, type=pa.int64()) + + return udaf( + CountAcc, + [pa.int64()], + pa.int64(), + [pa.int64()], + "immutable", + name="count_all", + ) + + def test_agg_udf_self_contained_blob(self): + u = self._build_aggregate_udf() + e = u(col("a")) + blob = pickle.dumps(e) + assert len(blob) > 200 + + def test_agg_udf_decodes_into_fresh_ctx(self): + u = self._build_aggregate_udf() + e = u(col("a")) + blob = e.to_bytes() + fresh = SessionContext() + decoded = Expr.from_bytes(blob, ctx=fresh) + assert "count_all" in decoded.canonical_name() + + def test_agg_udf_decodes_via_pickle_with_no_worker_ctx(self): + u = self._build_aggregate_udf() + e = u(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "count_all" in decoded.canonical_name() + + def test_agg_udf_evaluates_after_roundtrip(self): + """End-to-end: the decoded aggregate UDF runs and merges across + partitions, exercising the round-tripped state-field schema.""" + u = self._build_aggregate_udf() + e = u(col("a")) + decoded = pickle.loads(pickle.dumps(e)) # noqa: S301 + + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2, 3, 4, 5]}) + out = df.aggregate([], [decoded.alias("n")]).to_pydict() + assert out["n"] == [5] + + +class TestWindowUDFCodec: + """Python window UDFs travel inline like scalar UDFs.""" + + def _build_window_udf(self): + from datafusion import udwf + from datafusion.user_defined import WindowEvaluator + + class CountUpEvaluator(WindowEvaluator): + def evaluate_all(self, values, num_rows): + return pa.array(list(range(num_rows))) + + return udwf( + CountUpEvaluator, + [pa.int64()], + pa.int64(), + "immutable", + name="count_up", + ) + + def test_window_udf_self_contained_blob(self): + u = self._build_window_udf() + e = u(col("a")) + blob = pickle.dumps(e) + assert len(blob) > 200 + + def test_window_udf_decodes_into_fresh_ctx(self): + u = self._build_window_udf() + e = u(col("a")) + blob = e.to_bytes() + fresh = SessionContext() + decoded = Expr.from_bytes(blob, ctx=fresh) + assert "count_up" in decoded.canonical_name() + + def test_window_udf_decodes_via_pickle_with_no_worker_ctx(self): + u = self._build_window_udf() + e = u(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "count_up" in decoded.canonical_name()