Skip to content

Commit 2617080

Browse files
authored
Improve substrait support to include more types and expressions (#6484)
## Does this PR closes an open issue or discussion? <!-- This helps us keep track of fixed issues and changes. --> - Closes #. ## What changes are included in this PR? 1. Support for more types - including new time related ones and decimal. 2. Support for basic arithmetic operations (+,-,/,*). 3. Support for the `is_null` and `is_not_null` substrate expression, which also allows for querying vortex-backed datasets **with filters** through duckdb! 4. Move for `expr.evalute(..)` API to `array.apply(...)`, to align the python API better with the Rust one. ## What is the rationale for this change? Both the arrow dataset API and substrait open up more ways to use Vortex through existing tools without maintaining dedicated integrations. ## How is this change tested? Adds a few additional tests covering new code paths, and it also unlocks an existing test that was expected to fail. ## Are there any user-facing changes? Extends the public API surface in Python, doesn't break any existing code. --------- Signed-off-by: Adam Gutglick <adam@spiraldb.com>
1 parent 83acff1 commit 2617080

12 files changed

Lines changed: 210 additions & 152 deletions

File tree

docs/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ inherits some of its doc strings from Rust docstrings:
99
cd ../vortex-python && uv run maturin develop
1010
```
1111

12+
The docs also require the [`doxygen`](https://www.doxygen.nl/) tool, which can be installed with:
13+
14+
```
15+
brew install doxygen
16+
```
17+
1218
Build the Vortex docs:
1319

1420
```

docs/user-guide/vortex-python.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ Available types: {func}`~vortex.null`, {func}`~vortex.bool_`,
119119

120120
The `vortex.expr` module provides expressions for filtering and projecting. These
121121
are primarily used with {meth}`.VortexFile.scan` and {meth}`.VortexFile.to_arrow` but can also be
122-
evaluated directly:
122+
applied directly:
123123

124124
```{doctest} pycon
125125
>>> import vortex.expr as ve
@@ -129,7 +129,7 @@ evaluated directly:
129129
... {'name': 'Carol', 'age': 35},
130130
... ])
131131
>>> expr = ve.column('age') > 28
132-
>>> expr.evaluate(arr).to_arrow_array().to_pylist()
132+
>>> arr.apply(expr).to_arrow_array().to_pylist()
133133
[True, False, True]
134134
```
135135

vortex-python/python/vortex/_lib/arrays.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import polars as pl
99
import pyarrow as pa
1010

1111
from .dtype import DType, PType
12+
from .expr import Expr
1213
from .scalar import Scalar, ScalarPyType
1314
from .serde import ArrayContext
1415

@@ -43,6 +44,7 @@ class Array:
4344
def to_polars_series(self) -> pl.Series: ...
4445
def to_pylist(self) -> list[ScalarPyType]: ...
4546
def serialize(self, ctx: ArrayContext) -> bytes: ...
47+
def apply(self, expr: Expr) -> Array: ...
4648

4749
class NativeArray(Array): ...
4850

@@ -60,7 +62,12 @@ class PrimitiveArray(Array):
6062
@property
6163
def ptype(self) -> PType: ...
6264

63-
# TODO(connor): Is this missing a `DecimalArray`?
65+
@final
66+
class DecimalArray(Array):
67+
@property
68+
def precision(self) -> int: ...
69+
@property
70+
def scale(self) -> int: ...
6471

6572
@final
6673
class VarBinArray(Array): ...

vortex-python/python/vortex/_lib/expr.pyi

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@ from typing import TypeAlias, final
66

77
from typing_extensions import override
88

9-
from vortex.type_aliases import IntoArray
10-
11-
from .arrays import Array
129
from .dtype import DType
1310
from .scalar import ScalarPyType
1411

@@ -26,11 +23,15 @@ class Expr:
2623
def __ge__(self, other: IntoExpr) -> Expr: ...
2724
def __and__(self, other: IntoExpr) -> Expr: ...
2825
def __or__(self, other: IntoExpr) -> Expr: ...
29-
def evaluate(self, array: IntoArray) -> Array: ...
26+
def __add__(self, other: IntoExpr) -> Expr: ...
27+
def __sub__(self, other: IntoExpr) -> Expr: ...
28+
def __mul__(self, other: IntoExpr) -> Expr: ...
29+
def __truediv__(self, other: IntoExpr) -> Expr: ...
3030

3131
def column(name: str) -> Expr: ...
3232
def root() -> Expr: ...
3333
def literal(dtype: DType, value: ScalarPyType) -> Expr: ...
3434
def not_(child: Expr) -> Expr: ...
3535
def and_(left: Expr, right: Expr) -> Expr: ...
3636
def cast(child: Expr, dtype: DType) -> Expr: ...
37+
def is_null(child: Expr) -> Expr: ...

vortex-python/python/vortex/substrait.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import operator
55
from collections.abc import Callable
6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, Literal
77

88
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
99

@@ -50,16 +50,34 @@ def literal(substrait_object: Expression.Literal) -> _expr.Expr:
5050
return _expr.literal(_dtype.float_(32, nullable=False), substrait_object.fp32)
5151
case "fp64":
5252
return _expr.literal(_dtype.float_(64, nullable=False), substrait_object.fp64)
53+
case "decimal":
54+
substrait_decimal = substrait_object.decimal
55+
return _expr.literal(
56+
_dtype.decimal(precision=substrait_decimal.precision, scale=substrait_decimal.scale, nullable=False),
57+
int.from_bytes(substrait_decimal.value, byteorder="little", signed=True),
58+
)
5359
case "string":
5460
return _expr.literal(_dtype.utf8(nullable=False), substrait_object.string)
5561
case "binary":
5662
return _expr.literal(_dtype.binary(nullable=False), substrait_object.binary)
5763
case "timestamp":
58-
raise NotImplementedError
64+
# The unit here is from the substrait definition
65+
return _expr.literal(_dtype.timestamp(unit="us", nullable=False), substrait_object.timestamp)
66+
case "precision_timestamp":
67+
unit = _precision_to_unit("precision_timestamp", substrait_object.precision_timestamp.precision)
68+
69+
return _expr.literal(
70+
_dtype.timestamp(unit=unit, nullable=False), substrait_object.precision_timestamp.value
71+
)
5972
case "date":
60-
raise NotImplementedError
73+
# The unit here is from the substrait definition
74+
return _expr.literal(_dtype.date(unit="days", nullable=False), substrait_object.date)
6175
case "time":
62-
raise NotImplementedError
76+
# The unit here is from the substrait definition
77+
return _expr.literal(_dtype.time(unit="us", nullable=False), substrait_object.time)
78+
case "precision_time":
79+
unit = _precision_to_unit("precision_time", substrait_object.precision_time.precision)
80+
return _expr.literal(_dtype.time(unit=unit, nullable=False), substrait_object.precision_time.value)
6381
case "interval_year_to_month":
6482
raise NotImplementedError
6583
case "interval_day_to_second":
@@ -72,10 +90,6 @@ def literal(substrait_object: Expression.Literal) -> _expr.Expr:
7290
raise NotImplementedError
7391
case "fixed_binary":
7492
raise NotImplementedError
75-
case "decimal":
76-
raise NotImplementedError
77-
case "precision_timestamp":
78-
raise NotImplementedError
7993
case "precision_timestamp_tz":
8094
raise NotImplementedError
8195
case "struct":
@@ -103,6 +117,20 @@ def literal(substrait_object: Expression.Literal) -> _expr.Expr:
103117
raise ValueError(f"unknown literal_type {literal_type}")
104118

105119

120+
def _precision_to_unit(type_: str, p: int) -> Literal["s", "ms", "us", "ns"]:
121+
match p:
122+
case 0:
123+
return "s"
124+
case 3:
125+
return "ms"
126+
case 6:
127+
return "us"
128+
case 9:
129+
return "ns"
130+
case other:
131+
raise ValueError(f"{type_} with a precision of {other} is not supported with Vortex")
132+
133+
106134
def field_reference(substrait_object: Expression.FieldReference, schema: NamedStruct) -> _expr.Expr:
107135
# https://github.com/substrait-io/substrait/blob/main/proto/substrait/algebra.proto#L1415
108136
match substrait_object.WhichOneof("reference_type"):
@@ -146,8 +174,6 @@ def scalar_function(
146174
) -> _expr.Expr:
147175
# https://github.com/substrait-io/substrait/blob/main/proto/substrait/extensions/extensions.proto#L57
148176
function = functions[substrait_object.function_reference]
149-
if len(substrait_object.options) != 0:
150-
raise NotImplementedError(substrait_object.options)
151177
arguments = [function_argument(argument, functions, schema) for argument in substrait_object.arguments]
152178
return function(*arguments)
153179

@@ -200,15 +226,34 @@ def extension_function(
200226
case "gte":
201227
return operator.__ge__
202228
case "is_null":
203-
raise NotImplementedError
229+
return _expr.is_null
204230
case "is_not_null":
205-
raise NotImplementedError
231+
return _is_not_null
206232
case name:
207233
raise NotImplementedError(f"Function name {name} not supported")
234+
case "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml":
235+
match substrait_object.name:
236+
case "add":
237+
return operator.__add__
238+
case "subtract":
239+
return operator.__sub__
240+
case "multiply":
241+
return operator.__mul__
242+
case "divide":
243+
return operator.__truediv__
244+
case name:
245+
raise NotImplementedError(f"Arithmetic function {name} not supported")
208246
case uri:
209247
raise NotImplementedError(f"Extension URI {uri} not supported")
210248

211249

250+
def _is_not_null(e: _expr.Expr) -> _expr.Expr:
251+
"""
252+
Helper function to have a well-typed callable to return
253+
"""
254+
return _expr.not_(_expr.is_null(e))
255+
256+
212257
def expression(
213258
substrait_object: Expression, functions: list[Callable[..., _expr.Expr]], schema: NamedStruct
214259
) -> _expr.Expr:

vortex-python/python/vortex/type_aliases.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright the Vortex contributors
3-
from typing import TypeAlias, Union # pyright: ignore[reportDeprecated]
3+
from typing import TypeAlias
44

55
import pyarrow as pa
66

@@ -11,7 +11,6 @@
1111
# TypeAliases do not support __doc__.
1212
IntoProjection: TypeAlias = Expr | list[str] | None
1313
IntoArrayIterator: TypeAlias = Array | ArrayIterator | pa.Table | pa.RecordBatchReader
14-
IntoArray: TypeAlias = Union[Array, "pa.Array[pa.Scalar[pa.DataType]]", pa.Table] # pyright: ignore[reportDeprecated]
1514

1615
# If you make an intersphinx reference to pyarrow.RecordBatchReader in the return type of a function
1716
# *and also* use the IntoProjection type alias in a parameter type, Sphinx thinks the type alias

vortex-python/src/arrays/into_array.rs

Lines changed: 0 additions & 79 deletions
This file was deleted.

vortex-python/src/arrays/mod.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ pub(crate) mod builtins;
55
pub(crate) mod compressed;
66
pub(crate) mod fastlanes;
77
pub(crate) mod from_arrow;
8-
pub mod into_array;
98
mod native;
109
pub(crate) mod py;
1110
mod range_to_sequence;
@@ -44,6 +43,7 @@ use crate::arrow::ToPyArrow;
4443
use crate::dtype::PyDType;
4544
use crate::error::PyVortexError;
4645
use crate::error::PyVortexResult;
46+
use crate::expr::PyExpr;
4747
use crate::install_module;
4848
use crate::python_repr::PythonRepr;
4949
use crate::scalar::PyScalar;
@@ -421,6 +421,36 @@ impl PyArray {
421421
)
422422
}
423423

424+
/// Apply an expression on this array
425+
///
426+
/// Examples
427+
/// --------
428+
///
429+
/// Extract one column from a Vortex array:
430+
///
431+
/// ```python
432+
/// >>> import vortex.expr as ve
433+
/// >>> import vortex as vx
434+
/// >>> array = vx.array([{"a": 0, "b": "hello"}, {"a": 1, "b": "goodbye"}])
435+
/// >>> expr = ve.column("a")
436+
/// >>> array = array.apply(expr)
437+
/// >>> array.to_arrow_array().to_pylist()
438+
/// [0, 1]
439+
/// ```
440+
///
441+
/// See also
442+
/// --------
443+
/// vortex.open : Open an on-disk Vortex array for scanning with an expression.
444+
/// vortex.VortexFile : An on-disk Vortex array ready to scan with an expression.
445+
/// vortex.VortexFile.scan : Scan an on-disk Vortex array with an expression.
446+
pub fn apply(slf: Bound<Self>, expr: PyExpr) -> PyVortexResult<PyArrayRef> {
447+
let slf = PyArrayRef::extract(slf.as_any().as_borrowed())?.into_inner();
448+
449+
let inner = slf.apply(&expr)?;
450+
451+
Ok(PyArrayRef::from(inner))
452+
}
453+
424454
///Rust docs are *not* copied into Python for __lt__: https://github.com/PyO3/pyo3/issues/4326
425455
fn __lt__(slf: Bound<Self>, other: PyArrayRef) -> PyVortexResult<PyArrayRef> {
426456
let slf = PyArrayRef::extract(slf.as_any().as_borrowed())?.into_inner();

0 commit comments

Comments
 (0)