Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Refactor go feature server
Signed-off-by: Kevin Zhang <kzhang@tecton.ai>
  • Loading branch information
kevjumba committed Jul 29, 2022
commit c16d9ade777ddd926c6bb92ad40978f93e275901
19 changes: 5 additions & 14 deletions sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
FileDataSourceCreator,
)

from tests.utils.http_utils import check_port_open, free_port

logger = logging.getLogger(__name__)

level = logging.INFO
Expand Down Expand Up @@ -327,7 +329,7 @@ def feature_server_endpoint(environment):
yield environment.feature_store.get_feature_server_endpoint()
return

port = _free_port()
port = free_port()

proc = Process(
target=start_test_local_server,
Expand All @@ -340,7 +342,7 @@ def feature_server_endpoint(environment):
proc.start()
# Wait for server to start
wait_retry_backoff(
lambda: (None, _check_port_open("localhost", port)),
lambda: (None, check_port_open("localhost", port)),
timeout_secs=10,
)

Expand All @@ -353,23 +355,12 @@ def feature_server_endpoint(environment):
wait_retry_backoff(
lambda: (
None,
not _check_port_open("localhost", environment.get_local_server_port()),
not check_port_open("localhost", environment.get_local_server_port()),
),
timeout_secs=30,
)


def _check_port_open(host, port) -> bool:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
return sock.connect_ex((host, port)) == 0


def _free_port():
sock = socket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]


@pytest.fixture
def universal_data_sources(environment) -> TestData:
return construct_universal_test_data(environment)
Expand Down
96 changes: 28 additions & 68 deletions sdk/python/tests/integration/e2e/test_go_feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,10 @@
location,
)

from tests.utils.http_utils import free_port, check_port_open
from tests.utils.feature_utils import generate_expected_logs, get_latest_rows

@pytest.fixture
def initialized_registry(environment, universal_data_sources):
fs = environment.feature_store

_, _, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)

feature_service = FeatureService(
name="driver_features",
features=[feature_views.driver],
logging_config=LoggingConfig(
destination=environment.data_source_creator.create_logged_features_destination(),
sample_rate=1.0,
),
)
feast_objects: List[FeastObject] = [feature_service]
feast_objects.extend(feature_views.values())
feast_objects.extend([driver(), customer(), location()])

fs.apply(feast_objects)
fs.materialize(environment.start_date, environment.end_date)


def server_port(environment, server_type: str):
def _server_port(environment, server_type: str):
if not environment.test_repo_config.go_feature_serving:
pytest.skip("Only for Go path")

Expand Down Expand Up @@ -106,15 +85,38 @@ def server_port(environment, server_type: str):
# wait for graceful stop
time.sleep(5)

# Go test fixtures

@pytest.fixture
def initialized_registry(environment, universal_data_sources):
fs = environment.feature_store

_, _, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)

feature_service = FeatureService(
name="driver_features",
features=[feature_views.driver],
logging_config=LoggingConfig(
destination=environment.data_source_creator.create_logged_features_destination(),
sample_rate=1.0,
),
)
feast_objects: List[FeastObject] = [feature_service]
feast_objects.extend(feature_views.values())
feast_objects.extend([driver(), customer(), location()])

fs.apply(feast_objects)
fs.materialize(environment.start_date, environment.end_date)

@pytest.fixture
def grpc_server_port(environment, initialized_registry):
yield from server_port(environment, "grpc")
yield from _server_port(environment, "grpc")


@pytest.fixture
def http_server_port(environment, initialized_registry):
yield from server_port(environment, "http")
yield from _server_port(environment, "http")


@pytest.fixture
Expand Down Expand Up @@ -252,45 +254,3 @@ def retrieve():
persisted_logs = persisted_logs.sort_values(by="driver_id").reset_index(drop=True)
persisted_logs = persisted_logs[expected_logs.columns]
pd.testing.assert_frame_equal(expected_logs, persisted_logs, check_dtype=False)


def free_port():
sock = socket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]


def check_port_open(host, port) -> bool:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
return sock.connect_ex((host, port)) == 0


def get_latest_rows(df, join_key, entity_values):
rows = df[df[join_key].isin(entity_values)]
return rows.loc[rows.groupby(join_key)["event_timestamp"].idxmax()]


def generate_expected_logs(
df: pd.DataFrame,
feature_view: FeatureView,
features: List[str],
join_keys: List[str],
timestamp_column: str,
):
logs = pd.DataFrame()
for join_key in join_keys:
logs[join_key] = df[join_key]

for feature in features:
col = f"{feature_view.name}__{feature}"
logs[col] = df[feature]
logs[f"{col}__timestamp"] = df[timestamp_column]
logs[f"{col}__status"] = FieldStatus.PRESENT
if feature_view.ttl:
logs[f"{col}__status"] = logs[f"{col}__status"].mask(
df[timestamp_column]
< datetime.utcnow().replace(tzinfo=pytz.UTC) - feature_view.ttl,
FieldStatus.OUTSIDE_MAX_AGE,
)

return logs.sort_values(by=join_keys).reset_index(drop=True)
120 changes: 120 additions & 0 deletions sdk/python/tests/utils/feature_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import List
from datetime import datetime
import pytz

import contextlib
import datetime
import tempfile
import uuid
from pathlib import Path
from typing import Iterator, Union

import numpy as np
import pandas as pd
import pyarrow

from feast import FeatureService, FeatureStore, FeatureView
from feast.errors import FeatureViewNotFoundException
from feast.feature_logging import LOG_DATE_FIELD, LOG_TIMESTAMP_FIELD, REQUEST_ID_FIELD
from feast.protos.feast.serving.ServingService_pb2 import FieldStatus

"""
Return latest rows in a dataframe based on join key and entity values.
"""
def get_latest_rows(df, join_key, entity_values):
rows = df[df[join_key].isin(entity_values)]
return rows.loc[rows.groupby(join_key)["event_timestamp"].idxmax()]

"""
Given dataframe and feature view, generate the expected logging dataframes that would be otherwise generated by our logging infrastructure.
"""
def generate_expected_logs(
df: pd.DataFrame,
feature_view: FeatureView,
features: List[str],
join_keys: List[str],
timestamp_column: str,
):
logs = pd.DataFrame()
for join_key in join_keys:
logs[join_key] = df[join_key]

for feature in features:
col = f"{feature_view.name}__{feature}"
logs[col] = df[feature]
logs[f"{col}__timestamp"] = df[timestamp_column]
logs[f"{col}__status"] = FieldStatus.PRESENT
if feature_view.ttl:
logs[f"{col}__status"] = logs[f"{col}__status"].mask(
df[timestamp_column]
< datetime.utcnow().replace(tzinfo=pytz.UTC) - feature_view.ttl,
FieldStatus.OUTSIDE_MAX_AGE,
)

return logs.sort_values(by=join_keys).reset_index(drop=True)


def prepare_logs(
source_df: pd.DataFrame, feature_service: FeatureService, store: FeatureStore
) -> pd.DataFrame:
num_rows = source_df.shape[0]

logs_df = pd.DataFrame()
logs_df[REQUEST_ID_FIELD] = [str(uuid.uuid4()) for _ in range(num_rows)]
logs_df[LOG_TIMESTAMP_FIELD] = pd.Series(
np.random.randint(0, 7 * 24 * 3600, num_rows)
).map(lambda secs: pd.Timestamp.utcnow() - datetime.timedelta(seconds=secs))
logs_df[LOG_DATE_FIELD] = logs_df[LOG_TIMESTAMP_FIELD].dt.date

for projection in feature_service.feature_view_projections:
try:
view = store.get_feature_view(projection.name)
except FeatureViewNotFoundException:
view = store.get_on_demand_feature_view(projection.name)
for source in view.source_request_sources.values():
for field in source.schema:
logs_df[field.name] = source_df[field.name]
else:
for entity_name in view.entities:
entity = store.get_entity(entity_name)
logs_df[entity.join_key] = source_df[entity.join_key]

for feature in projection.features:
source_field = (
feature.name
if feature.name in source_df.columns
else f"{projection.name_to_use()}__{feature.name}"
)
destination_field = f"{projection.name_to_use()}__{feature.name}"
logs_df[destination_field] = source_df[source_field]
logs_df[f"{destination_field}__timestamp"] = source_df[
"event_timestamp"
].dt.floor("s")
if logs_df[f"{destination_field}__timestamp"].dt.tz:
logs_df[f"{destination_field}__timestamp"] = logs_df[
f"{destination_field}__timestamp"
].dt.tz_convert(None)
logs_df[f"{destination_field}__status"] = FieldStatus.PRESENT
if isinstance(view, FeatureView) and view.ttl:
logs_df[f"{destination_field}__status"] = logs_df[
f"{destination_field}__status"
].mask(
logs_df[f"{destination_field}__timestamp"]
< (datetime.datetime.utcnow() - view.ttl),
FieldStatus.OUTSIDE_MAX_AGE,
)

return logs_df


@contextlib.contextmanager
def to_logs_dataset(
table: pyarrow.Table, pass_as_path: bool
) -> Iterator[Union[pyarrow.Table, Path]]:
if not pass_as_path:
yield table
return

with tempfile.TemporaryDirectory() as temp_dir:
pyarrow.parquet.write_to_dataset(table, root_path=temp_dir)
yield Path(temp_dir)
12 changes: 12 additions & 0 deletions sdk/python/tests/utils/http_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import socket
from contextlib import closing

def free_port():
sock = socket.socket()
sock.bind(("", 0))
return sock.getsockname()[1]


def check_port_open(host, port) -> bool:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
return sock.connect_ex((host, port)) == 0