Skip to content
Merged
Prev Previous commit
Next Next commit
move partition columns to destination config
Signed-off-by: pyalex <moskalenko.alexey@gmail.com>
  • Loading branch information
pyalex committed Apr 26, 2022
commit 9d1eb1a0c930431d295f7bb17a628f58fcba796b
9 changes: 7 additions & 2 deletions protos/feast/core/FeatureService.proto
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ message LoggingConfig {
message FileDestination {
string path = 1;
string s3_endpoint_override = 2;

// column names to use for partitioning
repeated string partition_by = 3;
}

message BigQueryDestination {
Expand All @@ -75,11 +78,13 @@ message LoggingConfig {
}

message RedshiftDestination {
string table_ref = 1;
// Destination table name. ClusterId and database will be taken from an offline store config
string table_name = 1;
}

message SnowflakeDestination {
string table = 1;
// Destination table name. Schema and database will be taken from an offline store config
string table_name = 1;
}

message CustomDestination {
Expand Down
9 changes: 0 additions & 9 deletions sdk/python/feast/feature_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ def get_schema(self, registry: "Registry") -> pa.Schema:
""" Generate schema for logs destination. """
raise NotImplementedError

@abc.abstractmethod
def get_partition_column(self, registry: "Registry") -> str:
""" Return partition column that must exist in generated schema. """
raise NotImplementedError

@abc.abstractmethod
def get_log_timestamp_column(self) -> str:
""" Return timestamp column that must exist in generated schema. """
Expand Down Expand Up @@ -99,15 +94,11 @@ def get_schema(self, registry: "Registry") -> pa.Schema:
# system columns
fields[REQUEST_ID_FIELD] = pa.string()
fields[LOG_TIMESTAMP_FIELD] = pa.timestamp("us", tz=UTC)
fields[LOG_DATE_FIELD] = pa.date32()

return pa.schema(
[pa.field(name, data_type) for name, data_type in fields.items()]
)

def get_partition_column(self, registry: "Registry") -> str:
return LOG_DATE_FIELD

def get_log_timestamp_column(self) -> str:
return LOG_TIMESTAMP_FIELD

Expand Down
9 changes: 4 additions & 5 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,11 @@ def write_logged_features(
job_config = bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.PARQUET,
schema=arrow_schema_to_bq_schema(source.get_schema(registry)),
time_partitioning=bigquery.TimePartitioning(
type_=bigquery.TimePartitioningType.DAY,
field=source.get_log_timestamp_column(),
),
)
partition_col = source.get_partition_column(registry)
if partition_col:
job_config.time_partitioning = bigquery.TimePartitioning(
type_=bigquery.TimePartitioningType.DAY, field=partition_col
)

with tempfile.TemporaryFile() as parquet_temp_file:
pyarrow.parquet.write_table(table=data, where=parquet_temp_file)
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/offline_stores/bigquery_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class BigQueryLoggingDestination(LoggingDestination):

table: str

def __init__(self, table_ref):
def __init__(self, *, table_ref):
self.table = table_ref

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def write_logged_features(
pyarrow.parquet.write_to_dataset(
data,
root_path=path,
partition_cols=[source.get_partition_column(registry)],
partition_cols=destination.partition_by,
filesystem=filesystem,
)

Expand Down
19 changes: 16 additions & 3 deletions sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Callable, Dict, Iterable, Optional, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Tuple

from pyarrow._fs import FileSystem
from pyarrow._s3fs import S3FileSystem
Expand Down Expand Up @@ -301,22 +301,35 @@ class FileLoggingDestination(LoggingDestination):

path: str
s3_endpoint_override: str
partition_by: Optional[List[str]]

def __init__(self, *, path: str, s3_endpoint_override=""):
def __init__(
self,
*,
path: str,
s3_endpoint_override="",
partition_by: Optional[List[str]] = None,
):
self.path = path
self.s3_endpoint_override = s3_endpoint_override
self.partition_by = partition_by

@classmethod
def from_proto(cls, config_proto: LoggingConfigProto) -> "LoggingDestination":
return FileLoggingDestination(
path=config_proto.file_destination.path,
s3_endpoint_override=config_proto.file_destination.s3_endpoint_override,
partition_by=list(config_proto.file_destination.partition_by)
if config_proto.file_destination.partition_by
else None,
)

def to_proto(self) -> LoggingConfigProto:
return LoggingConfigProto(
file_destination=LoggingConfigProto.FileDestination(
path=self.path, s3_endpoint_override=self.s3_endpoint_override,
path=self.path,
s3_endpoint_override=self.s3_endpoint_override,
partition_by=self.partition_by,
)
)

Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def write_logged_features(
s3_resource=s3_resource,
s3_path=s3_path,
iam_role=config.offline_store.iam_role,
table_name=destination.table,
table_name=destination.table_name,
schema=source.get_schema(registry),
fail_if_exists=False,
)
Expand Down
12 changes: 6 additions & 6 deletions sdk/python/feast/infra/offline_stores/redshift_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,23 +336,23 @@ def to_data_source(self) -> DataSource:
class RedshiftLoggingDestination(LoggingDestination):
_proto_kind = "redshift_destination"

table: str
table_name: str

def __init__(self, table_ref: str):
self.table = table_ref
def __init__(self, *, table_name: str):
self.table_name = table_name

@classmethod
def from_proto(cls, config_proto: LoggingConfigProto) -> "LoggingDestination":
return RedshiftLoggingDestination(
table_ref=config_proto.redshift_destination.table_ref,
table_name=config_proto.redshift_destination.table_name,
)

def to_proto(self) -> LoggingConfigProto:
return LoggingConfigProto(
redshift_destination=LoggingConfigProto.RedshiftDestination(
table_ref=self.table
table_name=self.table_name
)
)

def to_data_source(self) -> DataSource:
return RedshiftSource(table=self.table)
return RedshiftSource(table=self.table_name)
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def write_logged_features(
write_pandas(
snowflake_conn,
data.to_pandas(),
table_name=logging_config.destination.table,
table_name=logging_config.destination.table_name,
auto_create_table=True,
)

Expand Down
12 changes: 6 additions & 6 deletions sdk/python/feast/infra/offline_stores/snowflake_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,25 +336,25 @@ def to_data_source(self) -> DataSource:


class SnowflakeLoggingDestination(LoggingDestination):
table: str
table_name: str

_proto_kind = "snowflake_destination"

def __init__(self, table: str):
self.table = table
def __init__(self, *, table_name: str):
self.table_name = table_name

@classmethod
def from_proto(cls, config_proto: LoggingConfigProto) -> "LoggingDestination":
return SnowflakeLoggingDestination(
table=config_proto.snowflake_destination.table,
table_name=config_proto.snowflake_destination.table_name,
)

def to_proto(self) -> LoggingConfigProto:
return LoggingConfigProto(
snowflake_destination=LoggingConfigProto.SnowflakeDestination(
table=self.table,
table_name=self.table_name,
)
)

def to_data_source(self) -> DataSource:
return SnowflakeSource(table=self.table,)
return SnowflakeSource(table=self.table_name,)
3 changes: 1 addition & 2 deletions sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ def retrieve_feature_service_logs(
schema = logging_source.get_schema(registry)
logging_config = feature_service.logging_config
ts_column = logging_source.get_log_timestamp_column()
partition_column = logging_source.get_partition_column(registry)
columns = list(set(schema.names) - {ts_column, partition_column})
columns = list(set(schema.names) - {ts_column})

return self.offline_store.pull_all_from_table_or_query(
config=config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def create_logged_features_destination(self) -> LoggingDestination:
)
self.tables.append(table)

return RedshiftLoggingDestination(table_ref=table)
return RedshiftLoggingDestination(table_name=table)

def create_offline_store_config(self) -> FeastConfigBaseModel:
return self.offline_store_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def create_logged_features_destination(self) -> LoggingDestination:
)
self.tables.append(table)

return SnowflakeLoggingDestination(table=table)
return SnowflakeLoggingDestination(table_name=table)

def create_offline_store_config(self) -> FeastConfigBaseModel:
return self.offline_store_config
Expand Down