Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
chore: allow for some duplication to make await cleaner/more obvious
Signed-off-by: robhowley <rhowley@seatgeek.com>
  • Loading branch information
robhowley committed May 31, 2024
commit 8c164c36aecfae2de9705df957bbc2c107992590
168 changes: 89 additions & 79 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,59 +208,6 @@ def online_write_batch(
)
self._write_batch_non_duplicates(table_instance, data, progress, config)

def _read_batches(
self, online_config, entity_ids, table_name, batch_get_item
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []

batch_size = online_config.batch_size
entity_ids_iter = iter(entity_ids)
while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))
batch_result: List[
Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]
] = []
# No more items to insert
if len(batch) == 0:
break
batch_entity_ids = {
table_name: {
"Keys": [{"entity_id": entity_id} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
}
response = batch_get_item(
RequestItems=batch_entity_ids,
)
response = response.get("Responses")
table_responses = response.get(table_name)
if table_responses:
table_responses = self._sort_dynamodb_response(
table_responses, entity_ids
)
entity_idx = 0
for tbl_res in table_responses:
entity_id = tbl_res["entity_id"]
while entity_id != batch[entity_idx]:
batch_result.append((None, None))
entity_idx += 1
res = {}
for feature_name, value_bin in tbl_res["values"].items():
val = ValueProto()
val.ParseFromString(value_bin.value)
res[feature_name] = val
batch_result.append(
(datetime.fromisoformat(tbl_res["event_ts"]), res)
)
entity_idx += 1

# Not all entities in a batch may have responses
# Pad with remaining values in batch that were not found
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result))
batch_result.extend(batch_size_nones)
result.extend(batch_result)
return result

def online_read(
self,
config: RepoConfig,
Expand All @@ -278,27 +225,36 @@ def online_read(
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)

dynamodb_resource = self._get_dynamodb_resource(
online_config.region, online_config.endpoint_url
)
table_instance = dynamodb_resource.Table(
_get_table_name(online_config, config, table)
)

entity_ids = [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]
batch_size = online_config.batch_size
entity_ids = self._to_entity_ids(config, entity_keys)
entity_ids_iter = iter(entity_ids)
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []

return self._read_batches(
online_config,
entity_ids,
table_instance.name,
dynamodb_resource.batch_get_item,
)
while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))

# No more items to insert
if len(batch) == 0:
break
batch_entity_ids = self._to_batch_get_payload(
online_config, table_instance.name, batch
)
response = dynamodb_resource.batch_get_item(
RequestItems=batch_entity_ids,
)
batch_result = self._process_batch_get_response(
table_instance.name, response, entity_ids, batch
)
result.extend(batch_result)
return result

async def online_read_async(
self,
Expand All @@ -324,21 +280,30 @@ async def online_read_async(
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)

entity_ids = [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]
batch_size = online_config.batch_size
entity_ids = self._to_entity_ids(config, entity_keys)
entity_ids_iter = iter(entity_ids)
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
table_name = _get_table_name(online_config, config, table)

async with self._get_aiodynamodb_client(online_config.region) as client:
return self._read_batches(
online_config,
entity_ids,
_get_table_name(online_config, config, table),
lambda **kwargs: (await client(**kwargs)),
)
while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))

# No more items to insert
if len(batch) == 0:
break
batch_entity_ids = self._to_batch_get_payload(
online_config, table_name, batch
)
response = await client.batch_get_item(
RequestItems=batch_entity_ids,
)
batch_result = self._process_batch_get_response(
table_name, response, entity_ids, batch
)
result.extend(batch_result)
return result

def _get_aioboto_session(self):
if self._aioboto_session is None:
Expand Down Expand Up @@ -403,6 +368,51 @@ def _write_batch_non_duplicates(
if progress:
progress(1)

def _process_batch_get_response(self, table_name, response, entity_ids, batch):
response = response.get("Responses")
table_responses = response.get(table_name)

batch_result = []
if table_responses:
table_responses = self._sort_dynamodb_response(table_responses, entity_ids)
entity_idx = 0
for tbl_res in table_responses:
entity_id = tbl_res["entity_id"]
while entity_id != batch[entity_idx]:
batch_result.append((None, None))
entity_idx += 1
res = {}
for feature_name, value_bin in tbl_res["values"].items():
val = ValueProto()
val.ParseFromString(value_bin.value)
res[feature_name] = val
batch_result.append((datetime.fromisoformat(tbl_res["event_ts"]), res))
entity_idx += 1
# Not all entities in a batch may have responses
# Pad with remaining values in batch that were not found
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result))
batch_result.extend(batch_size_nones)
return batch_result

@staticmethod
def _to_entity_ids(config: RepoConfig, entity_keys: List[EntityKeyProto]):
return [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]

@staticmethod
def _to_batch_get_payload(online_config, table_name, batch):
return {
table_name: {
"Keys": [{"entity_id": entity_id} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
}


def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None):
return boto3.client(
Expand Down