Skip to content

Commit 17ad500

Browse files
lokeshrangineniredhatHameed
authored andcommitted
Added the arrow flight interceptor to inject the auth header. (feast-dev#68)
* * Added the arrow flight interceptor to inject the auth header. * Injecting grpc interceptor if it is needed when auth type is not NO_AUTH. Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com> * Fixing the failing integration test cases by setting the header in binary format. Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com> * Refactored method and moved to factory class to incorporate code review comment. Fixed lint error by removing the type of port. and other minor changes. Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com> * Incorproating code review comments from Daniel. Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com> --------- Signed-off-by: Lokesh Rangineni <19699092+lokeshrangineni@users.noreply.github.com> Signed-off-by: Abdul Hameed <ahameed@redhat.com>
1 parent a1af4a1 commit 17ad500

9 files changed

Lines changed: 106 additions & 113 deletions

File tree

sdk/python/feast/infra/offline_stores/remote.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
RetrievalMetadata,
2828
)
2929
from feast.infra.registry.base_registry import BaseRegistry
30-
from feast.permissions.client.utils import create_flight_call_options
30+
from feast.permissions.client.arrow_flight_auth_interceptor import (
31+
build_arrow_flight_client,
32+
)
3133
from feast.repo_config import FeastConfigBaseModel, RepoConfig
3234
from feast.saved_dataset import SavedDatasetStorage
3335

@@ -47,7 +49,6 @@ class RemoteRetrievalJob(RetrievalJob):
4749
def __init__(
4850
self,
4951
client: fl.FlightClient,
50-
options: pa.flight.FlightCallOptions,
5152
api: str,
5253
api_parameters: Dict[str, Any],
5354
entity_df: Union[pd.DataFrame, str] = None,
@@ -56,7 +57,6 @@ def __init__(
5657
):
5758
# Initialize the client connection
5859
self.client = client
59-
self.options = options
6060
self.api = api
6161
self.api_parameters = api_parameters
6262
self.entity_df = entity_df
@@ -77,7 +77,6 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
7777
self.entity_df,
7878
self.table,
7979
self.client,
80-
self.options,
8180
)
8281

8382
@property
@@ -118,7 +117,6 @@ def persist(
118117
api=RemoteRetrievalJob.persist.__name__,
119118
api_parameters=api_parameters,
120119
client=self.client,
121-
options=self.options,
122120
table=self.table,
123121
entity_df=self.entity_df,
124122
)
@@ -137,9 +135,9 @@ def get_historical_features(
137135
) -> RemoteRetrievalJob:
138136
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)
139137

140-
# Initialize the client connection
141-
client = RemoteOfflineStore.init_client(config)
142-
options = create_flight_call_options(config.auth_config)
138+
client = build_arrow_flight_client(
139+
config.offline_store.host, config.offline_store.port, config.auth_config
140+
)
143141

144142
feature_view_names = [fv.name for fv in feature_views]
145143
name_aliases = [fv.projection.name_alias for fv in feature_views]
@@ -154,7 +152,6 @@ def get_historical_features(
154152

155153
return RemoteRetrievalJob(
156154
client=client,
157-
options=options,
158155
api=OfflineStore.get_historical_features.__name__,
159156
api_parameters=api_parameters,
160157
entity_df=entity_df,
@@ -174,8 +171,9 @@ def pull_all_from_table_or_query(
174171
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)
175172

176173
# Initialize the client connection
177-
client = RemoteOfflineStore.init_client(config)
178-
options = create_flight_call_options(config.auth_config)
174+
client = build_arrow_flight_client(
175+
config.offline_store.host, config.offline_store.port, config.auth_config
176+
)
179177

180178
api_parameters = {
181179
"data_source_name": data_source.name,
@@ -188,7 +186,6 @@ def pull_all_from_table_or_query(
188186

189187
return RemoteRetrievalJob(
190188
client=client,
191-
options=options,
192189
api=OfflineStore.pull_all_from_table_or_query.__name__,
193190
api_parameters=api_parameters,
194191
)
@@ -207,8 +204,9 @@ def pull_latest_from_table_or_query(
207204
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)
208205

209206
# Initialize the client connection
210-
client = RemoteOfflineStore.init_client(config)
211-
options = create_flight_call_options(config.auth_config)
207+
client = build_arrow_flight_client(
208+
config.offline_store.host, config.offline_store.port, config.auth_config
209+
)
212210

213211
api_parameters = {
214212
"data_source_name": data_source.name,
@@ -222,7 +220,6 @@ def pull_latest_from_table_or_query(
222220

223221
return RemoteRetrievalJob(
224222
client=client,
225-
options=options,
226223
api=OfflineStore.pull_latest_from_table_or_query.__name__,
227224
api_parameters=api_parameters,
228225
)
@@ -242,8 +239,9 @@ def write_logged_features(
242239
data = pyarrow.parquet.read_table(data, use_threads=False, pre_buffer=False)
243240

244241
# Initialize the client connection
245-
client = RemoteOfflineStore.init_client(config)
246-
options = create_flight_call_options(config.auth_config)
242+
client = build_arrow_flight_client(
243+
config.offline_store.host, config.offline_store.port, config.auth_config
244+
)
247245

248246
api_parameters = {
249247
"feature_service_name": source._feature_service.name,
@@ -253,7 +251,6 @@ def write_logged_features(
253251
api=OfflineStore.write_logged_features.__name__,
254252
api_parameters=api_parameters,
255253
client=client,
256-
options=options,
257254
table=data,
258255
entity_df=None,
259256
)
@@ -268,8 +265,9 @@ def offline_write_batch(
268265
assert isinstance(config.offline_store, RemoteOfflineStoreConfig)
269266

270267
# Initialize the client connection
271-
client = RemoteOfflineStore.init_client(config)
272-
options = create_flight_call_options(config.auth_config)
268+
client = build_arrow_flight_client(
269+
config.offline_store.host, config.offline_store.port, config.auth_config
270+
)
273271

274272
feature_view_names = [feature_view.name]
275273
name_aliases = [feature_view.projection.name_alias]
@@ -284,18 +282,10 @@ def offline_write_batch(
284282
api=OfflineStore.offline_write_batch.__name__,
285283
api_parameters=api_parameters,
286284
client=client,
287-
options=options,
288285
table=table,
289286
entity_df=None,
290287
)
291288

292-
@staticmethod
293-
def init_client(config):
294-
location = f"grpc://{config.offline_store.host}:{config.offline_store.port}"
295-
client = fl.connect(location=location)
296-
logger.info(f"Connecting FlightClient at {location}")
297-
return client
298-
299289

300290
def _create_retrieval_metadata(feature_refs: List[str], entity_df: pd.DataFrame):
301291
entity_schema = _get_entity_schema(
@@ -349,35 +339,31 @@ def _send_retrieve_remote(
349339
entity_df: Union[pd.DataFrame, str],
350340
table: pa.Table,
351341
client: fl.FlightClient,
352-
options: pa.flight.FlightCallOptions,
353342
):
354343
command_descriptor = _call_put(
355344
api,
356345
api_parameters,
357346
client,
358-
options,
359347
entity_df,
360348
table,
361349
)
362-
return _call_get(client, options, command_descriptor)
350+
return _call_get(client, command_descriptor)
363351

364352

365353
def _call_get(
366354
client: fl.FlightClient,
367-
options: pa.flight.FlightCallOptions,
368355
command_descriptor: fl.FlightDescriptor,
369356
):
370-
flight = client.get_flight_info(command_descriptor, options)
357+
flight = client.get_flight_info(command_descriptor)
371358
ticket = flight.endpoints[0].ticket
372-
reader = client.do_get(ticket, options)
359+
reader = client.do_get(ticket)
373360
return reader.read_all()
374361

375362

376363
def _call_put(
377364
api: str,
378365
api_parameters: Dict[str, Any],
379366
client: fl.FlightClient,
380-
options: pa.flight.FlightCallOptions,
381367
entity_df: Union[pd.DataFrame, str],
382368
table: pa.Table,
383369
):
@@ -397,7 +383,7 @@ def _call_put(
397383
)
398384
)
399385

400-
_put_parameters(command_descriptor, entity_df, table, client, options)
386+
_put_parameters(command_descriptor, entity_df, table, client)
401387
return command_descriptor
402388

403389

@@ -406,7 +392,6 @@ def _put_parameters(
406392
entity_df: Union[pd.DataFrame, str],
407393
table: pa.Table,
408394
client: fl.FlightClient,
409-
options: pa.flight.FlightCallOptions,
410395
):
411396
updatedTable: pa.Table
412397

@@ -417,7 +402,7 @@ def _put_parameters(
417402
else:
418403
updatedTable = _create_empty_table()
419404

420-
writer, _ = client.do_put(command_descriptor, updatedTable.schema, options)
405+
writer, _ = client.do_put(command_descriptor, updatedTable.schema)
421406

422407
writer.write_table(updatedTable)
423408
writer.close()

sdk/python/feast/infra/registry/remote.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from feast.infra.infra_object import Infra
1616
from feast.infra.registry.base_registry import BaseRegistry
1717
from feast.on_demand_feature_view import OnDemandFeatureView
18+
from feast.permissions.auth.auth_type import AuthType
1819
from feast.permissions.auth_model import (
1920
AuthConfig,
2021
NoAuthConfig,
@@ -48,13 +49,12 @@ def __init__(
4849
repo_path: Optional[Path],
4950
auth_config: AuthConfig = NoAuthConfig(),
5051
):
51-
auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config)
5252
self.auth_config = auth_config
5353
channel = grpc.insecure_channel(registry_config.path)
54-
self.intercepted_channel = grpc.intercept_channel(
55-
channel, auth_header_interceptor
56-
)
57-
self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.intercepted_channel)
54+
if self.auth_config.type != AuthType.NONE.value:
55+
auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config)
56+
channel = grpc.intercept_channel(channel, auth_header_interceptor)
57+
self.stub = RegistryServer_pb2_grpc.RegistryServerStub(channel)
5858

5959
def apply_entity(self, entity: Entity, project: str, commit: bool = True):
6060
request = RegistryServer_pb2.ApplyEntityRequest(
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import pyarrow.flight as fl
2+
3+
from feast.permissions.auth.auth_type import AuthType
4+
from feast.permissions.auth_model import AuthConfig
5+
from feast.permissions.client.auth_client_manager_factory import get_auth_token
6+
7+
8+
class FlightBearerTokenInterceptor(fl.ClientMiddleware):
9+
def __init__(self, auth_config: AuthConfig):
10+
super().__init__()
11+
self.auth_config = auth_config
12+
13+
def call_completed(self, exception):
14+
pass
15+
16+
def received_headers(self, headers):
17+
pass
18+
19+
def sending_headers(self):
20+
access_token = get_auth_token(self.auth_config)
21+
return {b"authorization": b"Bearer " + access_token.encode("utf-8")}
22+
23+
24+
class FlightAuthInterceptorFactory(fl.ClientMiddlewareFactory):
25+
def __init__(self, auth_config: AuthConfig):
26+
super().__init__()
27+
self.auth_config = auth_config
28+
29+
def start_call(self, info):
30+
return FlightBearerTokenInterceptor(self.auth_config)
31+
32+
33+
def build_arrow_flight_client(host: str, port, auth_config: AuthConfig):
34+
if auth_config.type != AuthType.NONE.value:
35+
middleware_factory = FlightAuthInterceptorFactory(auth_config)
36+
return fl.FlightClient(f"grpc://{host}:{port}", middleware=[middleware_factory])
37+
else:
38+
return fl.FlightClient(f"grpc://{host}:{port}")
Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,8 @@
11
from abc import ABC, abstractmethod
22

3-
from feast.permissions.auth.auth_type import AuthType
4-
from feast.permissions.auth_model import (
5-
AuthConfig,
6-
KubernetesAuthConfig,
7-
OidcAuthConfig,
8-
)
9-
103

114
class AuthenticationClientManager(ABC):
125
@abstractmethod
136
def get_token(self) -> str:
147
"""Retrieves the token based on the authentication type configuration"""
158
pass
16-
17-
18-
def get_auth_client_manager(auth_config: AuthConfig) -> AuthenticationClientManager:
19-
if auth_config.type == AuthType.OIDC.value:
20-
assert isinstance(auth_config, OidcAuthConfig)
21-
22-
from feast.permissions.client.oidc_authentication_client_manager import (
23-
OidcAuthClientManager,
24-
)
25-
26-
return OidcAuthClientManager(auth_config)
27-
elif auth_config.type == AuthType.KUBERNETES.value:
28-
assert isinstance(auth_config, KubernetesAuthConfig)
29-
30-
from feast.permissions.client.kubernetes_auth_client_manager import (
31-
KubernetesAuthClientManager,
32-
)
33-
34-
return KubernetesAuthClientManager(auth_config)
35-
else:
36-
raise RuntimeError(
37-
f"No Auth client manager implemented for the auth type:${auth_config.type}"
38-
)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from feast.permissions.auth.auth_type import AuthType
2+
from feast.permissions.auth_model import (
3+
AuthConfig,
4+
KubernetesAuthConfig,
5+
OidcAuthConfig,
6+
)
7+
from feast.permissions.client.auth_client_manager import AuthenticationClientManager
8+
from feast.permissions.client.kubernetes_auth_client_manager import (
9+
KubernetesAuthClientManager,
10+
)
11+
from feast.permissions.client.oidc_authentication_client_manager import (
12+
OidcAuthClientManager,
13+
)
14+
15+
16+
def get_auth_client_manager(auth_config: AuthConfig) -> AuthenticationClientManager:
17+
if auth_config.type == AuthType.OIDC.value:
18+
assert isinstance(auth_config, OidcAuthConfig)
19+
return OidcAuthClientManager(auth_config)
20+
elif auth_config.type == AuthType.KUBERNETES.value:
21+
assert isinstance(auth_config, KubernetesAuthConfig)
22+
return KubernetesAuthClientManager(auth_config)
23+
else:
24+
raise RuntimeError(
25+
f"No Auth client manager implemented for the auth type:${auth_config.type}"
26+
)
27+
28+
29+
def get_auth_token(auth_config: AuthConfig) -> str:
30+
return get_auth_client_manager(auth_config).get_token()

sdk/python/feast/permissions/client/grpc_client_auth_interceptor.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
import grpc
44

5-
from feast.permissions.auth.auth_type import AuthType
65
from feast.permissions.auth_model import AuthConfig
7-
from feast.permissions.client.auth_client_manager import get_auth_client_manager
6+
from feast.permissions.client.auth_client_manager_factory import get_auth_token
87

98
logger = logging.getLogger(__name__)
109

@@ -43,16 +42,11 @@ def intercept_stream_stream(
4342
return continuation(client_call_details, request_iterator)
4443

4544
def _append_auth_header_metadata(self, client_call_details):
46-
if self._auth_type.type is not AuthType.NONE.value:
47-
logger.info(
48-
f"Intercepted the grpc api method {client_call_details.method} call to inject Authorization header "
49-
f"token. "
50-
)
51-
metadata = client_call_details.metadata or []
52-
auth_client_manager = get_auth_client_manager(self._auth_type)
53-
access_token = auth_client_manager.get_token()
54-
metadata.append(
55-
(b"authorization", b"Bearer " + access_token.encode("utf-8"))
56-
)
57-
client_call_details = client_call_details._replace(metadata=metadata)
45+
logger.debug(
46+
"Intercepted the grpc api method call to inject Authorization header "
47+
)
48+
metadata = client_call_details.metadata or []
49+
access_token = get_auth_token(self._auth_type)
50+
metadata.append((b"authorization", b"Bearer " + access_token.encode("utf-8")))
51+
client_call_details = client_call_details._replace(metadata=metadata)
5852
return client_call_details

0 commit comments

Comments
 (0)