1919import pyarrow
2020import pyarrow as pa
2121from dateutil import parser
22- from pydantic import StrictStr
22+ from pydantic import StrictStr , root_validator
2323from pydantic .typing import Literal
2424from pytz import utc
2525
@@ -51,15 +51,18 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel):
5151 type : Literal ["redshift" ] = "redshift"
5252 """ Offline store type selector"""
5353
54- cluster_id : StrictStr
55- """ Redshift cluster identifier """
54+ cluster_id : Optional [StrictStr ]
55+ """ Redshift cluster identifier, for provisioned clusters """
56+
57+ user : Optional [StrictStr ]
58+ """ Redshift user name, only required for provisioned clusters """
59+
60+ workgroup : Optional [StrictStr ]
61+ """ Redshift workgroup identifier, for serverless """
5662
5763 region : StrictStr
5864 """ Redshift cluster's AWS region """
5965
60- user : StrictStr
61- """ Redshift user name """
62-
6366 database : StrictStr
6467 """ Redshift database name """
6568
@@ -69,6 +72,26 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel):
6972 iam_role : StrictStr
7073 """ IAM Role for Redshift, granting it access to S3 """
7174
75+ @root_validator
76+ def require_cluster_and_user_or_workgroup (cls , values ):
77+ """
78+ Provisioned Redshift clusters: Require cluster_id and user, ignore workgroup
79+ Serverless Redshift: Require workgroup, ignore cluster_id and user
80+ """
81+ cluster_id , user , workgroup = (
82+ values .get ("cluster_id" ),
83+ values .get ("user" ),
84+ values .get ("workgroup" ),
85+ )
86+ if not (cluster_id and user ) and not workgroup :
87+ raise ValueError (
88+ "please specify either cluster_id & user if using provisioned clusters, or workgroup if using serverless"
89+ )
90+ elif cluster_id and workgroup :
91+ raise ValueError ("cannot specify both cluster_id and workgroup" )
92+
93+ return values
94+
7295
7396class RedshiftOfflineStore (OfflineStore ):
7497 @staticmethod
@@ -248,6 +271,7 @@ def query_generator() -> Iterator[str]:
248271 aws_utils .execute_redshift_statement (
249272 redshift_client ,
250273 config .offline_store .cluster_id ,
274+ config .offline_store .workgroup ,
251275 config .offline_store .database ,
252276 config .offline_store .user ,
253277 f"DROP TABLE IF EXISTS { table_name } " ,
@@ -294,6 +318,7 @@ def write_logged_features(
294318 table = data ,
295319 redshift_data_client = redshift_client ,
296320 cluster_id = config .offline_store .cluster_id ,
321+ workgroup = config .offline_store .workgroup ,
297322 database = config .offline_store .database ,
298323 user = config .offline_store .user ,
299324 s3_resource = s3_resource ,
@@ -336,8 +361,10 @@ def offline_write_batch(
336361 table = table ,
337362 redshift_data_client = redshift_client ,
338363 cluster_id = config .offline_store .cluster_id ,
364+ workgroup = config .offline_store .workgroup ,
339365 database = redshift_options .database
340- or config .offline_store .database , # Users can define database in the source if needed but it's not required.
366+ # Users can define database in the source if needed but it's not required.
367+ or config .offline_store .database ,
341368 user = config .offline_store .user ,
342369 s3_resource = s3_resource ,
343370 s3_path = f"{ config .offline_store .s3_staging_location } /push/{ uuid .uuid4 ()} .parquet" ,
@@ -405,6 +432,7 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
405432 return aws_utils .unload_redshift_query_to_df (
406433 self ._redshift_client ,
407434 self ._config .offline_store .cluster_id ,
435+ self ._config .offline_store .workgroup ,
408436 self ._config .offline_store .database ,
409437 self ._config .offline_store .user ,
410438 self ._s3_resource ,
@@ -419,6 +447,7 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
419447 return aws_utils .unload_redshift_query_to_pa (
420448 self ._redshift_client ,
421449 self ._config .offline_store .cluster_id ,
450+ self ._config .offline_store .workgroup ,
422451 self ._config .offline_store .database ,
423452 self ._config .offline_store .user ,
424453 self ._s3_resource ,
@@ -439,6 +468,7 @@ def to_s3(self) -> str:
439468 aws_utils .execute_redshift_query_and_unload_to_s3 (
440469 self ._redshift_client ,
441470 self ._config .offline_store .cluster_id ,
471+ self ._config .offline_store .workgroup ,
442472 self ._config .offline_store .database ,
443473 self ._config .offline_store .user ,
444474 self ._s3_path ,
@@ -455,6 +485,7 @@ def to_redshift(self, table_name: str) -> None:
455485 aws_utils .upload_df_to_redshift (
456486 self ._redshift_client ,
457487 self ._config .offline_store .cluster_id ,
488+ self ._config .offline_store .workgroup ,
458489 self ._config .offline_store .database ,
459490 self ._config .offline_store .user ,
460491 self ._s3_resource ,
@@ -471,6 +502,7 @@ def to_redshift(self, table_name: str) -> None:
471502 aws_utils .execute_redshift_statement (
472503 self ._redshift_client ,
473504 self ._config .offline_store .cluster_id ,
505+ self ._config .offline_store .workgroup ,
474506 self ._config .offline_store .database ,
475507 self ._config .offline_store .user ,
476508 query ,
@@ -509,6 +541,7 @@ def _upload_entity_df(
509541 aws_utils .upload_df_to_redshift (
510542 redshift_client ,
511543 config .offline_store .cluster_id ,
544+ config .offline_store .workgroup ,
512545 config .offline_store .database ,
513546 config .offline_store .user ,
514547 s3_resource ,
@@ -522,6 +555,7 @@ def _upload_entity_df(
522555 aws_utils .execute_redshift_statement (
523556 redshift_client ,
524557 config .offline_store .cluster_id ,
558+ config .offline_store .workgroup ,
525559 config .offline_store .database ,
526560 config .offline_store .user ,
527561 f"CREATE TABLE { table_name } AS ({ entity_df } )" ,
@@ -577,6 +611,7 @@ def _get_entity_df_event_timestamp_range(
577611 statement_id = aws_utils .execute_redshift_statement (
578612 redshift_client ,
579613 config .offline_store .cluster_id ,
614+ config .offline_store .workgroup ,
580615 config .offline_store .database ,
581616 config .offline_store .user ,
582617 f"SELECT MIN({ entity_df_event_timestamp_col } ) AS min, MAX({ entity_df_event_timestamp_col } ) AS max "
0 commit comments