Skip to content
Prev Previous commit
Next Next commit
add interface
Signed-off-by: hao-affirm <104030690+hao-affirm@users.noreply.github.com>
  • Loading branch information
hao-affirm committed Sep 7, 2022
commit 7d98a4a85b532c99f6ce74e59ee77380cd0b4e4b
215 changes: 215 additions & 0 deletions sdk/python/feast/infra/online_stores/contrib/mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from __future__ import absolute_import

from datetime import datetime
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
)

import pymysql
import pytz
from feast import (
Entity,
FeatureView,
RepoConfig,
)
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel
from pydantic import StrictStr
from pymysql.connections import Connection


class MySQLOnlineStoreConfig(FeastConfigBaseModel):
"""
Configuration for the MySQL online store.
NOTE: The class *must* end with the `OnlineStoreConfig` suffix.
"""

type = "mysql"

host: Optional[StrictStr] = None
user: Optional[StrictStr] = None
password: Optional[StrictStr] = None
database: Optional[StrictStr] = None


class MySQLOnlineStore(OnlineStore):
"""
An online store implementation that uses MySQL.
NOTE: The class *must* end with the `OnlineStore` suffix.
"""

_conn: Connection

def _get_conn(self, config: RepoConfig) -> Connection:

online_store_config = config.online_store
assert isinstance(online_store_config, MySQLOnlineStoreConfig)

if not self._conn:
self._conn = pymysql.connect(
host=online_store_config.host or "127.0.0.1",
user=online_store_config.user or "root",
password=online_store_config.password,
database=online_store_config.database or "feast",
autocommit=True,
)
return self._conn

def online_write_batch(
self,
config: RepoConfig,
table: FeatureView,
data: List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]],
progress: Optional[Callable[[int], Any]],
) -> None:

conn = self._get_conn(config)
cur = conn.cursor()

project = config.project

for entity_key, values, timestamp, created_ts in data:
entity_key_bin = serialize_entity_key(entity_key).hex()
Comment thread
felixwang9817 marked this conversation as resolved.
Outdated
timestamp = _to_naive_utc(timestamp)
if created_ts is not None:
created_ts = _to_naive_utc(created_ts)

for feature_name, val in values.items():
self.write_to_table(created_ts, cur, entity_key_bin, feature_name, project, table, timestamp, val)
self._conn.commit()
if progress:
progress(1)

@staticmethod
def write_to_table(created_ts, cur, entity_key_bin, feature_name, project, table, timestamp, val) -> None:
cur.execute(
f"""
UPDATE {_table_id(project, table)}
SET value = %s, event_ts = %s, created_ts = %s
WHERE (entity_key = %s AND feature_name = %s)
""",
(
# SET
val.SerializeToString(),
timestamp,
created_ts,
# WHERE
entity_key_bin,
feature_name,
),
)
cur.execute(
f"""INSERT INTO {_table_id(project, table)}
(entity_key, feature_name, value, event_ts, created_ts)
VALUES (%s, %s, %s, %s, %s)""",
(
entity_key_bin,
feature_name,
val.SerializeToString(),
timestamp,
created_ts,
),
)

def online_read(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
conn = self._get_conn(config)
cur = conn.cursor()

result: List[Tuple[Optional[datetime], Optional[Dict[str, Any]]]] = []

project = config.project
for entity_key in entity_keys:
entity_key_bin = serialize_entity_key(entity_key).hex()

cur.execute(
f"SELECT feature_name, value, event_ts FROM {_table_id(project, table)} WHERE entity_key = %s",
(entity_key_bin,),
)

res = {}
res_ts: Optional[datetime] = None
records = cur.fetchall()
if records:
for feature_name, val_bin, ts in records:
val = ValueProto()
val.ParseFromString(val_bin)
res[feature_name] = val
res_ts = datetime.strptime(ts, '%Y-%m-%d %H:%M:%S')

if not res:
result.append((None, None))
else:
result.append((res_ts, res))
return result

def update(
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
) -> None:
conn = self._get_conn(config)
cur = conn.cursor()
project = config.project

# We don't create any special state for the entities in this implementation.
for table in tables_to_keep:
cur.execute(
f"""CREATE TABLE IF NOT EXISTS {_table_id(project, table)} (entity_key VARCHAR(512),
feature_name VARCHAR(256),
value BLOB,
event_ts,
created_ts timestamp NULL DEFAULT NULL,
PRIMARY KEY(entity_key, feature_name))"""
)

cur.execute(
f"ALTER TABLE {_table_id(project, table)} ADD INDEX {_table_id(project, table)}_ek (entity_key);"
)

for table in tables_to_delete:
cur.execute(f"DROP INDEX {_table_id(project, table)}_ek ON {_table_id(project, table)};")
cur.execute(f"DROP TABLE IF EXISTS {_table_id(project, table)}")

def teardown(
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
) -> None:
conn = self._get_conn(config)
cur = conn.cursor()
project = config.project

for table in tables:
cur.execute(f"DROP INDEX {_table_id(project, table)}_ek ON {_table_id(project, table)};")
cur.execute(f"DROP TABLE IF EXISTS {_table_id(project, table)}")


def _table_id(project: str, table: FeatureView) -> str:
return f"{project}_{table.name}"


def _to_naive_utc(ts: datetime) -> datetime:
if ts.tzinfo is None:
return ts
else:
return ts.astimezone(pytz.utc).replace(tzinfo=None)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from tests.integration.feature_repos.integration_test_repo_config import (
IntegrationTestRepoConfig,
)
from tests.integration.feature_repos.universal.online_store.mysql import (
MySQLOnlineStoreCreator,
)

FULL_REPO_CONFIGS = [
IntegrationTestRepoConfig(online_store_creator=MySQLOnlineStoreCreator),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Dict

from testcontainers.mysql import MySqlContainer
from testcontainers.core.waiting_utils import wait_for_logs

from tests.integration.feature_repos.universal.online_store_creator import (
OnlineStoreCreator,
)


class MySQLOnlineStoreCreator(OnlineStoreCreator):
def __init__(self, project_name: str, **kwargs):
super().__init__(project_name)
self.container = MySqlContainer('mysql:latest', platform='linux/amd64').with_exposed_ports("3306") \
.with_env("MYSQL_USER", "root") \
.with_env("MYSQL_DATABASE", "feast")

def create_online_store(self) -> Dict[str, str]:
self.container.start()
log_string_to_wait_for = "Ready to accept connections"
wait_for_logs(
container=self.container, predicate=log_string_to_wait_for, timeout=10
)
return {"type": "mysql", "user": "root", "password": "test", "database": "feast"}

def teardown(self):
self.container.stop()