# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pyarrow as pa import pyarrow.dataset as ds from datafusion import Expr, SessionContext, Table, udtf from datafusion.context import TableProviderExportable def python_table_function_inner( num_cols: int, num_rows: int, num_batches: int ) -> Table: rows = list(range(num_rows)) cols = [pa.array(rows) for _ in range(num_cols)] names = [f"row_{i}" for i in range(num_cols)] batch = pa.RecordBatch.from_arrays(cols, names=names) batches = [batch for _ in range(num_batches)] dataset = ds.dataset(batches) return Table(dataset) def test_python_table_function_with_args_class() -> None: """Test Python TableFunction using a class with arguments.""" ctx = SessionContext() class ParameterizedTableFunction: """Table function that takes parameters.""" def __call__(self, num_cols: Expr, num_rows: Expr, num_batches: Expr) -> Table: return python_table_function_inner( num_cols.to_variant().value_i64(), num_rows.to_variant().value_i64(), num_batches.to_variant().value_i64(), ) # Register the function table_func = ParameterizedTableFunction() table_udtf = udtf(table_func, "param_func") ctx.register_udtf(table_udtf) # Call with different parameters result = ctx.sql("SELECT * FROM param_func(5, 10, 2)").collect() assert len(result) == 2 assert result[0].num_columns == 5 assert result[0].num_rows == 10 def test_python_table_function_decorator() -> None: """Test Python TableFunction using decorator syntax.""" ctx = SessionContext() @udtf("decorated_func") def my_decorated_func(num_cols: Expr, num_rows: Expr, num_batches: Expr) -> Table: return python_table_function_inner( num_cols.to_variant().value_i64(), num_rows.to_variant().value_i64(), num_batches.to_variant().value_i64(), ) ctx.register_udtf(my_decorated_func) result = ctx.sql("SELECT * FROM decorated_func(3, 7, 2)").collect() assert len(result) == 2 assert result[0].num_columns == 3 assert result[0].num_rows == 7 def test_python_table_function_no_args() -> None: """Test Python TableFunction with no arguments.""" ctx = SessionContext() @udtf("static_func") def static_table_func() -> Table: return python_table_function_inner(2, 3, 1) ctx.register_udtf(static_table_func) result = ctx.sql("SELECT * FROM static_func()").collect() assert len(result) == 1 assert list(result[0].column(0).to_pylist()) == [0, 1, 2] assert list(result[0].column(1).to_pylist()) == [0, 1, 2] def test_python_table_function_single_arg() -> None: """Test Python TableFunction with a single argument.""" ctx = SessionContext() @udtf("single_arg_func") def single_arg_func(n: Expr) -> TableProviderExportable: return python_table_function_inner(2, n.to_variant().value_i64(), 1) ctx.register_udtf(single_arg_func) result = ctx.sql("SELECT * FROM single_arg_func(15)").collect() assert len(result) == 1 assert result[0].num_columns == 2 assert result[0].num_rows == 15 def test_python_table_function_with_string_args() -> None: """Test Python TableFunction with string arguments.""" ctx = SessionContext() @udtf("string_arg_func") def string_arg_func(prefix: Expr) -> TableProviderExportable: prefix_str = prefix.to_variant().value_string() # Create a table with the prefix in column names batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=[f"{prefix_str}_a", f"{prefix_str}_b"], ) return Table(ds.dataset([batch])) ctx.register_udtf(string_arg_func) result = ctx.sql("SELECT * FROM string_arg_func('test')").collect() assert len(result) == 1 assert result[0].schema.names == ["test_a", "test_b"]