1717from typing import Callable , Dict , Iterable , Optional , Tuple
1818
1919from pyarrow .parquet import ParquetFile
20+ from tenacity import retry , retry_unless_exception_type , wait_exponential
2021
2122from feast import type_map
2223from feast .data_format import FileFormat , StreamFormat
23- from feast .errors import DataSourceNotFoundException
24+ from feast .errors import (
25+ DataSourceNotFoundException ,
26+ RedshiftCredentialsError ,
27+ RedshiftQueryError ,
28+ )
2429from feast .protos .feast .core .DataSource_pb2 import DataSource as DataSourceProto
30+ from feast .repo_config import RepoConfig
2531from feast .value_type import ValueType
2632
2733
@@ -477,6 +483,15 @@ def from_proto(data_source):
477483 date_partition_column = data_source .date_partition_column ,
478484 query = data_source .bigquery_options .query ,
479485 )
486+ elif data_source .redshift_options .table or data_source .redshift_options .query :
487+ data_source_obj = RedshiftSource (
488+ field_mapping = data_source .field_mapping ,
489+ table = data_source .redshift_options .table ,
490+ event_timestamp_column = data_source .event_timestamp_column ,
491+ created_timestamp_column = data_source .created_timestamp_column ,
492+ date_partition_column = data_source .date_partition_column ,
493+ query = data_source .redshift_options .query ,
494+ )
480495 elif (
481496 data_source .kafka_options .bootstrap_servers
482497 and data_source .kafka_options .topic
@@ -520,12 +535,27 @@ def to_proto(self) -> DataSourceProto:
520535 """
521536 raise NotImplementedError
522537
523- def validate (self ):
538+ def validate (self , config : RepoConfig ):
524539 """
525540 Validates the underlying data source.
526541 """
527542 raise NotImplementedError
528543
544+ @staticmethod
545+ def source_datatype_to_feast_value_type () -> Callable [[str ], ValueType ]:
546+ """
547+ Get the callable method that returns Feast type given the raw column type
548+ """
549+ raise NotImplementedError
550+
551+ def get_table_column_names_and_types (
552+ self , config : RepoConfig
553+ ) -> Iterable [Tuple [str , str ]]:
554+ """
555+ Get the list of column names and raw column types
556+ """
557+ raise NotImplementedError
558+
529559
530560class FileSource (DataSource ):
531561 def __init__ (
@@ -622,15 +652,17 @@ def to_proto(self) -> DataSourceProto:
622652
623653 return data_source_proto
624654
625- def validate (self ):
655+ def validate (self , config : RepoConfig ):
626656 # TODO: validate a FileSource
627657 pass
628658
629659 @staticmethod
630660 def source_datatype_to_feast_value_type () -> Callable [[str ], ValueType ]:
631661 return type_map .pa_to_feast_value_type
632662
633- def get_table_column_names_and_types (self ) -> Iterable [Tuple [str , str ]]:
663+ def get_table_column_names_and_types (
664+ self , config : RepoConfig
665+ ) -> Iterable [Tuple [str , str ]]:
634666 schema = ParquetFile (self .path ).schema_arrow
635667 return zip (schema .names , map (str , schema .types ))
636668
@@ -703,7 +735,7 @@ def to_proto(self) -> DataSourceProto:
703735
704736 return data_source_proto
705737
706- def validate (self ):
738+ def validate (self , config : RepoConfig ):
707739 if not self .query :
708740 from google .api_core .exceptions import NotFound
709741 from google .cloud import bigquery
@@ -725,7 +757,9 @@ def get_table_query_string(self) -> str:
725757 def source_datatype_to_feast_value_type () -> Callable [[str ], ValueType ]:
726758 return type_map .bq_to_feast_value_type
727759
728- def get_table_column_names_and_types (self ) -> Iterable [Tuple [str , str ]]:
760+ def get_table_column_names_and_types (
761+ self , config : RepoConfig
762+ ) -> Iterable [Tuple [str , str ]]:
729763 from google .cloud import bigquery
730764
731765 client = bigquery .Client ()
@@ -875,3 +909,223 @@ def to_proto(self) -> DataSourceProto:
875909 data_source_proto .date_partition_column = self .date_partition_column
876910
877911 return data_source_proto
912+
913+
914+ class RedshiftOptions :
915+ """
916+ DataSource Redshift options used to source features from Redshift query
917+ """
918+
919+ def __init__ (self , table : Optional [str ], query : Optional [str ]):
920+ self ._table = table
921+ self ._query = query
922+
923+ @property
924+ def query (self ):
925+ """
926+ Returns the Redshift SQL query referenced by this source
927+ """
928+ return self ._query
929+
930+ @query .setter
931+ def query (self , query ):
932+ """
933+ Sets the Redshift SQL query referenced by this source
934+ """
935+ self ._query = query
936+
937+ @property
938+ def table (self ):
939+ """
940+ Returns the table name of this Redshift table
941+ """
942+ return self ._table
943+
944+ @table .setter
945+ def table (self , table_name ):
946+ """
947+ Sets the table ref of this Redshift table
948+ """
949+ self ._table = table_name
950+
951+ @classmethod
952+ def from_proto (cls , redshift_options_proto : DataSourceProto .RedshiftOptions ):
953+ """
954+ Creates a RedshiftOptions from a protobuf representation of a Redshift option
955+
956+ Args:
957+ redshift_options_proto: A protobuf representation of a DataSource
958+
959+ Returns:
960+ Returns a RedshiftOptions object based on the redshift_options protobuf
961+ """
962+
963+ redshift_options = cls (
964+ table = redshift_options_proto .table , query = redshift_options_proto .query ,
965+ )
966+
967+ return redshift_options
968+
969+ def to_proto (self ) -> DataSourceProto .RedshiftOptions :
970+ """
971+ Converts an RedshiftOptionsProto object to its protobuf representation.
972+
973+ Returns:
974+ RedshiftOptionsProto protobuf
975+ """
976+
977+ redshift_options_proto = DataSourceProto .RedshiftOptions (
978+ table = self .table , query = self .query ,
979+ )
980+
981+ return redshift_options_proto
982+
983+
984+ class RedshiftSource (DataSource ):
985+ def __init__ (
986+ self ,
987+ event_timestamp_column : Optional [str ] = "" ,
988+ table : Optional [str ] = None ,
989+ created_timestamp_column : Optional [str ] = "" ,
990+ field_mapping : Optional [Dict [str , str ]] = None ,
991+ date_partition_column : Optional [str ] = "" ,
992+ query : Optional [str ] = None ,
993+ ):
994+ super ().__init__ (
995+ event_timestamp_column ,
996+ created_timestamp_column ,
997+ field_mapping ,
998+ date_partition_column ,
999+ )
1000+
1001+ self ._redshift_options = RedshiftOptions (table = table , query = query )
1002+
1003+ def __eq__ (self , other ):
1004+ if not isinstance (other , RedshiftSource ):
1005+ raise TypeError (
1006+ "Comparisons should only involve RedshiftSource class objects."
1007+ )
1008+
1009+ return (
1010+ self .redshift_options .table == other .redshift_options .table
1011+ and self .redshift_options .query == other .redshift_options .query
1012+ and self .event_timestamp_column == other .event_timestamp_column
1013+ and self .created_timestamp_column == other .created_timestamp_column
1014+ and self .field_mapping == other .field_mapping
1015+ )
1016+
1017+ @property
1018+ def table (self ):
1019+ return self ._redshift_options .table
1020+
1021+ @property
1022+ def query (self ):
1023+ return self ._redshift_options .query
1024+
1025+ @property
1026+ def redshift_options (self ):
1027+ """
1028+ Returns the Redshift options of this data source
1029+ """
1030+ return self ._redshift_options
1031+
1032+ @redshift_options .setter
1033+ def redshift_options (self , _redshift_options ):
1034+ """
1035+ Sets the Redshift options of this data source
1036+ """
1037+ self ._redshift_options = _redshift_options
1038+
1039+ def to_proto (self ) -> DataSourceProto :
1040+ data_source_proto = DataSourceProto (
1041+ type = DataSourceProto .BATCH_REDSHIFT ,
1042+ field_mapping = self .field_mapping ,
1043+ redshift_options = self .redshift_options .to_proto (),
1044+ )
1045+
1046+ data_source_proto .event_timestamp_column = self .event_timestamp_column
1047+ data_source_proto .created_timestamp_column = self .created_timestamp_column
1048+ data_source_proto .date_partition_column = self .date_partition_column
1049+
1050+ return data_source_proto
1051+
1052+ def validate (self , config : RepoConfig ):
1053+ # As long as the query gets successfully executed, or the table exists,
1054+ # the data source is validated. We don't need the results though.
1055+ # TODO: uncomment this
1056+ # self.get_table_column_names_and_types(config)
1057+ print ("Validate" , self .get_table_column_names_and_types (config ))
1058+
1059+ def get_table_query_string (self ) -> str :
1060+ """Returns a string that can directly be used to reference this table in SQL"""
1061+ if self .table :
1062+ return f"`{ self .table } `"
1063+ else :
1064+ return f"({ self .query } )"
1065+
1066+ @staticmethod
1067+ def source_datatype_to_feast_value_type () -> Callable [[str ], ValueType ]:
1068+ return type_map .redshift_to_feast_value_type
1069+
1070+ def get_table_column_names_and_types (
1071+ self , config : RepoConfig
1072+ ) -> Iterable [Tuple [str , str ]]:
1073+ import boto3
1074+ from botocore .config import Config
1075+ from botocore .exceptions import ClientError
1076+
1077+ from feast .infra .offline_stores .redshift import RedshiftOfflineStoreConfig
1078+
1079+ assert isinstance (config .offline_store , RedshiftOfflineStoreConfig )
1080+
1081+ client = boto3 .client (
1082+ "redshift-data" , config = Config (region_name = config .offline_store .region )
1083+ )
1084+
1085+ try :
1086+ if self .table is not None :
1087+ table = client .describe_table (
1088+ ClusterIdentifier = config .offline_store .cluster_id ,
1089+ Database = config .offline_store .database ,
1090+ DbUser = config .offline_store .user ,
1091+ Table = self .table ,
1092+ )
1093+ # The API returns valid JSON with empty column list when the table doesn't exist
1094+ if len (table ["ColumnList" ]) == 0 :
1095+ raise DataSourceNotFoundException (self .table )
1096+
1097+ columns = table ["ColumnList" ]
1098+ else :
1099+ statement = client .execute_statement (
1100+ ClusterIdentifier = config .offline_store .cluster_id ,
1101+ Database = config .offline_store .database ,
1102+ DbUser = config .offline_store .user ,
1103+ Sql = f"SELECT * FROM ({ self .query } ) LIMIT 1" ,
1104+ )
1105+
1106+ # Need to retry client.describe_statement(...) until the task is finished. We don't want to bombard
1107+ # Redshift with queries, and neither do we want to wait for a long time on the initial call.
1108+ # The solution is exponential backoff. The backoff starts with 0.1 seconds and doubles exponentially
1109+ # until reaching 30 seconds, at which point the backoff is fixed.
1110+ @retry (
1111+ wait = wait_exponential (multiplier = 0.1 , max = 30 ),
1112+ retry = retry_unless_exception_type (RedshiftQueryError ),
1113+ )
1114+ def wait_for_statement ():
1115+ desc = client .describe_statement (Id = statement ["Id" ])
1116+ if desc ["Status" ] in ("SUBMITTED" , "STARTED" , "PICKED" ):
1117+ raise Exception # Retry
1118+ if desc ["Status" ] != "FINISHED" :
1119+ raise RedshiftQueryError (desc ) # Don't retry. Raise exception.
1120+
1121+ wait_for_statement ()
1122+
1123+ result = client .get_statement_result (Id = statement ["Id" ])
1124+
1125+ columns = result ["ColumnMetadata" ]
1126+ except ClientError as e :
1127+ if e .response ["Error" ]["Code" ] == "ValidationException" :
1128+ raise RedshiftCredentialsError () from e
1129+ raise
1130+
1131+ return [(column ["name" ], column ["typeName" ].upper ()) for column in columns ]
0 commit comments