Skip to content
Prev Previous commit
Next Next commit
format
Signed-off-by: cmuhao <sduxuhao@gmail.com>
  • Loading branch information
HaoXuAI committed May 9, 2024
commit d3619f65abbbe4f237e1091dab1838b4dbe0ec0c
73 changes: 37 additions & 36 deletions sdk/python/feast/infra/online_stores/contrib/elastichsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def _get_client(self, config: RepoConfig) -> Elasticsearch:
online_store_config = config.online_store
assert isinstance(online_store_config, ElasticsearchOnlineStoreConfig)

if not self._client:
if self._client:
return self._client
else:
self._client = Elasticsearch(
hosts=[
{
Expand All @@ -48,6 +50,7 @@ def _get_client(self, config: RepoConfig) -> Elasticsearch:
],
http_auth=(online_store_config.user, online_store_config.password),
)
return self._client

def create_index(self, config: RepoConfig, table: FeatureView):
pass
Expand All @@ -61,13 +64,13 @@ def _bulk_batch_actions(self, batch):
}

def online_write_batch(
self,
config: RepoConfig,
table: FeatureView,
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
self,
config: RepoConfig,
table: FeatureView,
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
) -> None:
insert_values = []
for entity_key, values, timestamp, created_ts in data:
Expand All @@ -93,16 +96,16 @@ def online_write_batch(

batch_size = config.online_config.batch_size
for i in range(0, len(insert_values), batch_size):
batch = insert_values[i : i + batch_size]
batch = insert_values[i: i + batch_size]
actions = self._bulk_batch_actions(batch)
helpers.bulk(self._client, actions)

def online_read(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
if not requested_features:
body = {
Expand All @@ -121,27 +124,25 @@ def online_read(
}
},
}
response = self._client.search(index=self._index, body=body)
response = self._get_client(config).search(index=self._index, body=body)
results: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
for hit in response["hits"]["hits"]:
results.append(
(
hit["_source"]["timestamp"],
{
hit["_source"]["feature_name"]: hit["_source"]["feature_value"]
},
{hit["_source"]["feature_name"]: hit["_source"]["feature_value"]},
)
)
return results

def update(
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
):
# implement the update method
for table in tables_to_delete:
Expand All @@ -150,20 +151,20 @@ def update(
self.create_index(config, table)

def teardown(
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
):
pass

def retrieve_online_documents(
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
self,
config: RepoConfig,
table: FeatureView,
requested_feature: str,
embedding: List[float],
top_k: int,
) -> List[
Tuple[
Optional[datetime],
Expand All @@ -180,7 +181,7 @@ def retrieve_online_documents(
Optional[ValueProto],
]
] = []
reponse = self._client.search(
reponse = self._get_client(config).search(
index=self._index,
knn={
"field": requested_feature,
Expand Down