Skip to content

Commit 7c53177

Browse files
authored
Fix inference of BigQuery ARRAY types. (feast-dev#2245)
* Support more BigQuery ARRAY types Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com> * Correctly infer BigQuery ARRAY types Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com>
1 parent 2080fa3 commit 7c53177

2 files changed

Lines changed: 21 additions & 15 deletions

File tree

sdk/python/feast/infra/offline_stores/bigquery_source.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Dict, Iterable, Optional, Tuple
1+
from typing import Callable, Dict, Iterable, List, Optional, Tuple
22

33
from feast import type_map
44
from feast.data_source import DataSource
@@ -123,18 +123,20 @@ def get_table_column_names_and_types(
123123

124124
client = bigquery.Client()
125125
if self.table_ref is not None:
126-
table_schema = client.get_table(self.table_ref).schema
127-
if not isinstance(table_schema[0], bigquery.schema.SchemaField):
126+
schema = client.get_table(self.table_ref).schema
127+
if not isinstance(schema[0], bigquery.schema.SchemaField):
128128
raise TypeError("Could not parse BigQuery table schema.")
129-
130-
name_type_pairs = [(field.name, field.field_type) for field in table_schema]
131129
else:
132130
bq_columns_query = f"SELECT * FROM ({self.query}) LIMIT 1"
133131
queryRes = client.query(bq_columns_query).result()
134-
name_type_pairs = [
135-
(schema_field.name, schema_field.field_type)
136-
for schema_field in queryRes.schema
137-
]
132+
schema = queryRes.schema
133+
134+
name_type_pairs: List[Tuple[str, str]] = []
135+
for field in schema:
136+
bq_type_as_str = field.field_type
137+
if field.mode == "REPEATED":
138+
bq_type_as_str = "ARRAY<" + bq_type_as_str + ">"
139+
name_type_pairs.append((field.name, bq_type_as_str))
138140

139141
return name_type_pairs
140142

sdk/python/feast/type_map.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,11 @@ def pa_to_feast_value_type(pa_type_as_str: str) -> ValueType:
442442

443443

444444
def bq_to_feast_value_type(bq_type_as_str: str) -> ValueType:
445+
is_list = False
446+
if bq_type_as_str.startswith("ARRAY<"):
447+
is_list = True
448+
bq_type_as_str = bq_type_as_str[6:-1]
449+
445450
type_map: Dict[str, ValueType] = {
446451
"DATETIME": ValueType.UNIX_TIMESTAMP,
447452
"TIMESTAMP": ValueType.UNIX_TIMESTAMP,
@@ -453,15 +458,14 @@ def bq_to_feast_value_type(bq_type_as_str: str) -> ValueType:
453458
"BYTES": ValueType.BYTES,
454459
"BOOL": ValueType.BOOL,
455460
"BOOLEAN": ValueType.BOOL, # legacy sql data type
456-
"ARRAY<INT64>": ValueType.INT64_LIST,
457-
"ARRAY<FLOAT64>": ValueType.DOUBLE_LIST,
458-
"ARRAY<STRING>": ValueType.STRING_LIST,
459-
"ARRAY<BYTES>": ValueType.BYTES_LIST,
460-
"ARRAY<BOOL>": ValueType.BOOL_LIST,
461461
"NULL": ValueType.NULL,
462462
}
463463

464-
return type_map[bq_type_as_str]
464+
value_type = type_map[bq_type_as_str]
465+
if is_list:
466+
value_type = ValueType[value_type.name + "_LIST"]
467+
468+
return value_type
465469

466470

467471
def redshift_to_feast_value_type(redshift_type_as_str: str) -> ValueType:

0 commit comments

Comments
 (0)