Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
converting create_data_source to take in arbitrary arguments
Signed-off-by: Chester Ong <chester.ong.ch@gmail.com>
  • Loading branch information
bushwhackr committed Feb 11, 2024
commit 3ecff669dc4f1c817c7a7804feca495f5ff35bae
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ def __init__(self, project_name: str, *args, **kwargs):
def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
suffix: Optional[str] = None,
timestamp_field="ts",
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
**kwargs,
) -> DataSource:
destination_name = kwargs.get("destination_name")
if not destination_name:
raise ValueError("destination_name is required")
timestamp_field = kwargs.get("timestamp_field", "ts")
created_timestamp_column = kwargs.get("created_timestamp_column", "created_ts")
field_mapping = kwargs.get("field_mapping", None)

table_name = destination_name
s3_target = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ def create_offline_store_config(self) -> MsSqlServerOfflineStoreConfig:
def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
timestamp_field="ts",
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
**kwargs,
) -> DataSource:
destination_name = kwargs.get("destination_name")
if not destination_name:
raise ValueError("destination_name is required")
timestamp_field = kwargs.get("timestamp_field", "ts")
created_timestamp_column = kwargs.get("created_timestamp_column", "created_ts")
field_mapping = kwargs.get("field_mapping", None)

# Make sure the field mapping is correct and convert the datetime datasources.
if timestamp_field in df:
df[timestamp_field] = pd.to_datetime(df[timestamp_field], utc=True).fillna(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,15 @@ def __init__(
def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
suffix: Optional[str] = None,
timestamp_field="ts",
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
**kwargs,
) -> DataSource:
destination_name = kwargs.get("destination_name")
if not destination_name:
raise ValueError("destination_name is required")
timestamp_field = kwargs.get("timestamp_field", "ts")
created_timestamp_column = kwargs.get("created_timestamp_column", "created_ts")
field_mapping = kwargs.get("field_mapping", None)

destination_name = self.get_prefixed_table_name(destination_name)

if self.offline_store_config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,15 @@ def create_offline_store_config(self):
def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
timestamp_field="ts",
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
**kwargs,
) -> DataSource:
destination_name = kwargs.get("destination_name")
if not destination_name:
raise ValueError("destination_name is required")
timestamp_field = kwargs.get("timestamp_field", "ts")
created_timestamp_column = kwargs.get("created_timestamp_column", "created_ts")
field_mapping = kwargs.get("field_mapping", None)

if timestamp_field in df:
df[timestamp_field] = pd.to_datetime(df[timestamp_field], utc=True)
# Make sure the field mapping is correct and convert the datetime datasources.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,15 @@ def teardown(self):
def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
suffix: Optional[str] = None,
timestamp_field="ts",
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
**kwargs,
) -> DataSource:
destination_name = kwargs.get("destination_name")
if not destination_name:
raise ValueError("destination_name is required")
timestamp_field = kwargs.get("timestamp_field", "ts")
created_timestamp_column = kwargs.get("created_timestamp_column", "created_ts")
field_mapping = kwargs.get("field_mapping", None)

destination_name = self.get_prefixed_table_name(destination_name)
self.client.execute_query(
f"CREATE SCHEMA IF NOT EXISTS memory.{self.project_name}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ def __init__(self, project_name: str, *args, **kwargs):
def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
event_timestamp_column="ts",
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
timestamp_field: Optional[str] = None,
**kwargs,
) -> DataSource:
Comment on lines 16 to 20
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

functionality change: due to create_data_source needing an indeterminate number of arguments. The base class should accept kwargs and the concrete class should then process user arguments on whether sufficient inputs were given.

"""
Create a data source based on the dataframe. Implementing this method requires the underlying implementation to
Expand All @@ -30,13 +26,7 @@ def create_data_source(

Args:
df: The dataframe to be used to create the data source.
destination_name: This str is used by the implementing classes to
isolate the multiple dataframes from each other.
event_timestamp_column: (Deprecated) Pass through for the underlying data source.
created_timestamp_column: Pass through for the underlying data source.
field_mapping: Pass through for the underlying data source.
timestamp_field: Pass through for the underlying data source.

kwargs: Additional arguments to be passed to the underlying data source.

Returns:
A Data source object, pointing to a table or file that is uploaded/persisted for the purpose of the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,14 @@ def create_offline_store_config(self):
def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
timestamp_field="ts",
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
**kwargs,
) -> DataSource:

destination_name = kwargs.get("destination_name")
if not destination_name:
raise ValueError("destination_name is required")
timestamp_field = kwargs.get("timestamp_field", "ts")
created_timestamp_column = kwargs.get("created_timestamp_column", "created_ts")
field_mapping = kwargs.get("field_mapping", None)
destination_name = self.get_prefixed_table_name(destination_name)

self.create_dataset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ def __init__(self, project_name: str, *args, **kwargs):
def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
timestamp_field="ts",
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
**kwargs,
) -> DataSource:
destination_name = kwargs.get("destination_name")
if not destination_name:
raise ValueError("destination_name is required")
timestamp_field = kwargs.get("timestamp_field", "ts")
created_timestamp_column = kwargs.get("created_timestamp_column", "created_ts")
field_mapping = kwargs.get("field_mapping", None)

destination_name = self.get_prefixed_table_name(destination_name)

Expand Down Expand Up @@ -93,11 +96,14 @@ class FileParquetDatasetSourceCreator(FileDataSourceCreator):
def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
timestamp_field="ts",
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
**kwargs,
) -> DataSource:
destination_name = kwargs.get("destination_name")
if not destination_name:
raise ValueError("destination_name is required")
timestamp_field = kwargs.get("timestamp_field", "ts")
created_timestamp_column = kwargs.get("created_timestamp_column", "created_ts")
field_mapping = kwargs.get("field_mapping", None)

destination_name = self.get_prefixed_table_name(destination_name)

Expand Down Expand Up @@ -167,12 +173,15 @@ def _upload_parquet_file(self, df, file_name, minio_endpoint):
def create_data_source(
self,
df: pd.DataFrame,
destination_name: Optional[str] = None,
suffix: Optional[str] = None,
timestamp_field="ts",
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
**kwargs,
) -> DataSource:
destination_name = kwargs.get("destination_name")
if not destination_name:
raise ValueError("destination_name is required")
timestamp_field = kwargs.get("timestamp_field", "ts")
created_timestamp_column = kwargs.get("created_timestamp_column", "created_ts")
field_mapping = kwargs.get("field_mapping", None)

filename = f"{destination_name}.parquet"
port = self.minio.get_exposed_port("9000")
host = self.minio.get_container_host_ip()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ def __init__(self, project_name: str, *args, **kwargs):
def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
suffix: Optional[str] = None,
timestamp_field="ts",
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
**kwargs,
) -> DataSource:
destination_name = kwargs.get("destination_name")
if not destination_name:
raise ValueError("destination_name is required")
timestamp_field = kwargs.get("timestamp_field", "ts")
created_timestamp_column = kwargs.get("created_timestamp_column", "created_ts")
field_mapping = kwargs.get("field_mapping", None)

destination_name = self.get_prefixed_table_name(destination_name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ def __init__(self, project_name: str, *args, **kwargs):
def create_data_source(
self,
df: pd.DataFrame,
destination_name: str,
suffix: Optional[str] = None,
timestamp_field="ts",
created_timestamp_column="created_ts",
field_mapping: Optional[Dict[str, str]] = None,
**kwargs,
) -> DataSource:
destination_name = kwargs.get("destination_name")
if not destination_name:
raise ValueError("destination_name is required")
timestamp_column = kwargs.get("timestamp_column", "ts")
created_timestamp_column = kwargs.get("created_timestamp_column", "created_ts")
field_mapping = kwargs.get("field_mapping", None)

destination_name = self.get_prefixed_table_name(destination_name)

Expand All @@ -63,7 +65,7 @@ def create_data_source(

return SnowflakeSource(
table=destination_name,
timestamp_field=timestamp_field,
timestamp_field=timestamp_column,
created_timestamp_column=created_timestamp_column,
field_mapping=field_mapping or {"ts_1": "ts"},
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from typing import Dict

from google.cloud import bigtable
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs

from feast.repo_config import FeastConfigBaseModel
from tests.integration.feature_repos.universal.online_store_creator import (
OnlineStoreCreator,
)
Expand All @@ -28,19 +28,19 @@ def __init__(self, project_name: str, **kwargs):
.with_exposed_ports(self.port)
)

def create_online_store(self) -> Dict[str, str]:
def create_online_store(self) -> FeastConfigBaseModel:
self.container.start()
log_string_to_wait_for = r"\[bigtable\] Cloud Bigtable emulator running"
wait_for_logs(
container=self.container, predicate=log_string_to_wait_for, timeout=10
)
exposed_port = self.container.get_exposed_port(self.port)
os.environ[bigtable.client.BIGTABLE_EMULATOR] = f"{self.host}:{exposed_port}"
return {
"type": "bigtable",
"project_id": self.gcp_project,
"instance": self.bt_instance,
}
return FeastConfigBaseModel(
type="bigtable",
project_id=self.gcp_project,
instance=self.bt_instance,
)

def teardown(self):
del os.environ[bigtable.client.BIGTABLE_EMULATOR]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs

from feast.repo_config import FeastConfigBaseModel
from tests.integration.feature_repos.universal.online_store_creator import (
OnlineStoreCreator,
)
Expand All @@ -23,15 +24,15 @@ def __init__(self, project_name: str, **kwargs):
.with_exposed_ports("8081")
)

def create_online_store(self) -> Dict[str, str]:
def create_online_store(self) -> FeastConfigBaseModel:
self.container.start()
log_string_to_wait_for = r"\[datastore\] Dev App Server is now running"
wait_for_logs(
container=self.container, predicate=log_string_to_wait_for, timeout=10
)
exposed_port = self.container.get_exposed_port("8081")
os.environ[datastore.client.DATASTORE_EMULATOR_HOST] = f"0.0.0.0:{exposed_port}"
return {"type": "datastore", "project_id": "test-project"}
return FeastConfigBaseModel(type="datastore", project_id="test-project")

def teardown(self):
del os.environ[datastore.client.DATASTORE_EMULATOR_HOST]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs

from feast.repo_config import FeastConfigBaseModel
from tests.integration.feature_repos.universal.online_store_creator import (
OnlineStoreCreator,
)
Expand All @@ -15,7 +16,7 @@ def __init__(self, project_name: str, **kwargs):
"amazon/dynamodb-local:latest"
).with_exposed_ports("8000")

def create_online_store(self) -> Dict[str, str]:
def create_online_store(self) -> FeastConfigBaseModel:
self.container.start()
log_string_to_wait_for = (
"Initializing DynamoDB Local with the following configuration:"
Expand All @@ -24,11 +25,11 @@ def create_online_store(self) -> Dict[str, str]:
container=self.container, predicate=log_string_to_wait_for, timeout=10
)
exposed_port = self.container.get_exposed_port("8000")
return {
"type": "dynamodb",
"endpoint_url": f"http://localhost:{exposed_port}",
"region": "us-west-2",
}
return FeastConfigBaseModel(
type="dynamodb",
endpoint_url=f"http://localhost:{exposed_port}",
region="us-west-2",
)

def teardown(self):
self.container.stop()