Skip to content
Prev Previous commit
Next Next commit
Change RegistryConfig to cacheMode
Signed-off-by: Stanley Opara <a-sopara@expediagroup.com>
  • Loading branch information
Stanley Opara committed Jun 11, 2024
commit 956b7596a2c4e21d663b15dc253844fb0f21cea1
47 changes: 22 additions & 25 deletions sdk/python/feast/infra/registry/caching_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,17 @@


class CachingRegistry(BaseRegistry):
def __init__(
self, project: str, cache_ttl_seconds: int, allow_async_cache: bool = False
):
def __init__(self, project: str, cache_ttl_seconds: int, cache_mode: str):
self.cached_registry_proto = self.proto()
proto_registry_utils.init_project_metadata(self.cached_registry_proto, project)
self.cached_registry_proto_created = datetime.utcnow()
self._refresh_lock = Lock()
self.cached_registry_proto_ttl = timedelta(
seconds=cache_ttl_seconds if cache_ttl_seconds is not None else 0
)
self.allow_async_cache = allow_async_cache
if allow_async_cache:
self._start_async_refresh(cache_ttl_seconds)
self.cache_mode = cache_mode
if cache_mode == "thread":
self._start_thread_async_refresh(cache_ttl_seconds)

@abstractmethod
def _get_data_source(self, name: str, project: str) -> DataSource:
Expand Down Expand Up @@ -292,29 +290,28 @@ def refresh(self, project: Optional[str] = None):
self.cached_registry_proto_created = datetime.utcnow()

def _refresh_cached_registry_if_necessary(self):
if self.allow_async_cache:
return
with self._refresh_lock:
expired = (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds()
> 0 # 0 ttl means infinity
and (
datetime.utcnow()
> (
self.cached_registry_proto_created
+ self.cached_registry_proto_ttl
if self.cache_mode == "sync":
with self._refresh_lock:
expired = (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds()
> 0 # 0 ttl means infinity
and (
datetime.utcnow()
> (
self.cached_registry_proto_created
+ self.cached_registry_proto_ttl
)
)
)
)

if expired:
logger.info("Registry cache expired, so refreshing")
self.refresh()
if expired:
logger.info("Registry cache expired, so refreshing")
self.refresh()

def _start_async_refresh(self, cache_ttl_seconds):
def _start_thread_async_refresh(self, cache_ttl_seconds):
self.registry_refresh_thread = threading.Timer(cache_ttl_seconds, self.refresh)
Comment thread
stanconia marked this conversation as resolved.
Outdated
self.registry_refresh_thread.setDaemon(True)
self.registry_refresh_thread.start()
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
super().__init__(
project=project,
cache_ttl_seconds=registry_config.cache_ttl_seconds,
allow_async_cache=registry_config.allow_async_cache,
cache_mode=registry_config.cache_mode,
)

def teardown(self):
Expand Down
3 changes: 1 addition & 2 deletions sdk/python/feast/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
BaseModel,
ConfigDict,
Field,
StrictBool,
StrictInt,
StrictStr,
ValidationError,
Expand Down Expand Up @@ -131,7 +130,7 @@ class RegistryConfig(FeastBaseModel):
sqlalchemy_config_kwargs: Dict[str, Any] = {}
""" Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """

allow_async_cache: StrictBool = False
cache_mode: StrictStr = "sync"
Comment thread
stanconia marked this conversation as resolved.


class RepoConfig(FeastBaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,15 @@ def pg_registry_async():

container.start()

registry_config = _given_registry_config_for_pg_sql(container, 2, True)
registry_config = _given_registry_config_for_pg_sql(container, 2, "thread")

yield SqlRegistry(registry_config, "project", None)

container.stop()


def _given_registry_config_for_pg_sql(
container, cache_ttl_seconds=2, allow_async_cache=False
container, cache_ttl_seconds=2, cache_mode="sync"
):
log_string_to_wait_for = "database system is ready to accept connections"
waited = wait_for_logs(
Expand All @@ -178,7 +178,7 @@ def _given_registry_config_for_pg_sql(
return RegistryConfig(
registry_type="sql",
cache_ttl_seconds=cache_ttl_seconds,
allow_async_cache=allow_async_cache,
cache_mode=cache_mode,
path=f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{container_host}:{container_port}/{POSTGRES_DB}",
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
)
Expand All @@ -201,16 +201,14 @@ def mysql_registry_async():
container = MySqlContainer("mysql:latest")
container.start()

registry_config = _given_registry_config_for_mysql(container, 2, True)
registry_config = _given_registry_config_for_mysql(container, 2, "thread")

yield SqlRegistry(registry_config, "project", None)

container.stop()


def _given_registry_config_for_mysql(
container, cache_ttl_seconds=2, allow_async_cache=False
):
def _given_registry_config_for_mysql(container, cache_ttl_seconds=2, cache_mode="sync"):
import sqlalchemy

engine = sqlalchemy.create_engine(
Expand All @@ -222,7 +220,7 @@ def _given_registry_config_for_mysql(
registry_type="sql",
path=container.get_connection_url(),
cache_ttl_seconds=cache_ttl_seconds,
allow_async_cache=allow_async_cache,
cache_mode=cache_mode,
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
)

Expand Down Expand Up @@ -835,7 +833,7 @@ def test_registry_cache(test_registry):
"test_registry",
async_sql_fixtures,
)
def test_registry_cache_async(test_registry):
def test_registry_cache_thread_async(test_registry):
# Create Feature Views
batch_source = FileSource(
name="test_source",
Expand Down