# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations import datafusion as dfn import pyarrow as pa import pyarrow.dataset as ds import pytest from datafusion import SessionContext, Table # Note we take in `database` as a variable even though we don't use # it because that will cause the fixture to set up the context with # the tables we need. def test_basic(ctx, database): with pytest.raises(KeyError): ctx.catalog("non-existent") default = ctx.catalog() assert default.names() == {"public"} for db in [default.schema("public"), default.schema()]: assert db.names() == {"csv1", "csv", "csv2"} table = db.table("csv") assert table.kind == "physical" assert table.schema == pa.schema( [ pa.field("int", pa.int64(), nullable=True), pa.field("str", pa.string(), nullable=True), pa.field("float", pa.float64(), nullable=True), ] ) def create_dataset() -> Table: batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 5, 6])], names=["a", "b"], ) dataset = ds.dataset([batch]) return Table.from_dataset(dataset) class CustomSchemaProvider(dfn.catalog.SchemaProvider): def __init__(self): self.tables = {"table1": create_dataset()} def table_names(self) -> set[str]: return set(self.tables.keys()) def register_table(self, name: str, table: Table): self.tables[name] = table def deregister_table(self, name, cascade: bool = True): del self.tables[name] def table(self, name: str) -> Table | None: return self.tables[name] def table_exist(self, name: str) -> bool: return name in self.tables class CustomCatalogProvider(dfn.catalog.CatalogProvider): def __init__(self): self.schemas = {"my_schema": CustomSchemaProvider()} def schema_names(self) -> set[str]: return set(self.schemas.keys()) def schema(self, name: str): return self.schemas[name] def register_schema(self, name: str, schema: dfn.catalog.Schema): self.schemas[name] = schema def deregister_schema(self, name, cascade: bool): del self.schemas[name] def test_python_catalog_provider(ctx: SessionContext): ctx.register_catalog_provider("my_catalog", CustomCatalogProvider()) # Check the default catalog provider assert ctx.catalog("datafusion").names() == {"public"} my_catalog = ctx.catalog("my_catalog") assert my_catalog.names() == {"my_schema"} my_catalog.register_schema("second_schema", CustomSchemaProvider()) assert my_catalog.schema_names() == {"my_schema", "second_schema"} my_catalog.deregister_schema("my_schema") assert my_catalog.schema_names() == {"second_schema"} def test_in_memory_providers(ctx: SessionContext): catalog = dfn.catalog.Catalog.memory_catalog() ctx.register_catalog_provider("in_mem_catalog", catalog) assert ctx.catalog_names() == {"datafusion", "in_mem_catalog"} schema = dfn.catalog.Schema.memory_schema() catalog.register_schema("in_mem_schema", schema) schema.register_table("my_table", create_dataset()) batches = ctx.sql("select * from in_mem_catalog.in_mem_schema.my_table").collect() assert len(batches) == 1 assert batches[0].column(0) == pa.array([1, 2, 3]) assert batches[0].column(1) == pa.array([4, 5, 6]) def test_python_schema_provider(ctx: SessionContext): catalog = ctx.catalog() catalog.deregister_schema("public") catalog.register_schema("test_schema1", CustomSchemaProvider()) assert catalog.names() == {"test_schema1"} catalog.register_schema("test_schema2", CustomSchemaProvider()) catalog.deregister_schema("test_schema1") assert catalog.names() == {"test_schema2"} def test_python_table_provider(ctx: SessionContext): catalog = ctx.catalog() catalog.register_schema("custom_schema", CustomSchemaProvider()) schema = catalog.schema("custom_schema") assert schema.table_names() == {"table1"} schema.deregister_table("table1") schema.register_table("table2", create_dataset()) assert schema.table_names() == {"table2"} # Use the default schema instead of our custom schema schema = catalog.schema() schema.register_table("table3", create_dataset()) assert schema.table_names() == {"table3"} schema.deregister_table("table3") schema.register_table("table4", create_dataset()) assert schema.table_names() == {"table4"} def test_in_end_to_end_python_providers(ctx: SessionContext): """Test registering all python providers and running a query against them.""" all_catalog_names = [ "datafusion", "custom_catalog", "in_mem_catalog", ] all_schema_names = [ "custom_schema", "in_mem_schema", ] ctx.register_catalog_provider(all_catalog_names[1], CustomCatalogProvider()) ctx.register_catalog_provider( all_catalog_names[2], dfn.catalog.Catalog.memory_catalog() ) for catalog_name in all_catalog_names: catalog = ctx.catalog(catalog_name) # Clean out previous schemas if they exist so we can start clean for schema_name in catalog.schema_names(): catalog.deregister_schema(schema_name, cascade=False) catalog.register_schema(all_schema_names[0], CustomSchemaProvider()) catalog.register_schema(all_schema_names[1], dfn.catalog.Schema.memory_schema()) for schema_name in all_schema_names: schema = catalog.schema(schema_name) for table_name in schema.table_names(): schema.deregister_table(table_name) schema.register_table("test_table", create_dataset()) for catalog_name in all_catalog_names: for schema_name in all_schema_names: table_full_name = f"{catalog_name}.{schema_name}.test_table" batches = ctx.sql(f"select * from {table_full_name}").collect() assert len(batches) == 1 assert batches[0].column(0) == pa.array([1, 2, 3]) assert batches[0].column(1) == pa.array([4, 5, 6])