Skip to content
Next Next commit
Add sql registry async refresh
Signed-off-by: Stanley Opara <a-sopara@expediagroup.com>
  • Loading branch information
Stanley Opara committed Jun 3, 2024
commit c05b0dbe4771fa8182f35d55049e1502b929773b
10 changes: 7 additions & 3 deletions sdk/python/feast/infra/registry/caching_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import threading
from abc import abstractmethod
from datetime import datetime, timedelta
from threading import Lock
Expand All @@ -21,9 +22,7 @@

class CachingRegistry(BaseRegistry):
def __init__(
self,
project: str,
cache_ttl_seconds: int,
self, project: str, cache_ttl_seconds: int, allow_async_cache: bool = False
):
self.cached_registry_proto = self.proto()
proto_registry_utils.init_project_metadata(self.cached_registry_proto, project)
Expand All @@ -32,6 +31,9 @@ def __init__(
self.cached_registry_proto_ttl = timedelta(
seconds=cache_ttl_seconds if cache_ttl_seconds is not None else 0
)
self.allow_async_cache = allow_async_cache
if allow_async_cache:
threading.Timer(cache_ttl_seconds, self.refresh).start()

@abstractmethod
def _get_data_source(self, name: str, project: str) -> DataSource:
Expand Down Expand Up @@ -289,6 +291,8 @@ def refresh(self, project: Optional[str] = None):
self.cached_registry_proto_created = datetime.utcnow()

def _refresh_cached_registry_if_necessary(self):
if self.allow_async_cache:
return
with self._refresh_lock:
expired = (
self.cached_registry_proto is None
Expand Down
4 changes: 3 additions & 1 deletion sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def __init__(
)
metadata.create_all(self.engine)
super().__init__(
project=project, cache_ttl_seconds=registry_config.cache_ttl_seconds
project=project,
cache_ttl_seconds=registry_config.cache_ttl_seconds,
allow_async_cache=registry_config.allow_async_cache,
)

def teardown(self):
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/feast/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
BaseModel,
ConfigDict,
Field,
StrictBool,
StrictInt,
StrictStr,
ValidationError,
Expand Down Expand Up @@ -130,6 +131,8 @@ class RegistryConfig(FeastBaseModel):
sqlalchemy_config_kwargs: Dict[str, Any] = {}
""" Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """

allow_async_cache: StrictBool = False


class RepoConfig(FeastBaseModel):
"""Repo config. Typically loaded from `feature_store.yaml`"""
Expand Down
111 changes: 103 additions & 8 deletions sdk/python/tests/integration/registration/test_universal_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,35 @@ def pg_registry():

container.start()

registry_config = _given_registry_config_for_pg_sql(container)

yield SqlRegistry(registry_config, "project", None)

container.stop()


@pytest.fixture(scope="session")
def pg_registry_async():
container = (
DockerContainer("postgres:latest")
.with_exposed_ports(5432)
.with_env("POSTGRES_USER", POSTGRES_USER)
.with_env("POSTGRES_PASSWORD", POSTGRES_PASSWORD)
.with_env("POSTGRES_DB", POSTGRES_DB)
)

container.start()

registry_config = _given_registry_config_for_pg_sql(container, 2, True)

yield SqlRegistry(registry_config, "project", None)

container.stop()


def _given_registry_config_for_pg_sql(
container, cache_ttl_seconds=2, allow_async_cache=False
):
log_string_to_wait_for = "database system is ready to accept connections"
waited = wait_for_logs(
container=container,
Expand All @@ -146,40 +175,57 @@ def pg_registry():
container_port = container.get_exposed_port(5432)
container_host = container.get_container_host_ip()

registry_config = RegistryConfig(
return RegistryConfig(
registry_type="sql",
cache_ttl_seconds=cache_ttl_seconds,
allow_async_cache=allow_async_cache,
path=f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{container_host}:{container_port}/{POSTGRES_DB}",
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
)


@pytest.fixture(scope="session")
def mysql_registry():
container = MySqlContainer("mysql:latest")
container.start()

registry_config = _given_registry_config_for_mysql(container)

yield SqlRegistry(registry_config, "project", None)

container.stop()


@pytest.fixture(scope="session")
def mysql_registry():
def mysql_registry_async():
container = MySqlContainer("mysql:latest")
container.start()

# testing for the database to exist and ready to connect and start testing.
registry_config = _given_registry_config_for_mysql(container, 2, True)

yield SqlRegistry(registry_config, "project", None)

container.stop()


def _given_registry_config_for_mysql(
container, cache_ttl_seconds=2, allow_async_cache=False
):
import sqlalchemy

engine = sqlalchemy.create_engine(
container.get_connection_url(), pool_pre_ping=True
)
engine.connect()

registry_config = RegistryConfig(
return RegistryConfig(
registry_type="sql",
path=container.get_connection_url(),
cache_ttl_seconds=cache_ttl_seconds,
allow_async_cache=allow_async_cache,
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
)

yield SqlRegistry(registry_config, "project", None)

container.stop()


@pytest.fixture(scope="session")
def sqlite_registry():
Expand Down Expand Up @@ -265,6 +311,17 @@ def mock_remote_registry():
lazy_fixture("sqlite_registry"),
]

async_sql_fixtures = [
pytest.param(
lazy_fixture("pg_registry_async"),
marks=pytest.mark.xdist_group(name="pg_registry_async"),
),
pytest.param(
lazy_fixture("mysql_registry_async"),
marks=pytest.mark.xdist_group(name="mysql_registry_async"),
),
]


@pytest.mark.integration
@pytest.mark.parametrize("test_registry", all_fixtures)
Expand Down Expand Up @@ -773,6 +830,44 @@ def test_registry_cache(test_registry):
test_registry.teardown()


@pytest.mark.integration
@pytest.mark.parametrize(
"test_registry",
async_sql_fixtures,
)
def test_registry_cache_async(test_registry):
# Create Feature Views
batch_source = FileSource(
name="test_source",
file_format=ParquetFormat(),
path="file://feast/*",
timestamp_field="ts_col",
created_timestamp_column="timestamp",
)

project = "project"

# Register data source
test_registry.apply_data_source(batch_source, project)
registry_data_sources_cached = test_registry.list_data_sources(
project, allow_cache=True
)
# async ttl yet to expire, so cache miss
assert len(registry_data_sources_cached) == 0

# Wait for cache to be refreshed
time.sleep(4)
# Now objects exist
registry_data_sources_cached = test_registry.list_data_sources(
project, allow_cache=True
)
assert len(registry_data_sources_cached) == 1
registry_data_source = registry_data_sources_cached[0]
assert registry_data_source == batch_source

test_registry.teardown()


@pytest.mark.integration
@pytest.mark.parametrize(
"test_registry",
Expand Down