Arrow Python UDFs#
Native Arrow UDFs operate directly on pyarrow.Array objects without converting to Pandas or
row-by-row Python objects. This preserves the columnar layout end-to-end, avoids unnecessary data
copies, and enables vectorized processing using Arrow’s native compute functions.
A Native Arrow UDF is defined using arrow_udf() as a decorator or wrapper, and no additional
configuration is required.
Note
Native Arrow UDFs can also be defined via udf() with pyarrow.Array type hints.
The type hints in the function signature determine which kind of Arrow UDF is created
(e.g., returning pa.Array creates an array-to-array UDF, while returning float
creates an aggregate UDF).
Native Arrow UDF Types#
Arrays to Array#
The type hint can be expressed as pyarrow.Array, … -> pyarrow.Array.
The function takes one or more pyarrow.Array and outputs one pyarrow.Array.
The output should always be of the same length as the input.
import pyarrow as pa
from pyspark.sql.functions import arrow_udf
@arrow_udf("string")
def to_upper(s: pa.Array) -> pa.Array:
return pa.compute.ascii_upper(s)
df = spark.createDataFrame([("John Doe",)], ("name",))
df.select(to_upper("name")).show()
# +--------------+
# |to_upper(name)|
# +--------------+
# | JOHN DOE|
# +--------------+
When the returnType is a struct type, the function returns a pa.StructArray:
import pyarrow as pa
from pyspark.sql.functions import arrow_udf
@arrow_udf("first string, last string")
def split_expand(v: pa.Array) -> pa.Array:
b = pa.compute.ascii_split_whitespace(v)
s0 = pa.array([t[0] for t in b])
s1 = pa.array([t[1] for t in b])
return pa.StructArray.from_arrays([s0, s1], names=["first", "last"])
df = spark.createDataFrame([("John Doe",)], ("name",))
df.select(split_expand("name")).show()
# +------------------+
# |split_expand(name)|
# +------------------+
# | {John, Doe}|
# +------------------+
Arrow UDFs support keyword arguments:
import pyarrow as pa
from pyspark.sql import functions as sf
from pyspark.sql.functions import arrow_udf
from pyspark.sql.types import IntegerType
@arrow_udf(returnType=IntegerType())
def calc(a: pa.Array, b: pa.Array) -> pa.Array:
return pa.compute.add(a, pa.compute.multiply(b, 10))
spark.range(2).select(calc(b=sf.col("id") * 10, a=sf.col("id"))).show()
# +-----------------------------+
# |calc(b => (id * 10), a => id)|
# +-----------------------------+
# | 0|
# | 101|
# +-----------------------------+
Iterator of Arrays to Iterator of Arrays#
The type hint can be expressed as Iterator[pyarrow.Array] -> Iterator[pyarrow.Array].
The function takes an iterator of pyarrow.Array and outputs an iterator of pyarrow.Array.
The length of the entire output should be the same as the entire input.
This is useful when the UDF execution requires expensive initialization.
import pyarrow as pa
from pyspark.sql.functions import arrow_udf
from typing import Iterator
@arrow_udf("long")
def plus_one(iterator: Iterator[pa.Array]) -> Iterator[pa.Array]:
for v in iterator:
yield pa.compute.add(v, 1)
df = spark.createDataFrame([(1,), (2,), (3,)], ["v"])
df.select(plus_one(df.v)).show()
# +-----------+
# |plus_one(v)|
# +-----------+
# | 2|
# | 3|
# | 4|
# +-----------+
Iterator of Multiple Arrays to Iterator of Arrays#
The type hint can be expressed as Iterator[Tuple[pyarrow.Array, ...]] -> Iterator[pyarrow.Array].
The function takes an iterator of a tuple of multiple pyarrow.Array and outputs an
iterator of pyarrow.Array. Use this when the UDF requires multiple input columns.
import pyarrow as pa
from pyspark.sql import functions as sf
from pyspark.sql.functions import arrow_udf
from typing import Iterator, Tuple
@arrow_udf("long")
def multiply(
iterator: Iterator[Tuple[pa.Array, pa.Array]]
) -> Iterator[pa.Array]:
for v1, v2 in iterator:
yield pa.compute.multiply(v1, v2.field("v"))
df = spark.createDataFrame([(1,), (2,), (3,)], ["v"])
df.withColumn(
'output', multiply(sf.col("v"), sf.struct(sf.col("v")))
).show()
# +---+------+
# | v|output|
# +---+------+
# | 1| 1|
# | 2| 4|
# | 3| 9|
# +---+------+
Arrays to Scalar#
The type hint can be expressed as pyarrow.Array, … -> Any.
The function takes one or more pyarrow.Array and returns a scalar value. The return type
annotation can be any type other than pa.Array, Iterator, or Tuple, which match the
array-to-array or iterator patterns above. The returned scalar can be
a Python primitive type (e.g., int or float), a NumPy data type, or a pyarrow.Scalar
instance which supports complex return types.
This type of UDF can be used with GroupedData.agg and
Window operations.
import pyarrow as pa
from pyspark.sql.functions import arrow_udf
@arrow_udf("double")
def mean_udf(v: pa.Array) -> float:
return pa.compute.mean(v).as_py()
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
df.groupby("id").agg(mean_udf(df['v'])).show()
# +---+-----------+
# | id|mean_udf(v)|
# +---+-----------+
# | 1| 1.5|
# | 2| 6.0|
# +---+-----------+
The return type can also be a complex type such as struct:
import pyarrow as pa
from pyspark.sql.functions import arrow_udf
@arrow_udf("struct<m1: double, m2: double>")
def min_max_udf(v: pa.Array) -> pa.Scalar:
m1 = pa.compute.min(v)
m2 = pa.compute.max(v)
t = pa.struct([pa.field("m1", pa.float64()), pa.field("m2", pa.float64())])
return pa.scalar(value={"m1": m1.as_py(), "m2": m2.as_py()}, type=t)
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
df.groupby("id").agg(min_max_udf(df['v'])).show()
# +---+--------------+
# | id|min_max_udf(v)|
# +---+--------------+
# | 1| {1.0, 2.0}|
# | 2| {3.0, 10.0}|
# +---+--------------+
This UDF can also be used as window functions:
import pyarrow as pa
from pyspark.sql import Window
from pyspark.sql.functions import arrow_udf
@arrow_udf("double")
def mean_udf(v: pa.Array) -> float:
return pa.compute.mean(v).as_py()
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
w = Window.partitionBy('id').orderBy('v').rowsBetween(-1, 0)
df.withColumn('mean_v', mean_udf("v").over(w)).show()
# +---+----+------+
# | id| v|mean_v|
# +---+----+------+
# | 1| 1.0| 1.0|
# | 1| 2.0| 1.5|
# | 2| 3.0| 3.0|
# | 2| 5.0| 4.0|
# | 2|10.0| 7.5|
# +---+----+------+
Note
For performance reasons, the input arrays to window functions are not copied. Mutating the input arrays is not allowed and will cause incorrect results.
Iterator of Arrays to Scalar#
The type hint can be expressed as Iterator[pyarrow.Array] -> a scalar type.
The function takes an iterator of pyarrow.Array and returns a scalar value. This is useful for
grouped aggregations where the UDF can process all batches iteratively, which is more
memory-efficient than loading all data at once.
Note
Only a single UDF is supported per aggregation.
import pyarrow as pa
from pyspark.sql.functions import arrow_udf
from typing import Iterator
@arrow_udf("double")
def arrow_mean(it: Iterator[pa.Array]) -> float:
sum_val = 0.0
cnt = 0
for v in it:
sum_val += pa.compute.sum(v).as_py()
cnt += len(v)
return sum_val / cnt
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
df.groupby("id").agg(arrow_mean(df['v'])).show()
# +---+-------------+
# | id|arrow_mean(v)|
# +---+-------------+
# | 1| 1.5|
# | 2| 6.0|
# +---+-------------+
Iterator of Multiple Arrays to Scalar#
The type hint can be expressed as Iterator[Tuple[pyarrow.Array, ...]] -> a scalar type.
The function takes an iterator of a tuple of multiple pyarrow.Array and returns a scalar value.
This is useful for grouped aggregations with multiple input columns.
Note
Only a single UDF is supported per aggregation.
import pyarrow as pa
import numpy as np
from pyspark.sql.functions import arrow_udf
from typing import Iterator, Tuple
@arrow_udf("double")
def arrow_weighted_mean(
it: Iterator[Tuple[pa.Array, pa.Array]]
) -> float:
weighted_sum = 0.0
weight = 0.0
for v, w in it:
weighted_sum += np.dot(v, w)
weight += pa.compute.sum(w).as_py()
return weighted_sum / weight
df = spark.createDataFrame(
[(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
("id", "v", "w"))
df.groupby("id").agg(arrow_weighted_mean(df["v"], df["w"])).show()
# +---+-------------------------+
# | id|arrow_weighted_mean(v, w)|
# +---+-------------------------+
# | 1| 1.6666666666666...|
# | 2| 7.166666666666...|
# +---+-------------------------+
Arrow Function APIs#
Arrow Function APIs apply Python native functions directly on Arrow data at the DataFrame level.
They work similarly to Pandas Function APIs but use pyarrow.RecordBatch and pyarrow.Table
instead of Pandas DataFrames.
Map#
DataFrame.mapInArrow() maps an iterator of pyarrow.RecordBatch to another iterator of
pyarrow.RecordBatch. The input and output can have different lengths.
import pyarrow as pa
df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
def filter_func(iterator):
for batch in iterator:
yield batch.filter(pa.compute.field("id") == 1)
df.mapInArrow(filter_func, df.schema).show()
# +---+---+
# | id|age|
# +---+---+
# | 1| 21|
# +---+---+
For detailed usage, please see DataFrame.mapInArrow().
Grouped Map#
DataFrame.groupBy().applyInArrow() maps each group using a function that takes a
pyarrow.Table and returns a pyarrow.Table.
import pyarrow as pa
import pyarrow.compute as pc
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
def normalize(table):
v = table.column("v")
norm = pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, ddof=1))
return table.set_column(1, "v", norm)
df.groupby("id").applyInArrow(
normalize, schema="id long, v double"
).sort("id", "v").show()
# +---+-------------------+
# | id| v|
# +---+-------------------+
# | 1|-0.7071067811865...|
# | 1| 0.7071067811865...|
# | 2|-0.8320502943378...|
# | 2|-0.2773500981126...|
# | 2| 1.1094003924504...|
# +---+-------------------+
The function can also accept grouping keys as the first argument:
def mean_func(key, table):
mean = pc.mean(table.column("v"))
return pa.Table.from_pydict({"id": [key[0].as_py()], "v": [mean.as_py()]})
df.groupby('id').applyInArrow(
mean_func, schema="id long, v double"
).sort("id").show()
# +---+---+
# | id| v|
# +---+---+
# | 1|1.5|
# | 2|6.0|
# +---+---+
For detailed usage, please see GroupedData.applyInArrow().
Co-grouped Map#
DataFrame.groupBy().cogroup().applyInArrow() allows two DataFrames to be cogrouped by a
common key and then a Python function applied to each cogroup. The function takes two
pyarrow.Table and returns a pyarrow.Table.
import pyarrow as pa
df1 = spark.createDataFrame(
[(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)], ("id", "v1"))
df2 = spark.createDataFrame([(1, "x"), (2, "y")], ("id", "v2"))
def summarize(l, r):
return pa.Table.from_pydict({
"left": [l.num_rows],
"right": [r.num_rows]
})
df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow(
summarize, schema="left long, right long"
).show()
# +----+-----+
# |left|right|
# +----+-----+
# | 2| 1|
# | 2| 1|
# +----+-----+
The function can also accept grouping keys as the first argument:
def summarize(key, l, r):
return pa.Table.from_pydict({
"key": [key[0].as_py()],
"left": [l.num_rows],
"right": [r.num_rows]
})
df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow(
summarize, schema="key long, left long, right long"
).sort("key").show()
# +---+----+-----+
# |key|left|right|
# +---+----+-----+
# | 1| 2| 1|
# | 2| 2| 1|
# +---+----+-----+
For detailed usage, please see PandasCogroupedOps.applyInArrow().
Notes#
SQL boolean expressions do not short-circuit: in WHERE cond AND udf(x), the UDF may be
called on all rows regardless of cond. If the function can fail on certain input values
(e.g., division by zero), handle those cases inside the function itself.
The Arrow data type of the returned pyarrow.Array should match the declared returnType.
When there is a mismatch, Spark will attempt to convert the returned data to the expected type
using Arrow’s safe casting, which raises an error on overflow or precision loss.
Supported SQL types are the same as for Arrow-based conversion. See Supported SQL Types for details.
See Also#
udf()– Create a Python UDF (with optional Arrow optimization)arrow_udtf()– Create a vectorized Arrow UDTF (see Arrow Python UDTFs)pandas_udf()– Create a Pandas UDF