|
| 1 | +import json |
| 2 | +import struct |
| 3 | +from datetime import datetime |
| 4 | +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union |
| 5 | + |
| 6 | +import mmh3 |
| 7 | +import pandas as pd |
| 8 | +from google.protobuf.timestamp_pb2 import Timestamp |
| 9 | + |
| 10 | +try: |
| 11 | + from redis import Redis |
| 12 | + from rediscluster import RedisCluster |
| 13 | +except ImportError as e: |
| 14 | + from feast.errors import FeastExtrasDependencyImportError |
| 15 | + |
| 16 | + raise FeastExtrasDependencyImportError("redis", str(e)) |
| 17 | + |
| 18 | +from tqdm import tqdm |
| 19 | + |
| 20 | +from feast import FeatureTable, utils |
| 21 | +from feast.entity import Entity |
| 22 | +from feast.feature_view import FeatureView |
| 23 | +from feast.infra.offline_stores.helpers import get_offline_store_from_config |
| 24 | +from feast.infra.provider import ( |
| 25 | + Provider, |
| 26 | + RetrievalJob, |
| 27 | + _convert_arrow_to_proto, |
| 28 | + _get_column_names, |
| 29 | + _run_field_mapping, |
| 30 | +) |
| 31 | +from feast.protos.feast.storage.Redis_pb2 import RedisKeyV2 as RedisKeyProto |
| 32 | +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto |
| 33 | +from feast.protos.feast.types.Value_pb2 import Value as ValueProto |
| 34 | +from feast.registry import Registry |
| 35 | +from feast.repo_config import RedisOnlineStoreConfig, RedisType, RepoConfig |
| 36 | + |
| 37 | +EX_SECONDS = 253402300799 |
| 38 | + |
| 39 | + |
| 40 | +class RedisProvider(Provider): |
| 41 | + _redis_type: Optional[RedisType] |
| 42 | + _connection_string: str |
| 43 | + |
| 44 | + def __init__(self, config: RepoConfig): |
| 45 | + assert isinstance(config.online_store, RedisOnlineStoreConfig) |
| 46 | + if config.online_store.redis_type: |
| 47 | + self._redis_type = config.online_store.redis_type |
| 48 | + if config.online_store.connection_string: |
| 49 | + self._connection_string = config.online_store.connection_string |
| 50 | + self.offline_store = get_offline_store_from_config(config.offline_store) |
| 51 | + |
| 52 | + def update_infra( |
| 53 | + self, |
| 54 | + project: str, |
| 55 | + tables_to_delete: Sequence[Union[FeatureTable, FeatureView]], |
| 56 | + tables_to_keep: Sequence[Union[FeatureTable, FeatureView]], |
| 57 | + entities_to_delete: Sequence[Entity], |
| 58 | + entities_to_keep: Sequence[Entity], |
| 59 | + partial: bool, |
| 60 | + ): |
| 61 | + pass |
| 62 | + |
| 63 | + def teardown_infra( |
| 64 | + self, |
| 65 | + project: str, |
| 66 | + tables: Sequence[Union[FeatureTable, FeatureView]], |
| 67 | + entities: Sequence[Entity], |
| 68 | + ) -> None: |
| 69 | + # according to the repos_operations.py we can delete the whole project |
| 70 | + client = self._get_client() |
| 71 | + |
| 72 | + tables_join_keys = [[e for e in t.entities] for t in tables] |
| 73 | + for table_join_keys in tables_join_keys: |
| 74 | + redis_key_bin = _redis_key( |
| 75 | + project, EntityKeyProto(join_keys=table_join_keys) |
| 76 | + ) |
| 77 | + keys = [k for k in client.scan_iter(match=f"{redis_key_bin}*", count=100)] |
| 78 | + if keys: |
| 79 | + client.unlink(*keys) |
| 80 | + |
| 81 | + def online_write_batch( |
| 82 | + self, |
| 83 | + project: str, |
| 84 | + table: Union[FeatureTable, FeatureView], |
| 85 | + data: List[ |
| 86 | + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] |
| 87 | + ], |
| 88 | + progress: Optional[Callable[[int], Any]], |
| 89 | + ) -> None: |
| 90 | + client = self._get_client() |
| 91 | + |
| 92 | + entity_hset = {} |
| 93 | + feature_view = table.name |
| 94 | + |
| 95 | + ex = Timestamp() |
| 96 | + ex.seconds = EX_SECONDS |
| 97 | + ex_str = ex.SerializeToString() |
| 98 | + |
| 99 | + for entity_key, values, timestamp, created_ts in data: |
| 100 | + redis_key_bin = _redis_key(project, entity_key) |
| 101 | + ts = Timestamp() |
| 102 | + ts.seconds = int(utils.make_tzaware(timestamp).timestamp()) |
| 103 | + entity_hset[f"_ts:{feature_view}"] = ts.SerializeToString() |
| 104 | + entity_hset[f"_ex:{feature_view}"] = ex_str |
| 105 | + |
| 106 | + for feature_name, val in values.items(): |
| 107 | + f_key = _mmh3(f"{feature_view}:{feature_name}") |
| 108 | + entity_hset[f_key] = val.SerializeToString() |
| 109 | + |
| 110 | + client.hset(redis_key_bin, mapping=entity_hset) |
| 111 | + if progress: |
| 112 | + progress(1) |
| 113 | + |
| 114 | + def online_read( |
| 115 | + self, |
| 116 | + project: str, |
| 117 | + table: Union[FeatureTable, FeatureView], |
| 118 | + entity_keys: List[EntityKeyProto], |
| 119 | + requested_features: List[str] = None, |
| 120 | + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: |
| 121 | + |
| 122 | + client = self._get_client() |
| 123 | + feature_view = table.name |
| 124 | + |
| 125 | + result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] |
| 126 | + |
| 127 | + if not requested_features: |
| 128 | + requested_features = [f.name for f in table.features] |
| 129 | + |
| 130 | + for entity_key in entity_keys: |
| 131 | + redis_key_bin = _redis_key(project, entity_key) |
| 132 | + hset_keys = [_mmh3(f"{feature_view}:{k}") for k in requested_features] |
| 133 | + ts_key = f"_ts:{feature_view}" |
| 134 | + hset_keys.append(ts_key) |
| 135 | + values = client.hmget(redis_key_bin, hset_keys) |
| 136 | + requested_features.append(ts_key) |
| 137 | + res_val = dict(zip(requested_features, values)) |
| 138 | + |
| 139 | + res_ts = Timestamp() |
| 140 | + ts_val = res_val.pop(ts_key) |
| 141 | + if ts_val: |
| 142 | + res_ts.ParseFromString(ts_val) |
| 143 | + |
| 144 | + res = {} |
| 145 | + for feature_name, val_bin in res_val.items(): |
| 146 | + val = ValueProto() |
| 147 | + if val_bin: |
| 148 | + val.ParseFromString(val_bin) |
| 149 | + res[feature_name] = val |
| 150 | + |
| 151 | + if not res: |
| 152 | + result.append((None, None)) |
| 153 | + else: |
| 154 | + timestamp = datetime.fromtimestamp(res_ts.seconds) |
| 155 | + result.append((timestamp, res)) |
| 156 | + return result |
| 157 | + |
| 158 | + def materialize_single_feature_view( |
| 159 | + self, |
| 160 | + feature_view: FeatureView, |
| 161 | + start_date: datetime, |
| 162 | + end_date: datetime, |
| 163 | + registry: Registry, |
| 164 | + project: str, |
| 165 | + tqdm_builder: Callable[[int], tqdm], |
| 166 | + ) -> None: |
| 167 | + entities = [] |
| 168 | + for entity_name in feature_view.entities: |
| 169 | + entities.append(registry.get_entity(entity_name, project)) |
| 170 | + |
| 171 | + ( |
| 172 | + join_key_columns, |
| 173 | + feature_name_columns, |
| 174 | + event_timestamp_column, |
| 175 | + created_timestamp_column, |
| 176 | + ) = _get_column_names(feature_view, entities) |
| 177 | + |
| 178 | + start_date = utils.make_tzaware(start_date) |
| 179 | + end_date = utils.make_tzaware(end_date) |
| 180 | + |
| 181 | + table = self.offline_store.pull_latest_from_table_or_query( |
| 182 | + data_source=feature_view.input, |
| 183 | + join_key_columns=join_key_columns, |
| 184 | + feature_name_columns=feature_name_columns, |
| 185 | + event_timestamp_column=event_timestamp_column, |
| 186 | + created_timestamp_column=created_timestamp_column, |
| 187 | + start_date=start_date, |
| 188 | + end_date=end_date, |
| 189 | + ) |
| 190 | + |
| 191 | + if feature_view.input.field_mapping is not None: |
| 192 | + table = _run_field_mapping(table, feature_view.input.field_mapping) |
| 193 | + |
| 194 | + join_keys = [entity.join_key for entity in entities] |
| 195 | + rows_to_write = _convert_arrow_to_proto(table, feature_view, join_keys) |
| 196 | + |
| 197 | + with tqdm_builder(len(rows_to_write)) as pbar: |
| 198 | + self.online_write_batch( |
| 199 | + project, feature_view, rows_to_write, lambda x: pbar.update(x) |
| 200 | + ) |
| 201 | + |
| 202 | + feature_view.materialization_intervals.append((start_date, end_date)) |
| 203 | + registry.apply_feature_view(feature_view, project) |
| 204 | + |
| 205 | + def _parse_connection_string(self): |
| 206 | + """ |
| 207 | + Reads Redis connections string using format |
| 208 | + for RedisCluster: |
| 209 | + redis1:6379,redis2:6379,decode_responses=true,skip_full_coverage_check=true,ssl=true,password=... |
| 210 | + for Redis: |
| 211 | + redis_master:6379,db=0,ssl=true,password=... |
| 212 | + """ |
| 213 | + connection_string = self._connection_string |
| 214 | + startup_nodes = [ |
| 215 | + dict(zip(["host", "port"], c.split(":"))) |
| 216 | + for c in connection_string.split(",") |
| 217 | + if "=" not in c |
| 218 | + ] |
| 219 | + params = {} |
| 220 | + for c in connection_string.split(","): |
| 221 | + if "=" in c: |
| 222 | + kv = c.split("=") |
| 223 | + try: |
| 224 | + kv[1] = json.loads(kv[1]) |
| 225 | + except json.JSONDecodeError: |
| 226 | + ... |
| 227 | + |
| 228 | + it = iter(kv) |
| 229 | + params.update(dict(zip(it, it))) |
| 230 | + |
| 231 | + return startup_nodes, params |
| 232 | + |
| 233 | + def _get_client(self): |
| 234 | + """ |
| 235 | + Creates the Redis client RedisCluster or Redis depending on configuration |
| 236 | + """ |
| 237 | + startup_nodes, kwargs = self._parse_connection_string() |
| 238 | + if self._redis_type == RedisType.redis_cluster: |
| 239 | + kwargs["startup_nodes"] = startup_nodes |
| 240 | + return RedisCluster(**kwargs) |
| 241 | + else: |
| 242 | + kwargs["host"] = startup_nodes[0]["host"] |
| 243 | + kwargs["port"] = startup_nodes[0]["port"] |
| 244 | + return Redis(**kwargs) |
| 245 | + |
| 246 | + def get_historical_features( |
| 247 | + self, |
| 248 | + config: RepoConfig, |
| 249 | + feature_views: List[FeatureView], |
| 250 | + feature_refs: List[str], |
| 251 | + entity_df: Union[pd.DataFrame, str], |
| 252 | + registry: Registry, |
| 253 | + project: str, |
| 254 | + ) -> RetrievalJob: |
| 255 | + return self.offline_store.get_historical_features( |
| 256 | + config=config, |
| 257 | + feature_views=feature_views, |
| 258 | + feature_refs=feature_refs, |
| 259 | + entity_df=entity_df, |
| 260 | + registry=registry, |
| 261 | + project=project, |
| 262 | + ) |
| 263 | + |
| 264 | + |
| 265 | +def _redis_key(project: str, entity_key: EntityKeyProto): |
| 266 | + redis_key = RedisKeyProto( |
| 267 | + project=project, |
| 268 | + entity_names=entity_key.join_keys, |
| 269 | + entity_values=entity_key.entity_values, |
| 270 | + ) |
| 271 | + return redis_key.SerializeToString() |
| 272 | + |
| 273 | + |
| 274 | +def _mmh3(key: str): |
| 275 | + """ |
| 276 | + Calculate murmur3_32 hash which is equal to scala version which is using little endian: |
| 277 | + https://stackoverflow.com/questions/29932956/murmur3-hash-different-result-between-python-and-java-implementation |
| 278 | + https://stackoverflow.com/questions/13141787/convert-decimal-int-to-little-endian-string-x-x |
| 279 | + """ |
| 280 | + key_hash = mmh3.hash(key, signed=False) |
| 281 | + return bytes.fromhex(struct.pack("<Q", key_hash).hex()[:8]) |
0 commit comments