diff --git a/docs/how-to-guides/dbt-integration.md b/docs/how-to-guides/dbt-integration.md index abaadbf8740..c85cf2508db 100644 --- a/docs/how-to-guides/dbt-integration.md +++ b/docs/how-to-guides/dbt-integration.md @@ -5,7 +5,6 @@ **Current Limitations**: - Supported data sources: BigQuery, Snowflake, and File-based sources only -- Single entity per model - Manual entity column specification required Breaking changes may occur in future releases. @@ -185,6 +184,53 @@ driver_features_fv = FeatureView( ``` {% endcode %} +## Multiple Entity Support + +The dbt integration supports feature views with multiple entities, useful for modeling relationships involving multiple keys. + +### Usage + +Specify multiple entity columns using repeated `-e` flags: + +```bash +feast dbt import \ + -m target/manifest.json \ + -e user_id \ + -e merchant_id \ + --tag feast \ + -o features/transactions.py +``` + +This creates a FeatureView with both `user_id` and `merchant_id` as entities, useful for: +- Transaction features keyed by both user and merchant +- Interaction features keyed by multiple parties +- Association tables in many-to-many relationships + +Single entity usage: +```bash +feast dbt import -m target/manifest.json -e driver_id --tag feast +``` + +### Requirements + +All specified entity columns must exist in each dbt model being imported. Models missing any entity column will be skipped with a warning. + +### Generated Code + +The `--output` flag generates code like: + +```python +user_id = Entity(name="user_id", join_keys=["user_id"], ...) +merchant_id = Entity(name="merchant_id", join_keys=["merchant_id"], ...) + +transaction_fv = FeatureView( + name="transactions", + entities=[user_id, merchant_id], # Multiple entities + schema=[...], + ... +) +``` + ## CLI Reference ### `feast dbt list` @@ -217,7 +263,7 @@ feast dbt import [OPTIONS] | Option | Description | Default | |--------|-------------|---------| -| `--entity-column`, `-e` | Column to use as entity key | (required) | +| `--entity-column`, `-e` | Entity column name (can be specified multiple times) | (required) | | `--data-source-type`, `-d` | Data source type: `bigquery`, `snowflake`, `file` | `bigquery` | | `--tag-filter`, `-t` | Filter models by dbt tag | None | | `--model`, `-m` | Import specific model(s) only | None | diff --git a/sdk/python/feast/cli/dbt_import.py b/sdk/python/feast/cli/dbt_import.py index b09fd90ec6d..c2e78b45c82 100644 --- a/sdk/python/feast/cli/dbt_import.py +++ b/sdk/python/feast/cli/dbt_import.py @@ -30,8 +30,10 @@ def dbt_cmd(): @click.option( "--entity-column", "-e", + "entity_columns", + multiple=True, required=True, - help="Primary key / entity column name (e.g., driver_id, customer_id)", + help="Entity column name (can be specified multiple times, e.g., -e user_id -e merchant_id)", ) @click.option( "--data-source-type", @@ -89,7 +91,7 @@ def dbt_cmd(): def import_command( ctx: click.Context, manifest_path: str, - entity_column: str, + entity_columns: tuple, data_source_type: str, timestamp_field: str, tag_filter: Optional[str], @@ -141,6 +143,28 @@ def import_command( if parser.project_name: click.echo(f" Project: {parser.project_name}") + # Convert tuple to list and validate + entity_cols: List[str] = list(entity_columns) if entity_columns else [] + + # Validation: At least one entity required (redundant with required=True but explicit) + if not entity_cols: + click.echo( + f"{Fore.RED}Error: At least one entity column required{Style.RESET_ALL}", + err=True, + ) + raise SystemExit(1) + + # Validation: No duplicate entity columns + if len(entity_cols) != len(set(entity_cols)): + duplicates = [col for col in entity_cols if entity_cols.count(col) > 1] + click.echo( + f"{Fore.RED}Error: Duplicate entity columns: {', '.join(set(duplicates))}{Style.RESET_ALL}", + err=True, + ) + raise SystemExit(1) + + click.echo(f"Entity columns: {', '.join(entity_cols)}") + # Get models with filters model_list: Optional[List[str]] = list(model_names) if model_names else None models = parser.get_models(model_names=model_list, tag_filter=tag_filter) @@ -188,24 +212,31 @@ def import_command( ) continue - # Validate entity column exists - if entity_column not in column_names: + # Validate ALL entity columns exist + missing_entities = [e for e in entity_cols if e not in column_names] + if missing_entities: click.echo( f"{Fore.YELLOW}Warning: Model '{model.name}' missing entity " - f"column '{entity_column}'. Skipping.{Style.RESET_ALL}" + f"column(s): {', '.join(missing_entities)}. Skipping.{Style.RESET_ALL}" ) continue - # Create or reuse entity - if entity_column not in entities_created: - entity = mapper.create_entity( - name=entity_column, - description="Entity key for dbt models", - ) - entities_created[entity_column] = entity - all_objects.append(entity) - else: - entity = entities_created[entity_column] + # Create or reuse entities (one per entity column) + model_entities: List[Any] = [] + for entity_col in entity_cols: + if entity_col not in entities_created: + # Use mapper's internal method for value type inference + entity_value_type = mapper._infer_entity_value_type(model, entity_col) + entity = mapper.create_entity( + name=entity_col, + description="Entity key for dbt models", + value_type=entity_value_type, + ) + entities_created[entity_col] = entity + all_objects.append(entity) + else: + entity = entities_created[entity_col] + model_entities.append(entity) # Create data source data_source = mapper.create_data_source( @@ -218,8 +249,8 @@ def import_command( feature_view = mapper.create_feature_view( model=model, source=data_source, - entity_column=entity_column, - entity=entity, + entity_columns=entity_cols, + entities=model_entities, timestamp_field=timestamp_field, ttl_days=ttl_days, exclude_columns=excluded, @@ -242,7 +273,7 @@ def import_command( m for m in models if timestamp_field in [c.name for c in m.columns] - and entity_column in [c.name for c in m.columns] + and all(e in [c.name for c in m.columns] for e in entity_cols) ] # Summary @@ -257,7 +288,7 @@ def import_command( code = generate_feast_code( models=valid_models, - entity_column=entity_column, + entity_columns=entity_cols, data_source_type=data_source_type, timestamp_field=timestamp_field, ttl_days=ttl_days, diff --git a/sdk/python/feast/dbt/codegen.py b/sdk/python/feast/dbt/codegen.py index 1c7acfb944c..affc38fe9e2 100644 --- a/sdk/python/feast/dbt/codegen.py +++ b/sdk/python/feast/dbt/codegen.py @@ -6,7 +6,7 @@ """ import logging -from typing import Any, List, Optional, Set +from typing import Any, List, Optional, Set, Union from jinja2 import BaseLoader, Environment @@ -106,7 +106,7 @@ {% for fv in feature_views %} {{ fv.var_name }} = FeatureView( name="{{ fv.name }}", - entities=[{{ fv.entity_var }}], + entities=[{{ fv.entity_vars | join(', ') }}], ttl=timedelta(days={{ fv.ttl_days }}), schema=[ {% for field in fv.fields %} @@ -220,7 +220,7 @@ def __init__( def generate( self, models: List[DbtModel], - entity_column: str, + entity_columns: Union[str, List[str]], manifest_path: str = "", project_name: str = "", exclude_columns: Optional[List[str]] = None, @@ -231,7 +231,7 @@ def generate( Args: models: List of DbtModel objects to generate code for - entity_column: The entity/primary key column name + entity_columns: Entity column name(s) - single string or list of strings manifest_path: Path to the dbt manifest (for documentation) project_name: dbt project name (for documentation) exclude_columns: Columns to exclude from features @@ -240,25 +240,36 @@ def generate( Returns: Generated Python code as a string """ - excluded = {entity_column, self.timestamp_field} + # Normalize entity_columns to list + entity_cols: List[str] = ( + [entity_columns] if isinstance(entity_columns, str) else entity_columns + ) + + if not entity_cols: + raise ValueError("At least one entity column must be specified") + + excluded = set(entity_cols) | {self.timestamp_field} if exclude_columns: excluded.update(exclude_columns) # Collect all Feast types used for imports type_imports: Set[str] = set() - # Prepare entity data + # Prepare entity data - create one entity per entity column entities = [] - entity_var = _make_var_name(entity_column) - entities.append( - { - "var_name": entity_var, - "name": entity_column, - "join_key": entity_column, - "description": "Entity key for dbt models", - "tags": {"source": "dbt"}, - } - ) + entity_vars = [] # Track variable names for feature views + for entity_col in entity_cols: + entity_var = _make_var_name(entity_col) + entity_vars.append(entity_var) + entities.append( + { + "var_name": entity_var, + "name": entity_col, + "join_key": entity_col, + "description": "Entity key for dbt models", + "tags": {"source": "dbt"}, + } + ) # Prepare data sources and feature views data_sources = [] @@ -269,7 +280,9 @@ def generate( column_names = [c.name for c in model.columns] if self.timestamp_field not in column_names: continue - if entity_column not in column_names: + + # Skip if ANY entity column is missing + if not all(e in column_names for e in entity_cols): continue # Build tags @@ -339,7 +352,7 @@ def generate( { "var_name": fv_var, "name": model.name, - "entity_var": entity_var, + "entity_vars": entity_vars, "source_var": source_var, "ttl_days": self.ttl_days, "fields": fields, @@ -366,7 +379,7 @@ def generate( def generate_feast_code( models: List[DbtModel], - entity_column: str, + entity_columns: Union[str, List[str]], data_source_type: str = "bigquery", timestamp_field: str = "event_timestamp", ttl_days: int = 1, @@ -380,7 +393,7 @@ def generate_feast_code( Args: models: List of DbtModel objects - entity_column: Primary key column name + entity_columns: Entity column name(s) - single string or list of strings data_source_type: Type of data source (bigquery, snowflake, file) timestamp_field: Timestamp column name ttl_days: TTL in days for feature views @@ -400,7 +413,7 @@ def generate_feast_code( return generator.generate( models=models, - entity_column=entity_column, + entity_columns=entity_columns, manifest_path=manifest_path, project_name=project_name, exclude_columns=exclude_columns, diff --git a/sdk/python/feast/dbt/mapper.py b/sdk/python/feast/dbt/mapper.py index 2d6d63fbd32..fd41d886cf1 100644 --- a/sdk/python/feast/dbt/mapper.py +++ b/sdk/python/feast/dbt/mapper.py @@ -26,6 +26,24 @@ ) from feast.value_type import ValueType +# Mapping from FeastType to ValueType for entity value inference +FEAST_TYPE_TO_VALUE_TYPE: Dict[FeastType, ValueType] = { + String: ValueType.STRING, + Int32: ValueType.INT64, + Int64: ValueType.INT64, + Float32: ValueType.DOUBLE, + Float64: ValueType.DOUBLE, + Bool: ValueType.BOOL, + Bytes: ValueType.BYTES, + UnixTimestamp: ValueType.UNIX_TIMESTAMP, +} + + +def feast_type_to_value_type(feast_type: FeastType) -> ValueType: + """Convert a FeastType to its corresponding ValueType for entities.""" + return FEAST_TYPE_TO_VALUE_TYPE.get(feast_type, ValueType.STRING) + + # Comprehensive mapping from dbt/warehouse types to Feast types # Covers BigQuery, Snowflake, Redshift, PostgreSQL, and common SQL types DBT_TO_FEAST_TYPE_MAP: Dict[str, FeastType] = { @@ -180,6 +198,14 @@ def __init__( self.timestamp_field = timestamp_field self.ttl_days = ttl_days + def _infer_entity_value_type(self, model: DbtModel, entity_col: str) -> ValueType: + """Infer entity ValueType from dbt model column type.""" + for column in model.columns: + if column.name == entity_col: + feast_type = map_dbt_type_to_feast_type(column.data_type) + return feast_type_to_value_type(feast_type) + return ValueType.UNKNOWN + def create_data_source( self, model: DbtModel, @@ -285,8 +311,8 @@ def create_feature_view( self, model: DbtModel, source: Any, - entity_column: str, - entity: Optional[Entity] = None, + entity_columns: Union[str, List[str]], + entities: Optional[Union[Entity, List[Entity]]] = None, timestamp_field: Optional[str] = None, ttl_days: Optional[int] = None, exclude_columns: Optional[List[str]] = None, @@ -298,8 +324,8 @@ def create_feature_view( Args: model: The DbtModel to create a FeatureView from source: The DataSource for this FeatureView - entity_column: The entity/primary key column name - entity: Optional pre-created Entity (created if not provided) + entity_columns: Entity column name(s) - single string or list of strings + entities: Optional pre-created Entity or list of Entities timestamp_field: Override the default timestamp field ttl_days: Override the default TTL in days exclude_columns: Additional columns to exclude from features @@ -308,15 +334,38 @@ def create_feature_view( Returns: A Feast FeatureView """ + # Normalize to lists + entity_cols: List[str] = ( + [entity_columns] + if isinstance(entity_columns, str) + else list(entity_columns) + ) + + entity_objs: List[Entity] = [] + if entities is not None: + entity_objs = [entities] if isinstance(entities, Entity) else list(entities) + + # Validate + if not entity_cols: + raise ValueError("At least one entity column must be specified") + + if entity_objs and len(entity_cols) != len(entity_objs): + raise ValueError( + f"Number of entity_columns ({len(entity_cols)}) must match " + f"number of entities ({len(entity_objs)})" + ) + ts_field = timestamp_field or self.timestamp_field ttl = timedelta(days=ttl_days if ttl_days is not None else self.ttl_days) - # Columns to exclude from features - excluded = {entity_column, ts_field} + # Columns to exclude from schema (timestamp + any explicitly excluded) + # Note: entity columns should NOT be excluded - FeatureView.__init__ + # expects entity columns to be in the schema and will extract them + excluded = {ts_field} if exclude_columns: excluded.update(exclude_columns) - # Create schema from model columns + # Create schema from model columns (includes entity columns) schema: List[Field] = [] for column in model.columns: if column.name not in excluded: @@ -329,12 +378,18 @@ def create_feature_view( ) ) - # Create entity if not provided - if entity is None: - entity = self.create_entity( - name=entity_column, - description=f"Entity for {model.name}", - ) + # Create entities if not provided + if not entity_objs: + entity_objs = [] + for entity_col in entity_cols: + # Infer entity value type from model column + entity_value_type = self._infer_entity_value_type(model, entity_col) + ent = self.create_entity( + name=entity_col, + description=f"Entity for {model.name}", + value_type=entity_value_type, + ) + entity_objs.append(ent) # Build tags from dbt metadata tags = { @@ -348,7 +403,7 @@ def create_feature_view( name=model.name, source=source, schema=schema, - entities=[entity], + entities=entity_objs, ttl=ttl, online=online, description=model.description, @@ -358,12 +413,12 @@ def create_feature_view( def create_all_from_model( self, model: DbtModel, - entity_column: str, + entity_columns: Union[str, List[str]], timestamp_field: Optional[str] = None, ttl_days: Optional[int] = None, exclude_columns: Optional[List[str]] = None, online: bool = True, - ) -> Dict[str, Union[Entity, Any, FeatureView]]: + ) -> Dict[str, Union[List[Entity], Any, FeatureView]]: """ Create all Feast objects (DataSource, Entity, FeatureView) from a dbt model. @@ -372,22 +427,34 @@ def create_all_from_model( Args: model: The DbtModel to create objects from - entity_column: The entity/primary key column name + entity_columns: Entity column name(s) - single string or list of strings timestamp_field: Override the default timestamp field ttl_days: Override the default TTL in days exclude_columns: Additional columns to exclude from features online: Whether to enable online serving Returns: - Dict with keys 'entity', 'data_source', 'feature_view' + Dict with keys 'entities', 'data_source', 'feature_view' """ - # Create entity - entity = self.create_entity( - name=entity_column, - description=f"Entity for {model.name}", - tags={"dbt.model": model.name}, + # Normalize to list + entity_cols: List[str] = ( + [entity_columns] + if isinstance(entity_columns, str) + else list(entity_columns) ) + # Create entities (plural) + entities_list = [] + for entity_col in entity_cols: + entity_value_type = self._infer_entity_value_type(model, entity_col) + entity = self.create_entity( + name=entity_col, + description=f"Entity for {model.name}", + tags={"dbt.model": model.name}, + value_type=entity_value_type, + ) + entities_list.append(entity) + # Create data source data_source = self.create_data_source( model=model, @@ -398,8 +465,8 @@ def create_all_from_model( feature_view = self.create_feature_view( model=model, source=data_source, - entity_column=entity_column, - entity=entity, + entity_columns=entity_cols, + entities=entities_list, timestamp_field=timestamp_field, ttl_days=ttl_days, exclude_columns=exclude_columns, @@ -407,7 +474,7 @@ def create_all_from_model( ) return { - "entity": entity, + "entities": entities_list, "data_source": data_source, "feature_view": feature_view, } diff --git a/sdk/python/tests/unit/dbt/test_mapper.py b/sdk/python/tests/unit/dbt/test_mapper.py index 809c4b43b8e..4f06ba47fc7 100644 --- a/sdk/python/tests/unit/dbt/test_mapper.py +++ b/sdk/python/tests/unit/dbt/test_mapper.py @@ -235,7 +235,7 @@ def test_create_feature_view(self, sample_model): fv = mapper.create_feature_view( model=sample_model, source=source, - entity_column="driver_id", + entity_columns="driver_id", ) assert fv.name == "driver_stats" @@ -262,7 +262,7 @@ def test_create_feature_view_with_exclude(self, sample_model): fv = mapper.create_feature_view( model=sample_model, source=source, - entity_column="driver_id", + entity_columns="driver_id", exclude_columns=["is_active"], ) @@ -276,14 +276,15 @@ def test_create_all_from_model(self, sample_model): mapper = DbtToFeastMapper(data_source_type="bigquery") result = mapper.create_all_from_model( model=sample_model, - entity_column="driver_id", + entity_columns="driver_id", ) - assert "entity" in result + assert "entities" in result assert "data_source" in result assert "feature_view" in result - assert result["entity"].name == "driver_id" + assert len(result["entities"]) == 1 + assert result["entities"][0].name == "driver_id" assert result["data_source"].name == "driver_stats_source" assert result["feature_view"].name == "driver_stats" @@ -294,7 +295,7 @@ def test_feature_type_mapping(self, sample_model): fv = mapper.create_feature_view( model=sample_model, source=source, - entity_column="driver_id", + entity_columns="driver_id", ) # Find specific features and check types