Skip to content

Commit 029d9ab

Browse files
authored
Accept Pandas dataframe as input for historical feature retrieval (feast-dev#1071)
* Accept Pandas dataframe as input for historical feature retrieval Signed-off-by: Khor Shu Heng <khor.heng@gojek.com> * Add missing import Signed-off-by: Khor Shu Heng <khor.heng@gojek.com> Co-authored-by: Khor Shu Heng <khor.heng@gojek.com>
1 parent 0d5cb8d commit 029d9ab

3 files changed

Lines changed: 43 additions & 10 deletions

File tree

sdk/python/feast/client.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
# limitations under the License.
1414
import logging
1515
import multiprocessing
16+
import os
1617
import shutil
18+
import tempfile
19+
import uuid
1720
from datetime import datetime
1821
from itertools import groupby
1922
from typing import Any, Dict, List, Optional, Union
23+
from urllib.parse import urlparse
2024

2125
import grpc
2226
import pandas as pd
@@ -34,6 +38,7 @@
3438
CONFIG_SERVING_URL_KEY,
3539
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT,
3640
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION,
41+
CONFIG_SPARK_STAGING_LOCATION,
3742
FEAST_DEFAULT_OPTIONS,
3843
)
3944
from feast.core.CoreService_pb2 import (
@@ -88,6 +93,7 @@
8893
GetOnlineFeaturesRequestV2,
8994
)
9095
from feast.serving.ServingService_pb2_grpc import ServingServiceStub
96+
from feast.staging.storage_client import get_staging_client
9197

9298
_logger = logging.getLogger(__name__)
9399

@@ -780,7 +786,7 @@ def get_online_features(
780786
def get_historical_features(
781787
self,
782788
feature_refs: List[str],
783-
entity_source: Union[FileSource, BigQuerySource],
789+
entity_source: Union[pd.DataFrame, FileSource, BigQuerySource],
784790
project: str = None,
785791
) -> RetrievalJob:
786792
"""
@@ -791,9 +797,14 @@ def get_historical_features(
791797
Each feature reference should have the following format:
792798
"feature_table:feature" where "feature_table" & "feature" refer to
793799
the feature and feature table names respectively.
794-
entity_source (Union[FileSource, BigQuerySource]): Source for the entity rows.
795-
The user needs to make sure that the source is accessible from the Spark cluster
796-
that will be used for the retrieval job.
800+
entity_source (Union[pd.DataFrame, FileSource, BigQuerySource]): Source for the entity rows.
801+
If entity_source is a Panda DataFrame, the dataframe will be exported to the staging
802+
location as parquet file. It is also assumed that the column event_timestamp is present
803+
in the dataframe, and is of type datetime without timezone information.
804+
805+
The user needs to make sure that the source (or staging location, if entity_source is
806+
a Panda DataFrame) is accessible from the Spark cluster that will be used for the
807+
retrieval job.
797808
project: Specifies the project that contains the feature tables
798809
which the requested features belong to.
799810
@@ -821,6 +832,29 @@ def get_historical_features(
821832
)
822833
output_format = self._config.get(CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT)
823834

835+
if isinstance(entity_source, pd.DataFrame):
836+
staging_location = self._config.get(CONFIG_SPARK_STAGING_LOCATION)
837+
entity_staging_uri = urlparse(
838+
os.path.join(staging_location, str(uuid.uuid4()))
839+
)
840+
staging_client = get_staging_client(entity_staging_uri.scheme)
841+
with tempfile.NamedTemporaryFile() as df_export_path:
842+
entity_source.to_parquet(df_export_path.name)
843+
bucket = (
844+
None
845+
if entity_staging_uri.scheme == "fs"
846+
else entity_staging_uri.netloc
847+
)
848+
staging_client.upload_file(
849+
df_export_path.name, bucket, entity_staging_uri.path
850+
)
851+
entity_source = FileSource(
852+
"event_timestamp",
853+
"created_timestamp",
854+
ParquetFormat(),
855+
entity_staging_uri.path,
856+
)
857+
824858
return start_historical_feature_retrieval_job(
825859
self, entity_source, feature_tables, output_format, output_location
826860
)

sdk/python/feast/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class AuthProvider(Enum):
6767
# Spark Job Config
6868
CONFIG_SPARK_LAUNCHER = "spark_launcher" # standalone, dataproc, emr
6969

70+
CONFIG_SPARK_STAGING_LOCATION = "spark_staging_location"
71+
7072
CONFIG_SPARK_INGESTION_JOB_JAR = "spark_ingestion_jar"
7173

7274
CONFIG_SPARK_STANDALONE_MASTER = "spark_standalone_master"
@@ -75,7 +77,6 @@ class AuthProvider(Enum):
7577
CONFIG_SPARK_DATAPROC_CLUSTER_NAME = "dataproc_cluster_name"
7678
CONFIG_SPARK_DATAPROC_PROJECT = "dataproc_project"
7779
CONFIG_SPARK_DATAPROC_REGION = "dataproc_region"
78-
CONFIG_SPARK_DATAPROC_STAGING_LOCATION = "dataproc_staging_location"
7980

8081
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT = "historical_feature_output_format"
8182
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION = "historical_feature_output_location"
@@ -87,7 +88,6 @@ class AuthProvider(Enum):
8788
CONFIG_SPARK_EMR_REGION = "emr_region"
8889
CONFIG_SPARK_EMR_CLUSTER_ID = "emr_cluster_id"
8990
CONFIG_SPARK_EMR_CLUSTER_TEMPLATE_PATH = "emr_cluster_template_path"
90-
CONFIG_SPARK_EMR_STAGING_LOCATION = "emr_staging_location"
9191
CONFIG_SPARK_EMR_LOG_LOCATION = "emr_log_location"
9292

9393

sdk/python/feast/pyspark/launcher.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212
CONFIG_SPARK_DATAPROC_CLUSTER_NAME,
1313
CONFIG_SPARK_DATAPROC_PROJECT,
1414
CONFIG_SPARK_DATAPROC_REGION,
15-
CONFIG_SPARK_DATAPROC_STAGING_LOCATION,
1615
CONFIG_SPARK_EMR_CLUSTER_ID,
1716
CONFIG_SPARK_EMR_CLUSTER_TEMPLATE_PATH,
1817
CONFIG_SPARK_EMR_LOG_LOCATION,
1918
CONFIG_SPARK_EMR_REGION,
20-
CONFIG_SPARK_EMR_STAGING_LOCATION,
2119
CONFIG_SPARK_HOME,
2220
CONFIG_SPARK_INGESTION_JOB_JAR,
2321
CONFIG_SPARK_LAUNCHER,
22+
CONFIG_SPARK_STAGING_LOCATION,
2423
CONFIG_SPARK_STANDALONE_MASTER,
2524
)
2625
from feast.data_source import BigQuerySource, DataSource, FileSource, KafkaSource
@@ -54,7 +53,7 @@ def _dataproc_launcher(config: Config) -> JobLauncher:
5453

5554
return gcloud.DataprocClusterLauncher(
5655
config.get(CONFIG_SPARK_DATAPROC_CLUSTER_NAME),
57-
config.get(CONFIG_SPARK_DATAPROC_STAGING_LOCATION),
56+
config.get(CONFIG_SPARK_STAGING_LOCATION),
5857
config.get(CONFIG_SPARK_DATAPROC_REGION),
5958
config.get(CONFIG_SPARK_DATAPROC_PROJECT),
6059
)
@@ -71,7 +70,7 @@ def _get_optional(option):
7170
region=config.get(CONFIG_SPARK_EMR_REGION),
7271
existing_cluster_id=_get_optional(CONFIG_SPARK_EMR_CLUSTER_ID),
7372
new_cluster_template_path=_get_optional(CONFIG_SPARK_EMR_CLUSTER_TEMPLATE_PATH),
74-
staging_location=config.get(CONFIG_SPARK_EMR_STAGING_LOCATION),
73+
staging_location=config.get(CONFIG_SPARK_STAGING_LOCATION),
7574
emr_log_location=config.get(CONFIG_SPARK_EMR_LOG_LOCATION),
7675
)
7776

0 commit comments

Comments
 (0)