|
26 | 26 | from feast.infra.online_stores.helpers import compute_entity_id |
27 | 27 | from feast.infra.online_stores.online_store import OnlineStore |
28 | 28 | from feast.infra.supported_async_methods import SupportedAsyncMethods |
| 29 | +from feast.infra.utils.aws_utils import dynamo_write_items_async |
29 | 30 | from feast.protos.feast.core.DynamoDBTable_pb2 import ( |
30 | 31 | DynamoDBTable as DynamoDBTableProto, |
31 | 32 | ) |
@@ -103,7 +104,7 @@ async def close(self): |
103 | 104 |
|
104 | 105 | @property |
105 | 106 | def async_supported(self) -> SupportedAsyncMethods: |
106 | | - return SupportedAsyncMethods(read=True) |
| 107 | + return SupportedAsyncMethods(read=True, write=True) |
107 | 108 |
|
108 | 109 | def update( |
109 | 110 | self, |
@@ -238,6 +239,42 @@ def online_write_batch( |
238 | 239 | ) |
239 | 240 | self._write_batch_non_duplicates(table_instance, data, progress, config) |
240 | 241 |
|
| 242 | + async def online_write_batch_async( |
| 243 | + self, |
| 244 | + config: RepoConfig, |
| 245 | + table: FeatureView, |
| 246 | + data: List[ |
| 247 | + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] |
| 248 | + ], |
| 249 | + progress: Optional[Callable[[int], Any]], |
| 250 | + ) -> None: |
| 251 | + """ |
| 252 | + Writes a batch of feature rows to the online store asynchronously. |
| 253 | +
|
| 254 | + If a tz-naive timestamp is passed to this method, it is assumed to be UTC. |
| 255 | +
|
| 256 | + Args: |
| 257 | + config: The config for the current feature store. |
| 258 | + table: Feature view to which these feature rows correspond. |
| 259 | + data: A list of quadruplets containing feature data. Each quadruplet contains an entity |
| 260 | + key, a dict containing feature values, an event timestamp for the row, and the created |
| 261 | + timestamp for the row if it exists. |
| 262 | + progress: Function to be called once a batch of rows is written to the online store, used |
| 263 | + to show progress. |
| 264 | + """ |
| 265 | + online_config = config.online_store |
| 266 | + assert isinstance(online_config, DynamoDBOnlineStoreConfig) |
| 267 | + |
| 268 | + table_name = _get_table_name(online_config, config, table) |
| 269 | + items = [ |
| 270 | + _to_write_item(config, entity_key, features, timestamp) |
| 271 | + for entity_key, features, timestamp, _ in data |
| 272 | + ] |
| 273 | + client = _get_aiodynamodb_client( |
| 274 | + online_config.region, config.online_store.max_pool_connections |
| 275 | + ) |
| 276 | + await dynamo_write_items_async(client, table_name, items) |
| 277 | + |
241 | 278 | def online_read( |
242 | 279 | self, |
243 | 280 | config: RepoConfig, |
@@ -419,19 +456,8 @@ def _write_batch_non_duplicates( |
419 | 456 | """Deduplicate write batch request items on ``entity_id`` primary key.""" |
420 | 457 | with table_instance.batch_writer(overwrite_by_pkeys=["entity_id"]) as batch: |
421 | 458 | for entity_key, features, timestamp, created_ts in data: |
422 | | - entity_id = compute_entity_id( |
423 | | - entity_key, |
424 | | - entity_key_serialization_version=config.entity_key_serialization_version, |
425 | | - ) |
426 | 459 | batch.put_item( |
427 | | - Item={ |
428 | | - "entity_id": entity_id, # PartitionKey |
429 | | - "event_ts": str(utils.make_tzaware(timestamp)), |
430 | | - "values": { |
431 | | - k: v.SerializeToString() |
432 | | - for k, v in features.items() # Serialized Features |
433 | | - }, |
434 | | - } |
| 460 | + Item=_to_write_item(config, entity_key, features, timestamp) |
435 | 461 | ) |
436 | 462 | if progress: |
437 | 463 | progress(1) |
@@ -675,3 +701,18 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None |
675 | 701 | region, endpoint_url |
676 | 702 | ) |
677 | 703 | return self._dynamodb_resource |
| 704 | + |
| 705 | + |
| 706 | +def _to_write_item(config, entity_key, features, timestamp): |
| 707 | + entity_id = compute_entity_id( |
| 708 | + entity_key, |
| 709 | + entity_key_serialization_version=config.entity_key_serialization_version, |
| 710 | + ) |
| 711 | + return { |
| 712 | + "entity_id": entity_id, # PartitionKey |
| 713 | + "event_ts": str(utils.make_tzaware(timestamp)), |
| 714 | + "values": { |
| 715 | + k: v.SerializeToString() |
| 716 | + for k, v in features.items() # Serialized Features |
| 717 | + }, |
| 718 | + } |
0 commit comments