@@ -16,6 +16,7 @@ class SnowflakeSource(DataSource):
1616 def __init__ (
1717 self ,
1818 database : Optional [str ] = None ,
19+ warehouse : Optional [str ] = None ,
1920 schema : Optional [str ] = None ,
2021 table : Optional [str ] = None ,
2122 query : Optional [str ] = None ,
@@ -33,6 +34,7 @@ def __init__(
3334
3435 Args:
3536 database (optional): Snowflake database where the features are stored.
37+ warehouse (optional): Snowflake warehouse where the database is stored.
3638 schema (optional): Snowflake schema in which the table is located.
3739 table (optional): Snowflake table where the features are stored.
3840 event_timestamp_column (optional): Event timestamp column used for point in
@@ -55,7 +57,11 @@ def __init__(
5557 _schema = "PUBLIC" if (database and table and not schema ) else schema
5658
5759 self .snowflake_options = SnowflakeOptions (
58- database = database , schema = _schema , table = table , query = query
60+ database = database ,
61+ schema = _schema ,
62+ table = table ,
63+ query = query ,
64+ warehouse = warehouse ,
5965 )
6066
6167 # If no name, use the table as the default name
@@ -107,6 +113,7 @@ def from_proto(data_source: DataSourceProto):
107113 database = data_source .snowflake_options .database ,
108114 schema = data_source .snowflake_options .schema ,
109115 table = data_source .snowflake_options .table ,
116+ warehouse = data_source .snowflake_options .warehouse ,
110117 event_timestamp_column = data_source .event_timestamp_column ,
111118 created_timestamp_column = data_source .created_timestamp_column ,
112119 query = data_source .snowflake_options .query ,
@@ -131,6 +138,7 @@ def __eq__(self, other):
131138 and self .snowflake_options .schema == other .snowflake_options .schema
132139 and self .snowflake_options .table == other .snowflake_options .table
133140 and self .snowflake_options .query == other .snowflake_options .query
141+ and self .snowflake_options .warehouse == other .snowflake_options .warehouse
134142 and self .event_timestamp_column == other .event_timestamp_column
135143 and self .created_timestamp_column == other .created_timestamp_column
136144 and self .field_mapping == other .field_mapping
@@ -159,6 +167,11 @@ def query(self):
159167 """Returns the snowflake options of this snowflake source."""
160168 return self .snowflake_options .query
161169
170+ @property
171+ def warehouse (self ):
172+ """Returns the warehouse of this snowflake source."""
173+ return self .snowflake_options .warehouse
174+
162175 def to_proto (self ) -> DataSourceProto :
163176 """
164177 Converts a SnowflakeSource object to its protobuf representation.
@@ -245,11 +258,13 @@ def __init__(
245258 schema : Optional [str ],
246259 table : Optional [str ],
247260 query : Optional [str ],
261+ warehouse : Optional [str ],
248262 ):
249263 self ._database = database
250264 self ._schema = schema
251265 self ._table = table
252266 self ._query = query
267+ self ._warehouse = warehouse
253268
254269 @property
255270 def query (self ):
@@ -291,6 +306,16 @@ def table(self, table):
291306 """Sets the table ref of this snowflake table."""
292307 self ._table = table
293308
309+ @property
310+ def warehouse (self ):
311+ """Returns the warehouse name of this snowflake table."""
312+ return self ._warehouse
313+
314+ @warehouse .setter
315+ def warehouse (self , warehouse ):
316+ """Sets the warehouse name of this snowflake table."""
317+ self ._warehouse = warehouse
318+
294319 @classmethod
295320 def from_proto (cls , snowflake_options_proto : DataSourceProto .SnowflakeOptions ):
296321 """
@@ -307,6 +332,7 @@ def from_proto(cls, snowflake_options_proto: DataSourceProto.SnowflakeOptions):
307332 schema = snowflake_options_proto .schema ,
308333 table = snowflake_options_proto .table ,
309334 query = snowflake_options_proto .query ,
335+ warehouse = snowflake_options_proto .warehouse ,
310336 )
311337
312338 return snowflake_options
@@ -323,6 +349,7 @@ def to_proto(self) -> DataSourceProto.SnowflakeOptions:
323349 schema = self .schema ,
324350 table = self .table ,
325351 query = self .query ,
352+ warehouse = self .warehouse ,
326353 )
327354
328355 return snowflake_options_proto
@@ -335,7 +362,7 @@ class SavedDatasetSnowflakeStorage(SavedDatasetStorage):
335362
336363 def __init__ (self , table_ref : str ):
337364 self .snowflake_options = SnowflakeOptions (
338- database = None , schema = None , table = table_ref , query = None
365+ database = None , schema = None , table = table_ref , query = None , warehouse = None
339366 )
340367
341368 @staticmethod
0 commit comments