Skip to content

Commit 5390d7b

Browse files
committed
fix: Adopt connection pooling for HBase
Signed-off-by: Hai Nguyen <quanghai.ng1512@gmail.com>
1 parent 2192e65 commit 5390d7b

File tree

2 files changed

+87
-51
lines changed

2 files changed

+87
-51
lines changed

sdk/python/feast/infra/online_stores/contrib/hbase_online_store/hbase.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from datetime import datetime
44
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
55

6-
from happybase import Connection
6+
from happybase import Connection, ConnectionPool
77
from pydantic.typing import Literal
88

99
from feast import Entity
@@ -29,6 +29,9 @@ class HbaseOnlineStoreConfig(FeastConfigBaseModel):
2929
port: str
3030
"""Port in which Hbase Thrift server is running"""
3131

32+
connection_pool_size: int = 4
33+
"""Number of connections to Hbase Thrift server to keep in the connection pool"""
34+
3235

3336
class HbaseConnection:
3437
"""
@@ -62,7 +65,7 @@ class HbaseOnlineStore(OnlineStore):
6265
_conn: Happybase Connection to connect to hbase thrift server.
6366
"""
6467

65-
_conn: Connection = None
68+
_conn: ConnectionPool = None
6669

6770
def _get_conn(self, config: RepoConfig):
6871
"""
@@ -76,7 +79,11 @@ def _get_conn(self, config: RepoConfig):
7679
assert isinstance(store_config, HbaseOnlineStoreConfig)
7780

7881
if not self._conn:
79-
self._conn = Connection(host=store_config.host, port=int(store_config.port))
82+
self._conn = ConnectionPool(
83+
host=store_config.host,
84+
port=int(store_config.port),
85+
size=int(store_config.connection_pool_size),
86+
)
8087
return self._conn
8188

8289
@log_exceptions_and_usage(online_store="hbase")

sdk/python/feast/infra/utils/hbase_utils.py

Lines changed: 77 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from typing import List
22

3-
from happybase import Connection
4-
5-
from feast.infra.key_encoding_utils import serialize_entity_key
6-
from feast.protos.feast.types.EntityKey_pb2 import EntityKey
3+
from happybase import ConnectionPool
74

85

96
class HbaseConstants:
@@ -40,14 +37,22 @@ class HbaseUtils:
4037
"""
4138

4239
def __init__(
43-
self, conn: Connection = None, host: str = None, port: int = None, timeout=None
40+
self,
41+
pool: ConnectionPool = None,
42+
host: str = None,
43+
port: int = None,
44+
connection_pool_size: int = 4,
4445
):
45-
if conn is None:
46+
if pool is None:
4647
self.host = host
4748
self.port = port
48-
self.conn = Connection(host=host, port=port, timeout=timeout)
49+
self.pool = ConnectionPool(
50+
host=host,
51+
port=port,
52+
size=connection_pool_size,
53+
)
4954
else:
50-
self.conn = conn
55+
self.pool = pool
5156

5257
def create_table(self, table_name: str, colm_family: List[str]):
5358
"""
@@ -60,7 +65,9 @@ def create_table(self, table_name: str, colm_family: List[str]):
6065
cf_dict: dict = {}
6166
for cf in colm_family:
6267
cf_dict[cf] = dict()
63-
return self.conn.create_table(table_name, cf_dict)
68+
69+
with self.pool.connection() as conn:
70+
return conn.create_table(table_name, cf_dict)
6471

6572
def create_table_with_default_cf(self, table_name: str):
6673
"""
@@ -69,7 +76,8 @@ def create_table_with_default_cf(self, table_name: str):
6976
Arguments:
7077
table_name: Name of the Hbase table.
7178
"""
72-
return self.conn.create_table(table_name, {"default": dict()})
79+
with self.pool.connection() as conn:
80+
return conn.create_table(table_name, {"default": dict()})
7381

7482
def check_if_table_exist(self, table_name: str):
7583
"""
@@ -78,16 +86,18 @@ def check_if_table_exist(self, table_name: str):
7886
Arguments:
7987
table_name: Name of the Hbase table.
8088
"""
81-
return bytes(table_name, "utf-8") in self.conn.tables()
89+
with self.pool.connection() as conn:
90+
return bytes(table_name, "utf-8") in conn.tables()
8291

8392
def batch(self, table_name: str):
8493
"""
85-
Returns a 'Batch' instance that can be used for mass data manipulation in the hbase table.
94+
Returns a "Batch" instance that can be used for mass data manipulation in the hbase table.
8695
8796
Arguments:
8897
table_name: Name of the Hbase table.
8998
"""
90-
return self.conn.table(table_name).batch()
99+
with self.pool.connection() as conn:
100+
return conn.table(table_name).batch()
91101

92102
def put(self, table_name: str, row_key: str, data: dict):
93103
"""
@@ -98,8 +108,9 @@ def put(self, table_name: str, row_key: str, data: dict):
98108
row_key: Row key of the row to be inserted to hbase table.
99109
data: Mapping of column family name:column name to column values
100110
"""
101-
table = self.conn.table(table_name)
102-
table.put(row_key, data)
111+
with self.pool.connection() as conn:
112+
table = conn.table(table_name)
113+
table.put(row_key, data)
103114

104115
def row(
105116
self,
@@ -119,8 +130,9 @@ def row(
119130
timestamp: timestamp specifies the maximum version the cells can have.
120131
include_timestamp: specifies if (column, timestamp) to be return instead of only column.
121132
"""
122-
table = self.conn.table(table_name)
123-
return table.row(row_key, columns, timestamp, include_timestamp)
133+
with self.pool.connection() as conn:
134+
table = conn.table(table_name)
135+
return table.row(row_key, columns, timestamp, include_timestamp)
124136

125137
def rows(
126138
self,
@@ -140,52 +152,69 @@ def rows(
140152
timestamp: timestamp specifies the maximum version the cells can have.
141153
include_timestamp: specifies if (column, timestamp) to be return instead of only column.
142154
"""
143-
table = self.conn.table(table_name)
144-
return table.rows(row_keys, columns, timestamp, include_timestamp)
155+
with self.pool.connection() as conn:
156+
table = conn.table(table_name)
157+
return table.rows(row_keys, columns, timestamp, include_timestamp)
145158

146159
def print_table(self, table_name):
147160
"""Prints the table scanning all the rows of the hbase table."""
148-
table = self.conn.table(table_name)
149-
scan_data = table.scan()
150-
for row_key, cols in scan_data:
151-
print(row_key.decode("utf-8"), cols)
161+
with self.pool.connection() as conn:
162+
table = conn.table(table_name)
163+
scan_data = table.scan()
164+
for row_key, cols in scan_data:
165+
print(row_key.decode("utf-8"), cols)
152166

153167
def delete_table(self, table: str):
154168
"""Deletes the hbase table given the table name."""
155169
if self.check_if_table_exist(table):
156-
self.conn.delete_table(table, disable=True)
170+
with self.pool.connection() as conn:
171+
conn.delete_table(table, disable=True)
157172

158173
def close_conn(self):
159174
"""Closes the happybase connection."""
160-
self.conn.close()
175+
with self.pool.connection() as conn:
176+
conn.close()
161177

162178

163179
def main():
180+
from feast.infra.key_encoding_utils import serialize_entity_key
181+
from feast.protos.feast.types.EntityKey_pb2 import EntityKey
164182
from feast.protos.feast.types.Value_pb2 import Value
165183

166-
connection = Connection(host="localhost", port=9090)
167-
table = connection.table("test_hbase_driver_hourly_stats")
168-
row_keys = [
169-
serialize_entity_key(
170-
EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1004)]),
171-
entity_key_serialization_version=2,
172-
).hex(),
173-
serialize_entity_key(
174-
EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1005)]),
175-
entity_key_serialization_version=2,
176-
).hex(),
177-
serialize_entity_key(
178-
EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1024)]),
179-
entity_key_serialization_version=2,
180-
).hex(),
181-
]
182-
rows = table.rows(row_keys)
183-
184-
for row_key, row in rows:
185-
for key, value in row.items():
186-
col_name = bytes.decode(key, "utf-8").split(":")[1]
187-
print(col_name, value)
188-
print()
184+
pool = ConnectionPool(
185+
host="localhost",
186+
port=9090,
187+
size=2,
188+
)
189+
with pool.connection() as connection:
190+
table = connection.table("test_hbase_driver_hourly_stats")
191+
row_keys = [
192+
serialize_entity_key(
193+
EntityKey(
194+
join_keys=["driver_id"], entity_values=[Value(int64_val=1004)]
195+
),
196+
entity_key_serialization_version=2,
197+
).hex(),
198+
serialize_entity_key(
199+
EntityKey(
200+
join_keys=["driver_id"], entity_values=[Value(int64_val=1005)]
201+
),
202+
entity_key_serialization_version=2,
203+
).hex(),
204+
serialize_entity_key(
205+
EntityKey(
206+
join_keys=["driver_id"], entity_values=[Value(int64_val=1024)]
207+
),
208+
entity_key_serialization_version=2,
209+
).hex(),
210+
]
211+
rows = table.rows(row_keys)
212+
213+
for _, row in rows:
214+
for key, value in row.items():
215+
col_name = bytes.decode(key, "utf-8").split(":")[1]
216+
print(col_name, value)
217+
print()
189218

190219

191220
if __name__ == "__main__":

0 commit comments

Comments
 (0)