1313# limitations under the License.
1414import logging
1515import multiprocessing
16+ import os
1617import shutil
18+ import tempfile
19+ import uuid
1720from datetime import datetime
1821from itertools import groupby
1922from typing import Any , Dict , List , Optional , Union
23+ from urllib .parse import urlparse
2024
2125import grpc
2226import pandas as pd
3438 CONFIG_SERVING_URL_KEY ,
3539 CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT ,
3640 CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION ,
41+ CONFIG_SPARK_STAGING_LOCATION ,
3742 FEAST_DEFAULT_OPTIONS ,
3843)
3944from feast .core .CoreService_pb2 import (
8893 GetOnlineFeaturesRequestV2 ,
8994)
9095from feast .serving .ServingService_pb2_grpc import ServingServiceStub
96+ from feast .staging .storage_client import get_staging_client
9197
9298_logger = logging .getLogger (__name__ )
9399
@@ -780,7 +786,7 @@ def get_online_features(
780786 def get_historical_features (
781787 self ,
782788 feature_refs : List [str ],
783- entity_source : Union [FileSource , BigQuerySource ],
789+ entity_source : Union [pd . DataFrame , FileSource , BigQuerySource ],
784790 project : str = None ,
785791 ) -> RetrievalJob :
786792 """
@@ -791,9 +797,14 @@ def get_historical_features(
791797 Each feature reference should have the following format:
792798 "feature_table:feature" where "feature_table" & "feature" refer to
793799 the feature and feature table names respectively.
794- entity_source (Union[FileSource, BigQuerySource]): Source for the entity rows.
795- The user needs to make sure that the source is accessible from the Spark cluster
796- that will be used for the retrieval job.
800+ entity_source (Union[pd.DataFrame, FileSource, BigQuerySource]): Source for the entity rows.
801+ If entity_source is a Panda DataFrame, the dataframe will be exported to the staging
802+ location as parquet file. It is also assumed that the column event_timestamp is present
803+ in the dataframe, and is of type datetime without timezone information.
804+
805+ The user needs to make sure that the source (or staging location, if entity_source is
806+ a Panda DataFrame) is accessible from the Spark cluster that will be used for the
807+ retrieval job.
797808 project: Specifies the project that contains the feature tables
798809 which the requested features belong to.
799810
@@ -821,6 +832,29 @@ def get_historical_features(
821832 )
822833 output_format = self ._config .get (CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT )
823834
835+ if isinstance (entity_source , pd .DataFrame ):
836+ staging_location = self ._config .get (CONFIG_SPARK_STAGING_LOCATION )
837+ entity_staging_uri = urlparse (
838+ os .path .join (staging_location , str (uuid .uuid4 ()))
839+ )
840+ staging_client = get_staging_client (entity_staging_uri .scheme )
841+ with tempfile .NamedTemporaryFile () as df_export_path :
842+ entity_source .to_parquet (df_export_path .name )
843+ bucket = (
844+ None
845+ if entity_staging_uri .scheme == "fs"
846+ else entity_staging_uri .netloc
847+ )
848+ staging_client .upload_file (
849+ df_export_path .name , bucket , entity_staging_uri .path
850+ )
851+ entity_source = FileSource (
852+ "event_timestamp" ,
853+ "created_timestamp" ,
854+ ParquetFormat (),
855+ entity_staging_uri .path ,
856+ )
857+
824858 return start_historical_feature_retrieval_job (
825859 self , entity_source , feature_tables , output_format , output_location
826860 )
0 commit comments