# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import datetime as dt import gzip import pathlib import pyarrow as pa import pyarrow.dataset as ds import pytest from datafusion import ( CsvReadOptions, DataFrame, RuntimeEnvBuilder, SessionConfig, SessionContext, SQLOptions, Table, column, literal, udf, ) def test_create_context_no_args(): SessionContext() def test_create_context_session_config_only(): SessionContext(config=SessionConfig()) def test_create_context_runtime_config_only(): SessionContext(runtime=RuntimeEnvBuilder()) @pytest.mark.parametrize("path_to_str", [True, False]) def test_runtime_configs(tmp_path, path_to_str): path1 = tmp_path / "dir1" path2 = tmp_path / "dir2" path1 = str(path1) if path_to_str else path1 path2 = str(path2) if path_to_str else path2 runtime = RuntimeEnvBuilder().with_disk_manager_specified(path1, path2) config = SessionConfig().with_default_catalog_and_schema("foo", "bar") ctx = SessionContext(config, runtime) assert ctx is not None db = ctx.catalog("foo").schema("bar") assert db is not None @pytest.mark.parametrize("path_to_str", [True, False]) def test_temporary_files(tmp_path, path_to_str): path = str(tmp_path) if path_to_str else tmp_path runtime = RuntimeEnvBuilder().with_temp_file_path(path) config = SessionConfig().with_default_catalog_and_schema("foo", "bar") ctx = SessionContext(config, runtime) assert ctx is not None db = ctx.catalog("foo").schema("bar") assert db is not None def test_create_context_with_all_valid_args(): runtime = RuntimeEnvBuilder().with_disk_manager_os().with_fair_spill_pool(10000000) config = ( SessionConfig() .with_create_default_catalog_and_schema(enabled=True) .with_default_catalog_and_schema("foo", "bar") .with_target_partitions(1) .with_information_schema(enabled=True) .with_repartition_joins(enabled=False) .with_repartition_aggregations(enabled=False) .with_repartition_windows(enabled=False) .with_parquet_pruning(enabled=False) ) ctx = SessionContext(config, runtime) # verify that at least some of the arguments worked ctx.catalog("foo").schema("bar") with pytest.raises(KeyError): ctx.catalog("datafusion") def test_register_record_batches(ctx): # create a RecordBatch and register it as memtable batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) ctx.register_record_batches("t", [[batch]]) assert ctx.catalog().schema().names() == {"t"} result = ctx.sql("SELECT a+b, a-b FROM t").collect() assert result[0].column(0) == pa.array([5, 7, 9]) assert result[0].column(1) == pa.array([-3, -3, -3]) def test_create_dataframe_registers_unique_table_name(ctx): # create a RecordBatch and register it as memtable batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) df = ctx.create_dataframe([[batch]]) tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 assert len(tables[0]) == 33 assert tables[0].startswith("c") # ensure that the rest of the table name contains # only hexadecimal numbers for c in tables[0][1:]: assert c in "0123456789abcdef" def test_create_dataframe_registers_with_defined_table_name(ctx): # create a RecordBatch and register it as memtable batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) df = ctx.create_dataframe([[batch]], name="tbl") tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 assert tables[0] == "tbl" def test_from_arrow_table(ctx): # create a PyArrow table data = {"a": [1, 2, 3], "b": [4, 5, 6]} table = pa.Table.from_pydict(data) # convert to DataFrame df = ctx.from_arrow(table) tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 def record_batch_generator(num_batches: int): schema = pa.schema([("a", pa.int64()), ("b", pa.int64())]) for _i in range(num_batches): yield pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], schema=schema ) @pytest.mark.parametrize( "source", [ # __arrow_c_array__ sources pa.array([{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}]), # __arrow_c_stream__ sources pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}), pa.RecordBatchReader.from_batches( pa.schema([("a", pa.int64()), ("b", pa.int64())]), record_batch_generator(1) ), pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}), ], ) def test_from_arrow_sources(ctx, source) -> None: df = ctx.from_arrow(source) assert df assert isinstance(df, DataFrame) assert df.schema().names == ["a", "b"] assert df.count() == 3 def test_from_arrow_table_with_name(ctx): # create a PyArrow table data = {"a": [1, 2, 3], "b": [4, 5, 6]} table = pa.Table.from_pydict(data) # convert to DataFrame with optional name df = ctx.from_arrow(table, name="tbl") tables = list(ctx.catalog().schema().names()) assert df assert tables[0] == "tbl" def test_from_arrow_table_empty(ctx): data = {"a": [], "b": []} schema = pa.schema([("a", pa.int32()), ("b", pa.string())]) table = pa.Table.from_pydict(data, schema=schema) # convert to DataFrame df = ctx.from_arrow(table) tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert len(df.collect()) == 0 def test_from_arrow_table_empty_no_schema(ctx): data = {"a": [], "b": []} table = pa.Table.from_pydict(data) # convert to DataFrame df = ctx.from_arrow(table) tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert len(df.collect()) == 0 def test_from_pylist(ctx): # create a dataframe from Python list data = [ {"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}, ] df = ctx.from_pylist(data) tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 def test_from_pydict(ctx): # create a dataframe from Python dictionary data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = ctx.from_pydict(data) tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 def test_from_pandas(ctx): # create a dataframe from pandas dataframe pd = pytest.importorskip("pandas") data = {"a": [1, 2, 3], "b": [4, 5, 6]} pandas_df = pd.DataFrame(data) df = ctx.from_pandas(pandas_df) tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 def test_from_polars(ctx): # create a dataframe from Polars dataframe pd = pytest.importorskip("polars") data = {"a": [1, 2, 3], "b": [4, 5, 6]} polars_df = pd.DataFrame(data) df = ctx.from_polars(polars_df) tables = list(ctx.catalog().schema().names()) assert df assert len(tables) == 1 assert isinstance(df, DataFrame) assert set(df.schema().names) == {"a", "b"} assert df.collect()[0].num_rows == 3 def test_register_table(ctx, database): default = ctx.catalog() public = default.schema("public") assert public.names() == {"csv", "csv1", "csv2"} table = public.table("csv") ctx.register_table("csv3", table) assert public.names() == {"csv", "csv1", "csv2", "csv3"} def test_read_table_from_catalog(ctx, database): default = ctx.catalog() public = default.schema("public") assert public.names() == {"csv", "csv1", "csv2"} table = public.table("csv") table_df = ctx.read_table(table) table_df.show() def test_read_table_from_df(ctx): df = ctx.from_pydict({"a": [1, 2]}) result = ctx.read_table(df).collect() assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] def test_read_table_from_dataset(ctx): batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) dataset = ds.dataset([batch]) result = ctx.read_table(dataset).collect() assert result[0].column(0) == pa.array([1, 2, 3]) assert result[0].column(1) == pa.array([4, 5, 6]) def test_deregister_table(ctx, database): default = ctx.catalog() public = default.schema("public") assert public.names() == {"csv", "csv1", "csv2"} ctx.deregister_table("csv") assert public.names() == {"csv1", "csv2"} def test_deregister_udf(): ctx = SessionContext() is_null = udf( lambda x: x.is_null(), [pa.float64()], pa.bool_(), volatility="immutable", name="my_is_null", ) ctx.register_udf(is_null) # Verify it works df = ctx.from_pydict({"a": [1.0, None]}) ctx.register_table("t", df.into_view()) result = ctx.sql("SELECT my_is_null(a) FROM t").collect() assert result[0].column(0) == pa.array([False, True]) # Deregister and verify it's gone ctx.deregister_udf("my_is_null") with pytest.raises(ValueError): ctx.sql("SELECT my_is_null(a) FROM t").collect() def test_deregister_udaf(): import pyarrow.compute as pc ctx = SessionContext() from datafusion import Accumulator, udaf class MySum(Accumulator): def __init__(self): self._sum = 0.0 def update(self, values: pa.Array) -> None: self._sum += pc.sum(values).as_py() def merge(self, states: list[pa.Array]) -> None: self._sum += pc.sum(states[0]).as_py() def state(self) -> list: return [self._sum] def evaluate(self) -> pa.Scalar: return self._sum my_sum = udaf( MySum, [pa.float64()], pa.float64(), [pa.float64()], volatility="immutable", name="my_sum", ) ctx.register_udaf(my_sum) df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]}) ctx.register_table("t", df.into_view()) result = ctx.sql("SELECT my_sum(a) FROM t").collect() assert result[0].column(0) == pa.array([6.0]) ctx.deregister_udaf("my_sum") with pytest.raises(ValueError): ctx.sql("SELECT my_sum(a) FROM t").collect() def test_deregister_udwf(): ctx = SessionContext() from datafusion import udwf from datafusion.user_defined import WindowEvaluator class MyRowNumber(WindowEvaluator): def __init__(self): self._row = 0 def evaluate_all(self, values, num_rows): return pa.array(list(range(1, num_rows + 1)), type=pa.uint64()) my_row_number = udwf( MyRowNumber, [pa.float64()], pa.uint64(), volatility="immutable", name="my_row_number", ) ctx.register_udwf(my_row_number) df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]}) ctx.register_table("t", df.into_view()) result = ctx.sql("SELECT my_row_number(a) OVER () FROM t").collect() assert result[0].column(0) == pa.array([1, 2, 3], type=pa.uint64()) ctx.deregister_udwf("my_row_number") with pytest.raises(ValueError): ctx.sql("SELECT my_row_number(a) OVER () FROM t").collect() def test_deregister_udtf(): import pyarrow.dataset as ds ctx = SessionContext() from datafusion import Table, udtf class MyTable: def __call__(self): batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}) return Table(ds.dataset([batch])) my_table = udtf(MyTable(), "my_table") ctx.register_udtf(my_table) result = ctx.sql("SELECT * FROM my_table()").collect() assert result[0].column(0) == pa.array([1, 2, 3]) ctx.deregister_udtf("my_table") with pytest.raises(ValueError): ctx.sql("SELECT * FROM my_table()").collect() def test_register_table_from_dataframe(ctx): df = ctx.from_pydict({"a": [1, 2]}) ctx.register_table("df_tbl", df) result = ctx.sql("SELECT * FROM df_tbl").collect() assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] @pytest.mark.parametrize("temporary", [True, False]) def test_register_table_from_dataframe_into_view(ctx, temporary): df = ctx.from_pydict({"a": [1, 2]}) table = df.into_view(temporary=temporary) assert isinstance(table, Table) if temporary: assert table.kind == "temporary" else: assert table.kind == "view" ctx.register_table("view_tbl", table) result = ctx.sql("SELECT * FROM view_tbl").collect() assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] def test_table_from_dataframe(ctx): df = ctx.from_pydict({"a": [1, 2]}) table = Table(df) assert isinstance(table, Table) ctx.register_table("from_dataframe_tbl", table) result = ctx.sql("SELECT * FROM from_dataframe_tbl").collect() assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] def test_table_from_dataframe_internal(ctx): df = ctx.from_pydict({"a": [1, 2]}) table = Table(df.df) assert isinstance(table, Table) ctx.register_table("from_internal_dataframe_tbl", table) result = ctx.sql("SELECT * FROM from_internal_dataframe_tbl").collect() assert [b.to_pydict() for b in result] == [{"a": [1, 2]}] def test_register_dataset(ctx): # create a RecordBatch and register it as a pyarrow.dataset.Dataset batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) assert ctx.catalog().schema().names() == {"t"} result = ctx.sql("SELECT a+b, a-b FROM t").collect() assert result[0].column(0) == pa.array([5, 7, 9]) assert result[0].column(1) == pa.array([-3, -3, -3]) def test_dataset_filter(ctx, capfd): # create a RecordBatch and register it as a pyarrow.dataset.Dataset batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) assert ctx.catalog().schema().names() == {"t"} df = ctx.sql("SELECT a+b, a-b FROM t WHERE a BETWEEN 2 and 3 AND b > 5") # Make sure the filter was pushed down in Physical Plan df.explain() captured = capfd.readouterr() assert "filter_expr=(((a >= 2) and (a <= 3)) and (b > 5))" in captured.out result = df.collect() assert result[0].column(0) == pa.array([9]) assert result[0].column(1) == pa.array([-3]) def test_dataset_count(ctx): # `datafusion-python` issue: https://github.com/apache/datafusion-python/issues/800 batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) # Testing the dataframe API df = ctx.table("t") assert df.count() == 3 # Testing the SQL API count = ctx.sql("SELECT COUNT(*) FROM t") count = count.collect() assert count[0].column(0) == pa.array([3]) def test_pyarrow_predicate_pushdown_is_null(ctx, capfd): """Ensure that pyarrow filter gets pushed down for `IsNull`""" # create a RecordBatch and register it as a pyarrow.dataset.Dataset batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([7, None, 9])], names=["a", "b", "c"], ) dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) # Make sure the filter was pushed down in Physical Plan df = ctx.sql("SELECT a FROM t WHERE c is NULL") df.explain() captured = capfd.readouterr() assert "filter_expr=is_null(c, {nan_is_null=false})" in captured.out result = df.collect() assert result[0].column(0) == pa.array([2]) def test_pyarrow_predicate_pushdown_timestamp(ctx, tmpdir, capfd): """Ensure that pyarrow filter gets pushed down for timestamp""" # Ref: https://github.com/apache/datafusion-python/issues/703 # create pyarrow dataset with no actual files col_type = pa.timestamp("ns", "+00:00") nyd_2000 = pa.scalar(dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc), col_type) pa_dataset_fs = pa.fs.SubTreeFileSystem(str(tmpdir), pa.fs.LocalFileSystem()) pa_dataset_format = pa.dataset.ParquetFileFormat() pa_dataset_partition = pa.dataset.field("a") <= nyd_2000 fragments = [ # NOTE: we never actually make this file. # Working predicate pushdown means it never gets accessed pa_dataset_format.make_fragment( "1.parquet", filesystem=pa_dataset_fs, partition_expression=pa_dataset_partition, ) ] pa_dataset = pa.dataset.FileSystemDataset( fragments, pa.schema([pa.field("a", col_type)]), pa_dataset_format, pa_dataset_fs, ) ctx.register_dataset("t", pa_dataset) # the partition for our only fragment is for a < 2000-01-01. # so querying for a > 2024-01-01 should not touch any files df = ctx.sql("SELECT * FROM t WHERE a > '2024-01-01T00:00:00+00:00'") assert df.collect() == [] def test_dataset_filter_nested_data(ctx): # create Arrow StructArrays to test nested data types data = pa.StructArray.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) batch = pa.RecordBatch.from_arrays( [data], names=["nested_data"], ) dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) assert ctx.catalog().schema().names() == {"t"} df = ctx.table("t") # This filter will not be pushed down to DatasetExec since it # isn't supported df = df.filter(column("nested_data")["b"] > literal(5)).select( column("nested_data")["a"] + column("nested_data")["b"], column("nested_data")["a"] - column("nested_data")["b"], ) result = df.collect() assert result[0].column(0) == pa.array([9]) assert result[0].column(1) == pa.array([-3]) def test_table_exist(ctx): batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) dataset = ds.dataset([batch]) ctx.register_dataset("t", dataset) assert ctx.table_exist("t") is True def test_table_not_found(ctx): from uuid import uuid4 with pytest.raises(KeyError): ctx.table(f"not-found-{uuid4()}") def test_session_start_time(ctx): import datetime import re st = ctx.session_start_time() assert isinstance(st, str) # Truncate nanoseconds to microseconds for Python 3.10 compat st = re.sub(r"(\.\d{6})\d+", r"\1", st) dt = datetime.datetime.fromisoformat(st) assert dt.isoformat() def test_enable_ident_normalization(ctx): assert ctx.enable_ident_normalization() is True ctx.sql("SET datafusion.sql_parser.enable_ident_normalization = false") assert ctx.enable_ident_normalization() is False def test_parse_sql_expr(ctx): from datafusion.common import DFSchema schema = DFSchema.empty() expr = ctx.parse_sql_expr("1 + 2", schema) assert str(expr) == "Expr(Int64(1) + Int64(2))" def test_execute_logical_plan(ctx): df = ctx.from_pydict({"a": [1, 2, 3]}) plan = df.logical_plan() df2 = ctx.execute_logical_plan(plan) result = df2.collect() assert result[0].column(0) == pa.array([1, 2, 3]) def test_refresh_catalogs(ctx): ctx.refresh_catalogs() def test_remove_optimizer_rule(ctx): assert ctx.remove_optimizer_rule("push_down_filter") is True assert ctx.remove_optimizer_rule("nonexistent_rule") is False def test_table_provider(ctx): batch = pa.RecordBatch.from_pydict({"x": [10, 20, 30]}) ctx.register_record_batches("provider_test", [[batch]]) tbl = ctx.table_provider("provider_test") assert tbl.schema == pa.schema([("x", pa.int64())]) def test_table_provider_not_found(ctx): with pytest.raises(KeyError): ctx.table_provider("nonexistent_table") def test_read_json(ctx): path = pathlib.Path(__file__).parent.resolve() # Default test_data_path = path / "data_test_context" / "data.json" df = ctx.read_json(test_data_path) result = df.collect() assert result[0].column(0) == pa.array(["a", "b", "c"]) assert result[0].column(1) == pa.array([1, 2, 3]) # Schema schema = pa.schema( [ pa.field("A", pa.string(), nullable=True), ] ) df = ctx.read_json(test_data_path, schema=schema) result = df.collect() assert result[0].column(0) == pa.array(["a", "b", "c"]) assert result[0].schema == schema # File extension test_data_path = path / "data_test_context" / "data.json" df = ctx.read_json(test_data_path, file_extension=".json") result = df.collect() assert result[0].column(0) == pa.array(["a", "b", "c"]) assert result[0].column(1) == pa.array([1, 2, 3]) def test_read_json_compressed(ctx, tmp_path): path = pathlib.Path(__file__).parent.resolve() test_data_path = path / "data_test_context" / "data.json" # File compression type gzip_path = tmp_path / "data.json.gz" with ( pathlib.Path.open(test_data_path, "rb") as csv_file, gzip.open(gzip_path, "wb") as gzipped_file, ): gzipped_file.writelines(csv_file) df = ctx.read_json(gzip_path, file_extension=".gz", file_compression_type="gz") result = df.collect() assert result[0].column(0) == pa.array(["a", "b", "c"]) assert result[0].column(1) == pa.array([1, 2, 3]) def test_read_csv(ctx): csv_df = ctx.read_csv(path="testing/data/csv/aggregate_test_100.csv") csv_df.select(column("c1")).show() def test_read_csv_list(ctx): csv_df = ctx.read_csv(path=["testing/data/csv/aggregate_test_100.csv"]) expected = csv_df.count() * 2 double_csv_df = ctx.read_csv( path=[ "testing/data/csv/aggregate_test_100.csv", "testing/data/csv/aggregate_test_100.csv", ] ) actual = double_csv_df.count() double_csv_df.select(column("c1")).show() assert actual == expected def test_read_csv_compressed(ctx, tmp_path): test_data_path = pathlib.Path("testing/data/csv/aggregate_test_100.csv") expected = ctx.read_csv(test_data_path).collect() # File compression type gzip_path = tmp_path / "aggregate_test_100.csv.gz" with ( pathlib.Path.open(test_data_path, "rb") as csv_file, gzip.open(gzip_path, "wb") as gzipped_file, ): gzipped_file.writelines(csv_file) csv_df = ctx.read_csv(gzip_path, file_extension=".gz", file_compression_type="gz") assert csv_df.collect() == expected csv_df = ctx.read_csv( gzip_path, options=CsvReadOptions(file_extension=".gz", file_compression_type="gz"), ) assert csv_df.collect() == expected def test_read_parquet(ctx): parquet_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet") parquet_df.show() assert parquet_df is not None path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet" parquet_df = ctx.read_parquet(path=path) assert parquet_df is not None def test_read_avro(ctx): avro_df = ctx.read_avro(path="testing/data/avro/alltypes_plain.avro") avro_df.show() assert avro_df is not None path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro" avro_df = ctx.read_avro(path=path) assert avro_df is not None def test_read_arrow(ctx, tmp_path): # Write an Arrow IPC file, then read it back table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]}) arrow_path = tmp_path / "test.arrow" with pa.ipc.new_file(str(arrow_path), table.schema) as writer: writer.write_table(table) df = ctx.read_arrow(str(arrow_path)) result = df.collect() assert result[0].column(0) == pa.array([1, 2, 3]) assert result[0].column(1) == pa.array(["x", "y", "z"]) # Also verify pathlib.Path works df = ctx.read_arrow(arrow_path) result = df.collect() assert result[0].column(0) == pa.array([1, 2, 3]) def test_read_empty(ctx): df = ctx.read_empty() result = df.collect() assert len(result) == 1 assert result[0].num_columns == 0 df = ctx.empty_table() result = df.collect() assert len(result) == 1 assert result[0].num_columns == 0 def test_register_arrow(ctx, tmp_path): # Write an Arrow IPC file, then register and query it table = pa.table({"x": [10, 20, 30]}) arrow_path = tmp_path / "test.arrow" with pa.ipc.new_file(str(arrow_path), table.schema) as writer: writer.write_table(table) ctx.register_arrow("arrow_tbl", str(arrow_path)) result = ctx.sql("SELECT * FROM arrow_tbl").collect() assert result[0].column(0) == pa.array([10, 20, 30]) # Also verify pathlib.Path works ctx.register_arrow("arrow_tbl_path", arrow_path) result = ctx.sql("SELECT * FROM arrow_tbl_path").collect() assert result[0].column(0) == pa.array([10, 20, 30]) def test_register_batch(ctx): batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) ctx.register_batch("batch_tbl", batch) result = ctx.sql("SELECT * FROM batch_tbl").collect() assert result[0].column(0) == pa.array([1, 2, 3]) assert result[0].column(1) == pa.array([4, 5, 6]) def test_register_batch_empty(ctx): batch = pa.RecordBatch.from_pydict({"a": pa.array([], type=pa.int64())}) ctx.register_batch("empty_batch_tbl", batch) result = ctx.sql("SELECT * FROM empty_batch_tbl").collect() assert result[0].num_rows == 0 def test_create_sql_options(): SQLOptions() def test_sql_with_options_no_ddl(ctx): sql = "CREATE TABLE IF NOT EXISTS valuetable AS VALUES(1,'HELLO'),(12,'DATAFUSION')" ctx.sql(sql) options = SQLOptions().with_allow_ddl(allow=False) with pytest.raises(Exception, match="DDL"): ctx.sql_with_options(sql, options=options) def test_sql_with_options_no_dml(ctx): table_name = "t" batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) dataset = ds.dataset([batch]) ctx.register_dataset(table_name, dataset) sql = f'INSERT INTO "{table_name}" VALUES (1, 2), (2, 3);' ctx.sql(sql) options = SQLOptions().with_allow_dml(allow=False) with pytest.raises(Exception, match="DML"): ctx.sql_with_options(sql, options=options) def test_sql_with_options_no_statements(ctx): sql = "SET time zone = 1;" ctx.sql(sql) options = SQLOptions().with_allow_statements(allow=False) with pytest.raises(Exception, match="SetVariable"): ctx.sql_with_options(sql, options=options) @pytest.fixture def batch(): return pa.RecordBatch.from_arrays( [pa.array([4, 5, 6])], names=["a"], ) def test_create_dataframe_with_global_ctx(batch): ctx = SessionContext.global_ctx() df = ctx.create_dataframe([[batch]]) result = df.collect()[0].column(0) assert result == pa.array([4, 5, 6]) def test_csv_read_options_builder_pattern(): """Test CsvReadOptions builder pattern.""" from datafusion import CsvReadOptions options = ( CsvReadOptions() .with_has_header(False) # noqa: FBT003 .with_delimiter("|") .with_quote("'") .with_schema_infer_max_records(2000) .with_truncated_rows(True) # noqa: FBT003 .with_newlines_in_values(True) # noqa: FBT003 .with_file_extension(".tsv") ) assert options.has_header is False assert options.delimiter == "|" assert options.quote == "'" assert options.schema_infer_max_records == 2000 assert options.truncated_rows is True assert options.newlines_in_values is True assert options.file_extension == ".tsv" def read_csv_with_options_inner( tmp_path: pathlib.Path, csv_content: str, options: CsvReadOptions, expected: pa.RecordBatch, as_read: bool, global_ctx: bool, ) -> None: from datafusion import SessionContext # Create a test CSV file group_dir = tmp_path / "group=a" group_dir.mkdir(exist_ok=True) csv_path = group_dir / "test.csv" csv_path.write_text(csv_content, newline="\n") ctx = SessionContext() if as_read: if global_ctx: from datafusion.io import read_csv df = read_csv(str(tmp_path), options=options) else: df = ctx.read_csv(str(tmp_path), options=options) else: ctx.register_csv("test_table", str(tmp_path), options=options) df = ctx.sql("SELECT * FROM test_table") df.show() # Verify the data result = df.collect() assert len(result) == 1 assert result[0] == expected @pytest.mark.parametrize( ("as_read", "global_ctx"), [ (True, True), (True, False), (False, False), ], ) def test_read_csv_with_options(tmp_path, as_read, global_ctx): """Test reading CSV with CsvReadOptions.""" csv_content = "Alice;30;|New York; NY|\nBob;25\n#Charlie;35;Paris\nPhil;75;Detroit' MI\nKarin;50;|Stockholm\nSweden|" # noqa: E501 # Some of the read options are difficult to test in combination # such as schema and schema_infer_max_records so run multiple tests # file_sort_order doesn't impact reading, but included here to ensure # all options parse correctly options = CsvReadOptions( has_header=False, delimiter=";", quote="|", terminator="\n", escape="\\", comment="#", newlines_in_values=True, schema_infer_max_records=1, null_regex="[pP]+aris", truncated_rows=True, file_sort_order=[[column("column_1").sort(), column("column_2")], ["column_3"]], ) expected = pa.RecordBatch.from_arrays( [ pa.array(["Alice", "Bob", "Phil", "Karin"]), pa.array([30, 25, 75, 50]), pa.array(["New York; NY", None, "Detroit' MI", "Stockholm\nSweden"]), ], names=["column_1", "column_2", "column_3"], ) read_csv_with_options_inner( tmp_path, csv_content, options, expected, as_read, global_ctx ) schema = pa.schema( [ pa.field("name", pa.string(), nullable=False), pa.field("age", pa.float32(), nullable=False), pa.field("location", pa.string(), nullable=True), ] ) options.with_schema(schema) expected = pa.RecordBatch.from_arrays( [ pa.array(["Alice", "Bob", "Phil", "Karin"]), pa.array([30.0, 25.0, 75.0, 50.0]), pa.array(["New York; NY", None, "Detroit' MI", "Stockholm\nSweden"]), ], schema=schema, ) read_csv_with_options_inner( tmp_path, csv_content, options, expected, as_read, global_ctx ) csv_content = "name,age\nAlice,30\nBob,25\nCharlie,35\nDiego,40\nEmily,15" expected = pa.RecordBatch.from_arrays( [ pa.array(["Alice", "Bob", "Charlie", "Diego", "Emily"]), pa.array([30, 25, 35, 40, 15]), pa.array(["a", "a", "a", "a", "a"]), ], schema=pa.schema( [ pa.field("name", pa.string(), nullable=True), pa.field("age", pa.int64(), nullable=True), pa.field("group", pa.string(), nullable=False), ] ), ) options = CsvReadOptions( table_partition_cols=[("group", pa.string())], ) read_csv_with_options_inner( tmp_path, csv_content, options, expected, as_read, global_ctx )