Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix: Added tests to the optimization changes
Signed-off-by: Bhargav Dodla <bdodla@expediagroup.com>
  • Loading branch information
Bhargav Dodla committed Aug 20, 2024
commit 8de470e5d410c4b98ab1958c92b325e32b6956df
44 changes: 21 additions & 23 deletions sdk/python/feast/infra/registry/caching_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ class CachingRegistry(BaseRegistry):
def __init__(self, project: str, cache_ttl_seconds: int, cache_mode: str):
self.cache_mode = cache_mode
self.cached_registry_proto = RegistryProto()
self.cached_registry_proto = self.proto()
self.cached_registry_proto_created = _utc_now()
self._refresh_lock = Lock()
self.cached_registry_proto_ttl = timedelta(
seconds=cache_ttl_seconds if cache_ttl_seconds is not None else 0
)
self.cached_registry_proto = self.proto()
self.cached_registry_proto_created = _utc_now()
if cache_mode == "thread":
self._start_thread_async_refresh(cache_ttl_seconds)
atexit.register(self._exit_handler)
Expand Down Expand Up @@ -332,34 +332,32 @@ def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
return self._get_infra(project)

def refresh(self, project: Optional[str] = None):
if project:
project_metadata = proto_registry_utils.get_project_metadata(
registry_proto=self.cached_registry_proto, project=project
)
if not project_metadata:
proto_registry_utils.init_project_metadata(
self.cached_registry_proto, project
)
self.cached_registry_proto = self.proto()
self.cached_registry_proto_created = _utc_now()

def _refresh_cached_registry_if_necessary(self):
if self.cache_mode == "sync":
with self._refresh_lock:
expired = (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds()
> 0 # 0 ttl means infinity
and (
_utc_now()
> (
self.cached_registry_proto_created
+ self.cached_registry_proto_ttl
if self.cached_registry_proto == RegistryProto():
# Avoids the need to refresh the registry when cache is not populated yet
# Specially during the __init__ phase
# proto() will populate the cache with project metadata if no objects are registered
expired = False
else:
expired = (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds()
> 0 # 0 ttl means infinity
and (
_utc_now()
> (
self.cached_registry_proto_created
+ self.cached_registry_proto_ttl
)
)
)
)
if expired:
logger.info("Registry cache expired, so refreshing")
self.refresh()
Expand All @@ -371,7 +369,7 @@ def _start_thread_async_refresh(self, cache_ttl_seconds):
self.registry_refresh_thread = threading.Timer(
cache_ttl_seconds, self._start_thread_async_refresh, [cache_ttl_seconds]
)
self.registry_refresh_thread.setDaemon(True)
self.registry_refresh_thread.daemon = True
self.registry_refresh_thread.start()

def _exit_handler(self):
Expand Down
36 changes: 28 additions & 8 deletions sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,6 @@ class SqlRegistryConfig(RegistryConfig):
""" Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """


# Number of workers in ThreadPoolExecutor
MAX_WORKERS = 5


class SqlRegistry(CachingRegistry):
def __init__(
self,
Expand All @@ -221,6 +217,9 @@ def __init__(
self.engine: Engine = create_engine(
registry_config.path, **registry_config.sqlalchemy_config_kwargs
)
self.thread_pool_executor_worker_count = (
registry_config.thread_pool_executor_worker_count
)
metadata.create_all(self.engine)

self._maybe_init_project_metadata(project)
Expand Down Expand Up @@ -368,6 +367,23 @@ def _list_entities(
entities, project, EntityProto, Entity, "entity_proto", tags=tags
)

# TODO: Add to BaseRegistry
def delete_project(self, project: str):
with self.engine.begin() as conn:
for t in {
entities,
data_sources,
feature_views,
feature_services,
on_demand_feature_views,
saved_datasets,
validation_references,
managed_infra,
feast_metadata,
}:
stmt = delete(t).where(t.c.project_id == project)
conn.execute(stmt)

def delete_entity(self, name: str, project: str, commit: bool = True):
return self._delete_object(
entities, name, project, "entity_name", EntityNotFoundException
Expand Down Expand Up @@ -740,10 +756,14 @@ def process_project(project_metadata: ProjectMetadata):

project_metadata_list = self.get_all_projects()

with ThreadPoolExecutor(
max_workers=MAX_WORKERS
) as executor: # Adjust max_workers as needed. Defaults to 5
executor.map(process_project, project_metadata_list)
if self.thread_pool_executor_worker_count == 0:
for project_metadata in project_metadata_list:
process_project(project_metadata)
else:
with ThreadPoolExecutor(
max_workers=self.thread_pool_executor_worker_count
) as executor:
executor.map(process_project, project_metadata_list)

if last_updated_timestamps:
r.last_updated.FromDatetime(max(last_updated_timestamps))
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/project_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def from_proto(cls, project_metadata_proto: ProjectMetadataProto):
project_metadata = cls(
project_name=project_metadata_proto.project,
project_uuid=project_metadata_proto.project_uuid,
last_updated_timestamp=project_metadata_proto.last_updated_timestamp.ToDatetime().astimezone(
tz=timezone.utc
last_updated_timestamp=project_metadata_proto.last_updated_timestamp.ToDatetime().replace(
tzinfo=timezone.utc
),
)

Expand Down
3 changes: 3 additions & 0 deletions sdk/python/feast/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ class RegistryConfig(FeastBaseModel):
cache_mode: StrictStr = "sync"
""" str: Cache mode type, Possible options are sync and thread(asynchronous caching using threading library)"""

thread_pool_executor_worker_count: StrictInt = 0
""" int: Number of worker threads to use for asynchronous caching in SQL Registry. If set to 0, it doesn't use ThreadPoolExecutor. """

@field_validator("path")
def validate_path(cls, path: str, values: ValidationInfo) -> str:
if values.data.get("registry_type") == "sql":
Expand Down
124 changes: 118 additions & 6 deletions sdk/python/tests/integration/registration/test_universal_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import os
import time
from datetime import timedelta, timezone
from datetime import datetime, timedelta, timezone
from tempfile import mkstemp
from unittest import mock

Expand Down Expand Up @@ -155,15 +155,18 @@ def pg_registry_async():

container.start()

registry_config = _given_registry_config_for_pg_sql(container, 2, "thread")
registry_config = _given_registry_config_for_pg_sql(container, 2, "thread", 3)

yield SqlRegistry(registry_config, "project", None)

container.stop()


def _given_registry_config_for_pg_sql(
container, cache_ttl_seconds=2, cache_mode="sync"
container,
cache_ttl_seconds=2,
cache_mode="sync",
thread_pool_executor_worker_count=0,
):
log_string_to_wait_for = "database system is ready to accept connections"
waited = wait_for_logs(
Expand All @@ -180,6 +183,7 @@ def _given_registry_config_for_pg_sql(
registry_type="sql",
cache_ttl_seconds=cache_ttl_seconds,
cache_mode=cache_mode,
thread_pool_executor_worker_count=thread_pool_executor_worker_count,
# The `path` must include `+psycopg` in order for `sqlalchemy.create_engine()`
# to understand that we are using psycopg3.
path=f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{container_host}:{container_port}/{POSTGRES_DB}",
Expand All @@ -204,14 +208,19 @@ def mysql_registry_async():
container = MySqlContainer("mysql:latest")
container.start()

registry_config = _given_registry_config_for_mysql(container, 2, "thread")
registry_config = _given_registry_config_for_mysql(container, 2, "thread", 3)

yield SqlRegistry(registry_config, "project", None)

container.stop()


def _given_registry_config_for_mysql(container, cache_ttl_seconds=2, cache_mode="sync"):
def _given_registry_config_for_mysql(
container,
cache_ttl_seconds=2,
cache_mode="sync",
thread_pool_executor_worker_count=0,
):
import sqlalchemy

engine = sqlalchemy.create_engine(
Expand All @@ -224,11 +233,12 @@ def _given_registry_config_for_mysql(container, cache_ttl_seconds=2, cache_mode=
path=container.get_connection_url(),
cache_ttl_seconds=cache_ttl_seconds,
cache_mode=cache_mode,
thread_pool_executor_worker_count=thread_pool_executor_worker_count,
sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True},
)


@pytest.fixture(scope="session")
@pytest.fixture(scope="function")
def sqlite_registry():
registry_config = RegistryConfig(
registry_type="sql",
Expand Down Expand Up @@ -342,6 +352,7 @@ def test_apply_entity_success(test_registry):
project_uuid = project_metadata[0].project_uuid
assert len(project_metadata[0].project_uuid) == 36
assert_project_uuid(project, project_uuid, test_registry)
assert project_metadata[0].last_updated_timestamp is not None

entities = test_registry.list_entities(project, tags=entity.tags)
assert_project_uuid(project, project_uuid, test_registry)
Expand Down Expand Up @@ -1343,3 +1354,104 @@ def validate_project_uuid(project_uuid, test_registry):
assert len(test_registry.cached_registry_proto.project_metadata) == 1
project_metadata = test_registry.cached_registry_proto.project_metadata[0]
assert project_metadata.project_uuid == project_uuid


@pytest.mark.integration
@pytest.mark.parametrize(
"test_registry",
sql_fixtures,
)
def test_project_metadata_success(test_registry):
project = "project"
project_metadata = test_registry.get_project_metadata(project)
assert project_metadata.project_name == project
assert project_metadata.last_updated_timestamp == datetime.fromtimestamp(
1, tz=timezone.utc
)

last_refresh_timestamp = project_metadata.last_updated_timestamp

entity = Entity(
name="test_project_metadata_success",
description="test_project_metadata_success",
tags={"team": "matchmaking"},
)

# Register Entity
test_registry.apply_entity(entity, project)

project_metadata = test_registry.get_project_metadata(project)
assert project_metadata.project_name == project
assert project_metadata.last_updated_timestamp > last_refresh_timestamp

project_metadata_list = test_registry.get_all_projects()
assert len(project_metadata_list) == 1

test_registry.delete_project(project)

project_metadata = test_registry.get_project_metadata(project)
assert project_metadata is None

project_metadata_list = test_registry.get_all_projects()
assert len(project_metadata_list) == 0

test_registry.teardown()


@pytest.mark.integration
@pytest.mark.parametrize(
"test_registry",
sql_fixtures,
)
def test_project_metadata_from_cache_on_init_success(test_registry):
# In Setup phase, proto() method is not executing fully due to lazy fixtures, so forcing the call
test_registry.cached_registry_proto = test_registry.proto()
project = "project"
project_metadata = test_registry.get_project_metadata(project, allow_cache=True)
assert project_metadata.project_name == project
assert project_metadata.last_updated_timestamp == datetime.fromtimestamp(
1, tz=timezone.utc
)
last_refresh_timestamp = project_metadata.last_updated_timestamp

entity = Entity(
name="test_project_metadata_from_cache_on_init_success",
description="test_project_metadata_from_cache_on_init_success",
tags={"team": "matchmaking"},
)
# Register Entity
test_registry.apply_entity(entity, project)

project_metadata = test_registry.get_project_metadata(project)
assert project_metadata.project_name == project
assert project_metadata.last_updated_timestamp > last_refresh_timestamp

test_registry.refresh()
project_metadata = test_registry.get_project_metadata(project, allow_cache=True)
assert project_metadata.project_name == project
assert project_metadata.last_updated_timestamp > last_refresh_timestamp

project_metadata_list = test_registry.get_all_projects()
assert len(project_metadata_list) == 1

test_registry.teardown()


@pytest.mark.integration
@pytest.mark.parametrize(
"test_registry",
async_sql_fixtures,
)
def test_registry_cache_project_metadata_thread_async(test_registry):
project = "project"
# Wait for cache to be refreshed
time.sleep(4)
# Now objects exist
project_metadata = test_registry.get_project_metadata(project, allow_cache=True)
assert project_metadata is not None
assert project_metadata.project_name == project

project_metadata_list = test_registry.get_all_projects()
assert len(project_metadata_list) == 1

test_registry.teardown()
9 changes: 4 additions & 5 deletions sdk/python/tests/unit/test_on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,14 @@ def test_from_proto_backwards_compatible_udf():
proto.spec.feature_transformation.user_defined_function.body_text
)

# And now we're going to null the feature_transformation proto object before reserializing the entire proto
# proto.spec.user_defined_function.body_text = on_demand_feature_view.transformation.udf_string
proto.spec.feature_transformation.user_defined_function.name = ""
proto.spec.feature_transformation.user_defined_function.body = b""
proto.spec.feature_transformation.user_defined_function.body_text = ""
# For objects that are already registered, feature_transformation and mode is not set
proto.spec.feature_transformation.Clear()
proto.spec.ClearField("mode")

# And now we expect the to get the same object back under feature_transformation
reserialized_proto = OnDemandFeatureView.from_proto(proto)
assert (
reserialized_proto.feature_transformation.udf_string
== on_demand_feature_view.feature_transformation.udf_string
)
assert reserialized_proto.mode == "pandas"
Loading