Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/datafusion/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def _warn_expr_for_literal_arg(function_name: str, arg_name: str) -> None:
)


def _warn_if_expr_for_literal_arg(
value: Any, function_name: str, arg_name: str
) -> None:
if isinstance(value, Expr):
_warn_expr_for_literal_arg(function_name, arg_name)


__all__ = [
"abs",
"acos",
Expand Down Expand Up @@ -437,6 +444,7 @@ def encode(expr: Expr, encoding: Expr | str) -> Expr:
>>> result.collect_column("enc")[0].as_py()
'aGVsbG8'
"""
_warn_if_expr_for_literal_arg(encoding, "encode", "encoding")
encoding = coerce_to_expr(encoding)
return Expr(f.encode(expr.expr, encoding.expr))

Expand All @@ -452,6 +460,7 @@ def decode(expr: Expr, encoding: Expr | str) -> Expr:
>>> result.collect_column("dec")[0].as_py()
b'hello'
"""
_warn_if_expr_for_literal_arg(encoding, "decode", "encoding")
encoding = coerce_to_expr(encoding)
return Expr(f.decode(expr.expr, encoding.expr))

Expand Down Expand Up @@ -742,6 +751,7 @@ def digest(value: Expr, method: Expr | str) -> Expr:
>>> len(result.collect_column("d")[0].as_py()) > 0
True
"""
_warn_if_expr_for_literal_arg(method, "digest", "method")
method = coerce_to_expr(method)
return Expr(f.digest(value.expr, method.expr))

Expand Down Expand Up @@ -3096,6 +3106,7 @@ def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
>>> result.collect_column("c")[0].as_py()
1.0
"""
_warn_if_expr_for_literal_arg(data_type, "arrow_cast", "data_type")
if isinstance(data_type, pa.DataType):
return expr.cast(data_type)
if isinstance(data_type, str):
Expand Down Expand Up @@ -3128,6 +3139,7 @@ def arrow_try_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr:
>>> result.collect_column("c")[0].as_py() is None
True
"""
_warn_if_expr_for_literal_arg(data_type, "arrow_try_cast", "data_type")
if isinstance(data_type, pa.DataType):
return expr.try_cast(data_type)
if isinstance(data_type, str):
Expand Down Expand Up @@ -3235,6 +3247,7 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr:
"""
if key is None:
return Expr(f.arrow_metadata(expr.expr))
_warn_if_expr_for_literal_arg(key, "arrow_metadata", "key")
if isinstance(key, str):
key = Expr.string_literal(key)
return Expr(f.arrow_metadata(expr.expr, key.expr))
Expand Down
62 changes: 62 additions & 0 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2355,6 +2355,68 @@ def test_regexp_replace_native(self):
).collect()
assert result[0].column(0)[0].as_py() == "aX bX cX"

@pytest.mark.parametrize(
("func", "arg_name", "expr"),
[
pytest.param(
f.encode,
"encoding",
lambda: f.encode(column("a"), literal("base64")),
id="encode-encoding",
),
pytest.param(
f.decode,
"encoding",
lambda: f.decode(column("a"), literal("base64")),
id="decode-encoding",
),
pytest.param(
f.digest,
"method",
lambda: f.digest(column("a"), literal("sha256")),
id="digest-method",
),
pytest.param(
f.arrow_cast,
"data_type",
lambda: f.arrow_cast(column("a"), literal("Float64")),
id="arrow-cast-data-type",
),
pytest.param(
f.arrow_try_cast,
"data_type",
lambda: f.arrow_try_cast(column("a"), literal("Float64")),
id="arrow-try-cast-data-type",
),
pytest.param(
f.arrow_metadata,
"key",
lambda: f.arrow_metadata(column("a"), literal("k")),
id="arrow-metadata-key",
),
],
)
def test_literal_only_expr_args_warn_deprecated(self, func, arg_name, expr):
with pytest.warns(
DeprecationWarning,
match=(
rf"Passing Expr for {func.__name__}\(\) argument "
rf"'{arg_name}' is deprecated"
),
):
result = expr()
assert result is not None

def test_literal_only_native_args_do_not_warn(self):
with warnings.catch_warnings():
warnings.simplefilter("error", DeprecationWarning)
assert f.encode(column("a"), "base64") is not None
assert f.decode(column("a"), "base64") is not None
assert f.digest(column("a"), "sha256") is not None
assert f.arrow_cast(column("a"), "Float64") is not None
assert f.arrow_try_cast(column("a"), pa.float64()) is not None
assert f.arrow_metadata(column("a"), "k") is not None

def test_backward_compat_with_lit(self):
"""Verify that existing code using lit() still works."""
ctx = SessionContext()
Expand Down
Loading