Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix custom registry store parsing
Signed-off-by: Danny Chiao <danny@tecton.ai>
  • Loading branch information
adchia committed Mar 12, 2023
commit 10e0bdc2d88faedc8d393088ceb69704a8f9c4da
8 changes: 5 additions & 3 deletions docs/tutorials/azure/notebooks/src/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import json
import joblib
from feast import FeatureStore, RepoConfig
from feast.infra.registry.registry import RegistryConfig
from feast.repo_config import RegistryConfig

from feast.infra.offline_stores.contrib.mssql_offline_store.mssql import MsSqlServerOfflineStoreConfig
from feast.infra.offline_stores.contrib.mssql_offline_store.mssql import (
MsSqlServerOfflineStoreConfig,
)
from feast.infra.online_stores.redis import RedisOnlineStoreConfig, RedisOnlineStore


Expand Down Expand Up @@ -73,4 +75,4 @@ def run(raw_data):
y_hat = model.predict(data)
return y_hat.tolist()
else:
return 0.0
return 0.0
2 changes: 1 addition & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def refresh_registry(self):
greater than 0, then once the cache becomes stale (more time than the TTL has passed), a new cache will be
downloaded synchronously, which may increase latencies if the triggering method is get_online_features().
"""
registry_config = self.config.get_registry_config()
registry_config = self.config.registry
registry = Registry(
self.config.project, registry_config, repo_path=self.repo_path
)
Expand Down
4 changes: 1 addition & 3 deletions sdk/python/feast/infra/registry/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ class FeastMetadataKeys(Enum):
class SnowflakeRegistryConfig(RegistryConfig):
"""Registry config for Snowflake"""

registry_type: Literal["snowflake.registry"] = Field(
"snowflake.registry", alias="type"
)
registry_type: Literal["snowflake.registry"] = "snowflake.registry"
""" Registry type selector """

type: Literal["snowflake.registry"] = "snowflake.registry"
Expand Down
28 changes: 19 additions & 9 deletions sdk/python/feast/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
# - existing values for the online store type in featurestore.yaml files continue to work in a backwards compatible way
# - first party and third party implementations can use the same class loading code path.
REGISTRY_CLASS_FOR_TYPE = {
"file": "feast.infra.registry.Registry",
"sql": "feast.infra.registry.SqlRegistry",
"file": "feast.infra.registry.registry.Registry",
"sql": "feast.infra.registry.sql.SqlRegistry",
"snowflake.registry": "feast.infra.registry.snowflake.SnowflakeRegistry",
}

Expand Down Expand Up @@ -109,14 +109,15 @@ class RegistryConfig(FeastBaseModel):
"""Metadata Store Configuration. Configuration that relates to reading from and writing to the Feast registry."""

registry_type: StrictStr = "file"
""" str: Provider name or a class name that implements Registry.
If specified, registry_store_type should be redundant."""
""" str: Provider name or a class name that implements Registry."""

registry_store_type: Optional[StrictStr]
""" str: Provider name or a class name that implements RegistryStore. """

path: StrictStr = ""
""" str: Path to metadata store. Can be a local path, or remote object storage path, e.g. a GCS URI """
""" str: Path to metadata store.
If registry_type is 'file', then an be a local path, or remote object storage path, e.g. a GCS URI
If registry_type is 'sql', then this is a database URL as expected by SQLAlchemy """

cache_ttl_seconds: StrictInt = 600
"""int: The cache TTL is the amount of time registry state will be cached in memory. If this TTL is exceeded then
Expand All @@ -141,7 +142,12 @@ class RepoConfig(FeastBaseModel):
""" str: local or gcp or aws """

_registry_config: Any = Field(alias="registry", default="data/registry.db")
""" str: Path to metadata store. Can be a local path, or remote object storage path, e.g. a GCS URI """
""" Configures the registry.
Can be:
1. str: a path to a file based registry (a local path, or remote object storage path, e.g. a GCS URI)
2. RegistryConfig: A fully specified file based registry or SQL based registry
3. SnowflakeRegistryConfig: Using a Snowflake table to store the registry
"""

_online_config: Any = Field(alias="online_store")
""" OnlineStoreConfig: Online store configuration (optional depending on provider) """
Expand Down Expand Up @@ -240,9 +246,13 @@ def __init__(self, **data: Any):
def registry(self):
if not self._registry:
if isinstance(self._registry_config, Dict):
self._registry = get_registry_config_from_type(
self._registry_config["type"]
)(**self._registry_config)
if "registry_type" in self._registry_config:
self._registry = get_registry_config_from_type(
self._registry_config["registry_type"]
)(**self._registry_config)
else:
# This may be a custom registry store, which does not need a 'registry_type'
self._registry = RegistryConfig(**self._registry_config)
elif isinstance(self._registry_config, str):
# User passed in just a path to file registry
self._registry = get_registry_config_from_type("file")(
Expand Down
2 changes: 2 additions & 0 deletions sdk/python/tests/unit/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def setup_third_party_provider_repo(provider_name: str):
type: sqlite
offline_store:
type: file
entity_key_serialization_version: 2
"""
)
)
Expand Down Expand Up @@ -159,6 +160,7 @@ def setup_third_party_registry_store_repo(
type: sqlite
offline_store:
type: file
entity_key_serialization_version: 2
"""
)
)
Expand Down
35 changes: 28 additions & 7 deletions sdk/python/tests/unit/infra/test_inference_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ def test_feature_view_inference_respects_basic_inference():
[feature_view_1],
[entity1],
RepoConfig(
provider="local", project="test", entity_key_serialization_version=2
provider="local",
project="test",
entity_key_serialization_version=2,
registry="dummy_registry.pb",
),
)
assert len(feature_view_1.schema) == 2
Expand All @@ -209,7 +212,10 @@ def test_feature_view_inference_respects_basic_inference():
[feature_view_2],
[entity1, entity2],
RepoConfig(
provider="local", project="test", entity_key_serialization_version=2
provider="local",
project="test",
entity_key_serialization_version=2,
registry="dummy_registry.pb",
),
)
assert len(feature_view_2.schema) == 3
Expand Down Expand Up @@ -240,7 +246,10 @@ def test_feature_view_inference_on_entity_value_types():
[feature_view_1],
[entity1],
RepoConfig(
provider="local", project="test", entity_key_serialization_version=2
provider="local",
project="test",
entity_key_serialization_version=2,
registry="dummy_registry.pb",
),
)

Expand Down Expand Up @@ -310,7 +319,10 @@ def test_feature_view_inference_on_entity_columns(simple_dataset_1):
[feature_view_1],
[entity1],
RepoConfig(
provider="local", project="test", entity_key_serialization_version=2
provider="local",
project="test",
entity_key_serialization_version=2,
registry="dummy_registry.pb",
),
)

Expand Down Expand Up @@ -345,7 +357,10 @@ def test_feature_view_inference_on_feature_columns(simple_dataset_1):
[feature_view_1],
[entity1],
RepoConfig(
provider="local", project="test", entity_key_serialization_version=2
provider="local",
project="test",
entity_key_serialization_version=2,
registry="dummy_registry.pb",
),
)

Expand Down Expand Up @@ -397,7 +412,10 @@ def test_update_feature_services_with_inferred_features(simple_dataset_1):
[feature_view_1, feature_view_2],
[entity1],
RepoConfig(
provider="local", project="test", entity_key_serialization_version=2
provider="local",
project="test",
entity_key_serialization_version=2,
registry="dummy_registry.pb",
),
)
feature_service.infer_features(
Expand Down Expand Up @@ -454,7 +472,10 @@ def test_update_feature_services_with_specified_features(simple_dataset_1):
[feature_view_1, feature_view_2],
[entity1],
RepoConfig(
provider="local", project="test", entity_key_serialization_version=2
provider="local",
project="test",
entity_key_serialization_version=2,
registry="dummy_registry.pb",
),
)
assert len(feature_view_1.features) == 1
Expand Down
12 changes: 6 additions & 6 deletions sdk/python/tests/unit/online_store/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_online() -> None:
fs_fast_ttl = FeatureStore(
config=RepoConfig(
registry=RegistryConfig(
path=store.config.registry, cache_ttl_seconds=cache_ttl
path=store.config.registry.path, cache_ttl_seconds=cache_ttl
),
online_store=store.config.online_store,
project=store.project,
Expand All @@ -161,7 +161,7 @@ def test_online() -> None:
assert result["trips"] == [7]

# Rename the registry.db so that it cant be used for refreshes
os.rename(store.config.registry, store.config.registry + "_fake")
os.rename(store.config.registry.path, store.config.registry.path + "_fake")

# Wait for registry to expire
time.sleep(cache_ttl)
Expand All @@ -180,7 +180,7 @@ def test_online() -> None:
).to_dict()

# Restore registry.db so that we can see if it actually reloads registry
os.rename(store.config.registry + "_fake", store.config.registry)
os.rename(store.config.registry.path + "_fake", store.config.registry.path)

# Test if registry is actually reloaded and whether results return
result = fs_fast_ttl.get_online_features(
Expand All @@ -200,7 +200,7 @@ def test_online() -> None:
fs_infinite_ttl = FeatureStore(
config=RepoConfig(
registry=RegistryConfig(
path=store.config.registry, cache_ttl_seconds=0
path=store.config.registry.path, cache_ttl_seconds=0
),
online_store=store.config.online_store,
project=store.project,
Expand All @@ -227,7 +227,7 @@ def test_online() -> None:
time.sleep(2)

# Rename the registry.db so that it cant be used for refreshes
os.rename(store.config.registry, store.config.registry + "_fake")
os.rename(store.config.registry.path, store.config.registry.path + "_fake")

# TTL is infinite so this method should use registry cache
result = fs_infinite_ttl.get_online_features(
Expand All @@ -248,7 +248,7 @@ def test_online() -> None:
fs_infinite_ttl.refresh_registry()

# Restore registry.db so that teardown works
os.rename(store.config.registry + "_fake", store.config.registry)
os.rename(store.config.registry.path + "_fake", store.config.registry.path)


def test_online_to_df():
Expand Down