Skip to content
Merged
Prev Previous commit
Next Next commit
fix tests
Signed-off-by: Achal Shah <achals@gmail.com>
  • Loading branch information
achals committed Jul 19, 2022
commit 7ab893cdecd88379753d7870cf347886ccd005c5
25 changes: 21 additions & 4 deletions go/internal/feast/onlinestore/redisonlinestore_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package onlinestore

import (
"github.com/feast-dev/feast/go/internal/feast/registry"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -10,7 +11,11 @@ func TestNewRedisOnlineStore(t *testing.T) {
var config = map[string]interface{}{
"connection_string": "redis://localhost:6379",
}
store, err := NewRedisOnlineStore("test", config)
rc := &registry.RepoConfig{
OnlineStore: config,
EntityKeySerializationVersion: 2,
}
store, err := NewRedisOnlineStore("test", rc, config)
assert.Nil(t, err)
var opts = store.client.Options()
assert.Equal(t, opts.Addr, "redis://localhost:6379")
Expand All @@ -23,7 +28,11 @@ func TestNewRedisOnlineStoreWithPassword(t *testing.T) {
var config = map[string]interface{}{
"connection_string": "redis://localhost:6379,password=secret",
}
store, err := NewRedisOnlineStore("test", config)
rc := &registry.RepoConfig{
OnlineStore: config,
EntityKeySerializationVersion: 2,
}
store, err := NewRedisOnlineStore("test", rc, config)
assert.Nil(t, err)
var opts = store.client.Options()
assert.Equal(t, opts.Addr, "redis://localhost:6379")
Expand All @@ -34,7 +43,11 @@ func TestNewRedisOnlineStoreWithDB(t *testing.T) {
var config = map[string]interface{}{
"connection_string": "redis://localhost:6379,db=1",
}
store, err := NewRedisOnlineStore("test", config)
rc := &registry.RepoConfig{
OnlineStore: config,
EntityKeySerializationVersion: 2,
}
store, err := NewRedisOnlineStore("test", rc, config)
assert.Nil(t, err)
var opts = store.client.Options()
assert.Equal(t, opts.Addr, "redis://localhost:6379")
Expand All @@ -45,7 +58,11 @@ func TestNewRedisOnlineStoreWithSsl(t *testing.T) {
var config = map[string]interface{}{
"connection_string": "redis://localhost:6379,ssl=true",
}
store, err := NewRedisOnlineStore("test", config)
rc := &registry.RepoConfig{
OnlineStore: config,
EntityKeySerializationVersion: 2,
}
store, err := NewRedisOnlineStore("test", rc, config)
assert.Nil(t, err)
var opts = store.client.Options()
assert.Equal(t, opts.Addr, "redis://localhost:6379")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ def online_write_batch(

b = hbase.batch(table_name)
for entity_key, values, timestamp, created_ts in data:
row_key = serialize_entity_key(entity_key,
entity_key_serialization_version=config.entity_key_serialization_version).hex()
row_key = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
values_dict = {}
for feature_name, val in values.items():
values_dict[
Expand Down Expand Up @@ -155,8 +157,10 @@ def online_read(
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []

row_keys = [
serialize_entity_key(entity_key,
entity_key_serialization_version=config.entity_key_serialization_version).hex()
serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for entity_key in entity_keys
]
rows = hbase.rows(table_name, row_keys=row_keys)
Expand Down
14 changes: 9 additions & 5 deletions sdk/python/feast/infra/online_stores/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,14 @@ def _write_minibatch(
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
config: RepoConfig
config: RepoConfig,
):
entities = []
for entity_key, features, timestamp, created_ts in data:
document_id = compute_entity_id(entity_key,
entity_key_serialization_version=config.entity_key_serialization_version)
document_id = compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)

key = client.key(
"Project", project, "Table", table.name, "Row", document_id,
Expand Down Expand Up @@ -243,8 +245,10 @@ def online_read(
keys: List[Key] = []
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
for entity_key in entity_keys:
document_id = compute_entity_id(entity_key,
entity_key_serialization_version=config.entity_key_serialization_version)
document_id = compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
key = client.key(
"Project", feast_project, "Table", table.name, "Row", document_id
)
Expand Down
16 changes: 11 additions & 5 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,13 @@ def online_read(
)

result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
entity_ids = [compute_entity_id(entity_key,
entity_key_serialization_version=config.entity_key_serialization_version)
for entity_key in entity_keys]
entity_ids = [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]
batch_size = online_config.batch_size
entity_ids_iter = iter(entity_ids)
while True:
Expand Down Expand Up @@ -307,8 +311,10 @@ def _write_batch_non_duplicates(
"""Deduplicate write batch request items on ``entity_id`` primary key."""
with table_instance.batch_writer(overwrite_by_pkeys=["entity_id"]) as batch:
for entity_key, features, timestamp, created_ts in data:
entity_id = compute_entity_id(entity_key,
entity_key_serialization_version=config.entity_key_serialization_version)
entity_id = compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
batch.put_item(
Item={
"entity_id": entity_id, # PartitionKey
Expand Down
14 changes: 10 additions & 4 deletions sdk/python/feast/infra/online_stores/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,11 @@ def online_write_batch(
# TODO: investigate if check and set is a better approach rather than pulling all entity ts and then setting
# it may be significantly slower but avoids potential (rare) race conditions
for entity_key, _, _, _ in data:
redis_key_bin = _redis_key(project, entity_key,
entity_key_serialization_version=config.entity_key_serialization_version)
redis_key_bin = _redis_key(
project,
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
keys.append(redis_key_bin)
pipe.hmget(redis_key_bin, ts_key)
prev_event_timestamps = pipe.execute()
Expand Down Expand Up @@ -269,8 +272,11 @@ def online_read(

keys = []
for entity_key in entity_keys:
redis_key_bin = _redis_key(project, entity_key,
entity_key_serialization_version=config.entity_key_serialization_version)
redis_key_bin = _redis_key(
project,
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
keys.append(redis_key_bin)
with client.pipeline(transaction=False) as pipe:
for redis_key_bin in keys:
Expand Down
6 changes: 3 additions & 3 deletions sdk/python/feast/infra/utils/hbase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,15 @@ def main():
row_keys = [
serialize_entity_key(
EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1004)]),
entity_key_serialization_version=2
entity_key_serialization_version=2,
).hex(),
serialize_entity_key(
EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1005)]),
entity_key_serialization_version=2
entity_key_serialization_version=2,
).hex(),
serialize_entity_key(
EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1024)]),
entity_key_serialization_version=2
entity_key_serialization_version=2,
).hex(),
]
rows = table.rows(row_keys)
Expand Down