@@ -27,6 +27,7 @@ def __init__(
2727 description : Optional [str ] = "" ,
2828 tags : Optional [Dict [str , str ]] = None ,
2929 owner : Optional [str ] = "" ,
30+ database : Optional [str ] = "" ,
3031 ):
3132 """
3233 Creates a RedshiftSource object.
@@ -47,11 +48,12 @@ def __init__(
4748 tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
4849 owner (optional): The owner of the redshift source, typically the email of the primary
4950 maintainer.
51+ database (optional): The Redshift database name.
5052 """
5153 # The default Redshift schema is named "public".
5254 _schema = "public" if table and not schema else schema
5355 self .redshift_options = RedshiftOptions (
54- table = table , schema = _schema , query = query
56+ table = table , schema = _schema , query = query , database = database
5557 )
5658
5759 if table is None and query is None :
@@ -102,6 +104,7 @@ def from_proto(data_source: DataSourceProto):
102104 description = data_source .description ,
103105 tags = dict (data_source .tags ),
104106 owner = data_source .owner ,
107+ database = data_source .redshift_options .database ,
105108 )
106109
107110 # Note: Python requires redefining hash in child classes that override __eq__
@@ -119,6 +122,7 @@ def __eq__(self, other):
119122 and self .redshift_options .table == other .redshift_options .table
120123 and self .redshift_options .schema == other .redshift_options .schema
121124 and self .redshift_options .query == other .redshift_options .query
125+ and self .redshift_options .database == other .redshift_options .database
122126 and self .event_timestamp_column == other .event_timestamp_column
123127 and self .created_timestamp_column == other .created_timestamp_column
124128 and self .field_mapping == other .field_mapping
@@ -139,9 +143,14 @@ def schema(self):
139143
140144 @property
141145 def query (self ):
142- """Returns the Redshift options of this Redshift source."""
146+ """Returns the Redshift query of this Redshift source."""
143147 return self .redshift_options .query
144148
149+ @property
150+ def database (self ):
151+ """Returns the Redshift database of this Redshift source."""
152+ return self .redshift_options .database
153+
145154 def to_proto (self ) -> DataSourceProto :
146155 """
147156 Converts a RedshiftSource object to its protobuf representation.
@@ -197,12 +206,15 @@ def get_table_column_names_and_types(
197206 assert isinstance (config .offline_store , RedshiftOfflineStoreConfig )
198207
199208 client = aws_utils .get_redshift_data_client (config .offline_store .region )
200-
201209 if self .table is not None :
202210 try :
203211 table = client .describe_table (
204212 ClusterIdentifier = config .offline_store .cluster_id ,
205- Database = config .offline_store .database ,
213+ Database = (
214+ self .database
215+ if self .database
216+ else config .offline_store .database
217+ ),
206218 DbUser = config .offline_store .user ,
207219 Table = self .table ,
208220 Schema = self .schema ,
@@ -221,7 +233,7 @@ def get_table_column_names_and_types(
221233 statement_id = aws_utils .execute_redshift_statement (
222234 client ,
223235 config .offline_store .cluster_id ,
224- config .offline_store .database ,
236+ self . database if self . database else config .offline_store .database ,
225237 config .offline_store .user ,
226238 f"SELECT * FROM ({ self .query } ) LIMIT 1" ,
227239 )
@@ -238,11 +250,16 @@ class RedshiftOptions:
238250 """
239251
240252 def __init__ (
241- self , table : Optional [str ], schema : Optional [str ], query : Optional [str ]
253+ self ,
254+ table : Optional [str ],
255+ schema : Optional [str ],
256+ query : Optional [str ],
257+ database : Optional [str ],
242258 ):
243259 self ._table = table
244260 self ._schema = schema
245261 self ._query = query
262+ self ._database = database
246263
247264 @property
248265 def query (self ):
@@ -274,6 +291,16 @@ def schema(self, schema):
274291 """Sets the schema of this Redshift table."""
275292 self ._schema = schema
276293
294+ @property
295+ def database (self ):
296+ """Returns the schema name of this Redshift table."""
297+ return self ._database
298+
299+ @database .setter
300+ def database (self , database ):
301+ """Sets the database name of this Redshift table."""
302+ self ._database = database
303+
277304 @classmethod
278305 def from_proto (cls , redshift_options_proto : DataSourceProto .RedshiftOptions ):
279306 """
@@ -289,6 +316,7 @@ def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
289316 table = redshift_options_proto .table ,
290317 schema = redshift_options_proto .schema ,
291318 query = redshift_options_proto .query ,
319+ database = redshift_options_proto .database ,
292320 )
293321
294322 return redshift_options
@@ -301,7 +329,10 @@ def to_proto(self) -> DataSourceProto.RedshiftOptions:
301329 A RedshiftOptionsProto protobuf.
302330 """
303331 redshift_options_proto = DataSourceProto .RedshiftOptions (
304- table = self .table , schema = self .schema , query = self .query ,
332+ table = self .table ,
333+ schema = self .schema ,
334+ query = self .query ,
335+ database = self .database ,
305336 )
306337
307338 return redshift_options_proto
@@ -314,7 +345,7 @@ class SavedDatasetRedshiftStorage(SavedDatasetStorage):
314345
315346 def __init__ (self , table_ref : str ):
316347 self .redshift_options = RedshiftOptions (
317- table = table_ref , schema = None , query = None
348+ table = table_ref , schema = None , query = None , database = None
318349 )
319350
320351 @staticmethod
0 commit comments