Skip to content

Commit fda56ad

Browse files
authored
Add project whitelist (#57)
* Add project whitelist Signed-off-by: Terence Lim <terencelimxp@gmail.com> * Shift project whitelist logic to jobservice Signed-off-by: Terence Lim <terencelimxp@gmail.com> * Fix tests Signed-off-by: Terence Lim <terencelimxp@gmail.com> * Add whitelist project tests Signed-off-by: Terence Lim <terencelimxp@gmail.com> * Fix flaky test Signed-off-by: Terence Lim <terencelimxp@gmail.com> * Use config instead of js client Signed-off-by: Terence Lim <terencelimxp@gmail.com> * Remove unnecessary code Signed-off-by: Terence Lim <terencelimxp@gmail.com>
1 parent 5730cf4 commit fda56ad

File tree

4 files changed

+90
-26
lines changed

4 files changed

+90
-26
lines changed

python/feast_spark/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ class ConfigOptions(metaclass=ConfigMeta):
157157
#: Log path of EMR cluster
158158
EMR_LOG_LOCATION: Optional[str] = None
159159

160+
#: Whitelisted Feast projects
161+
WHITELISTED_PROJECTS: Optional[str] = None
162+
160163
def defaults(self):
161164
return {
162165
k: getattr(self, k)

python/feast_spark/job_service.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import traceback
77
from concurrent.futures import ThreadPoolExecutor
8-
from typing import Dict, List, Tuple, cast
8+
from typing import Dict, List, Optional, Tuple, cast
99

1010
import grpc
1111
from google.api_core.exceptions import FailedPrecondition
@@ -98,10 +98,29 @@ class JobServiceServicer(JobService_pb2_grpc.JobServiceServicer):
9898
def __init__(self, client: Client):
9999
self.client = client
100100

101+
@property
102+
def _whitelisted_projects(self) -> Optional[List[str]]:
103+
if self.client.config.exists(opt.WHITELISTED_PROJECTS):
104+
whitelisted_projects = self.client.config.get(opt.WHITELISTED_PROJECTS)
105+
return whitelisted_projects.split(",")
106+
return None
107+
108+
def is_whitelisted(self, project: str):
109+
# Whitelisted projects not specified, allow all projects
110+
if not self._whitelisted_projects:
111+
return True
112+
return project in self._whitelisted_projects
113+
101114
def StartOfflineToOnlineIngestionJob(
102115
self, request: StartOfflineToOnlineIngestionJobRequest, context
103116
):
104117
"""Start job to ingest data from offline store into online store"""
118+
119+
if not self.is_whitelisted(request.project):
120+
raise ValueError(
121+
f"Project {request.project} is not whitelisted. Please contact your Feast administrator to whitelist it."
122+
)
123+
105124
feature_table = self.client.feature_store.get_feature_table(
106125
request.table_name, request.project
107126
)
@@ -125,6 +144,12 @@ def StartOfflineToOnlineIngestionJob(
125144

126145
def GetHistoricalFeatures(self, request: GetHistoricalFeaturesRequest, context):
127146
"""Produce a training dataset, return a job id that will provide a file reference"""
147+
148+
if not self.is_whitelisted(request.project):
149+
raise ValueError(
150+
f"Project {request.project} is not whitelisted. Please contact your Feast administrator to whitelist it."
151+
)
152+
128153
job = start_historical_feature_retrieval_job(
129154
client=self.client,
130155
project=request.project,
@@ -152,6 +177,11 @@ def StartStreamToOnlineIngestionJob(
152177
):
153178
"""Start job to ingest data from stream into online store"""
154179

180+
if not self.is_whitelisted(request.project):
181+
raise ValueError(
182+
f"Project {request.project} is not whitelisted. Please contact your Feast administrator to whitelist it."
183+
)
184+
155185
feature_table = self.client.feature_store.get_feature_table(
156186
request.table_name, request.project
157187
)
@@ -196,6 +226,12 @@ def StartStreamToOnlineIngestionJob(
196226

197227
def ListJobs(self, request, context):
198228
"""List all types of jobs"""
229+
230+
if not self.is_whitelisted(request.project):
231+
raise ValueError(
232+
f"Project {request.project} is not whitelisted. Please contact your Feast administrator to whitelist it."
233+
)
234+
199235
jobs = list_jobs(
200236
include_terminated=request.include_terminated,
201237
project=request.project,
@@ -326,6 +362,13 @@ def ensure_stream_ingestion_jobs(client: Client, all_projects: bool):
326362
if all_projects
327363
else [client.feature_store.project]
328364
)
365+
if client.config.exists(opt.WHITELISTED_PROJECTS):
366+
whitelisted_projects = client.config.get(opt.WHITELISTED_PROJECTS)
367+
if whitelisted_projects:
368+
whitelisted_projects = whitelisted_projects.split(",")
369+
projects = [
370+
project for project in projects if project in whitelisted_projects
371+
]
329372

330373
expected_job_hash_to_tables = _get_expected_job_hash_to_tables(client, projects)
331374

python/tests/test_streaming_job_scheduling.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919

2020
@pytest.fixture
2121
def feast_client():
22-
c = FeastClient(job_service_pause_between_jobs=0)
23-
c.list_projects = Mock(return_value=["default"])
22+
c = FeastClient(
23+
job_service_pause_between_jobs=0,
24+
options={"whitelisted_projects": "default,ride"},
25+
)
26+
c.list_projects = Mock(return_value=["default", "ride", "invalid_project"])
2427
c.list_feature_tables = Mock()
2528

2629
yield c
@@ -51,15 +54,18 @@ def feature_table():
5154

5255

5356
class SimpleStreamingIngestionJob(StreamIngestionJob):
54-
def __init__(self, id: str, feature_table: FeatureTable, status: SparkJobStatus):
57+
def __init__(
58+
self, id: str, project: str, feature_table: FeatureTable, status: SparkJobStatus
59+
):
5560
self._id = id
5661
self._feature_table = feature_table
62+
self._project = project
5763
self._status = status
5864
self._hash = hash
5965

6066
def get_hash(self) -> str:
6167
source = _source_to_argument(self._feature_table.stream_source, Config())
62-
feature_table = _feature_table_to_argument(None, "default", self._feature_table) # type: ignore
68+
feature_table = _feature_table_to_argument(None, self._project, self._feature_table) # type: ignore
6369

6470
job_json = json.dumps(
6571
{"source": source, "feature_table": feature_table}, sort_keys=True,
@@ -90,18 +96,21 @@ def test_new_job_creation(spark_client, feature_table):
9096

9197
ensure_stream_ingestion_jobs(spark_client, all_projects=True)
9298

93-
spark_client.start_stream_to_online_ingestion.assert_called_once_with(
94-
feature_table, [], project="default"
95-
)
99+
assert spark_client.start_stream_to_online_ingestion.call_count == 2
96100

97101

98102
def test_no_changes(spark_client, feature_table):
99103
""" Feature Table spec is the same """
100104

101-
job = SimpleStreamingIngestionJob("", feature_table, SparkJobStatus.IN_PROGRESS)
105+
job = SimpleStreamingIngestionJob(
106+
"", "default", feature_table, SparkJobStatus.IN_PROGRESS
107+
)
108+
job2 = SimpleStreamingIngestionJob(
109+
"", "ride", feature_table, SparkJobStatus.IN_PROGRESS
110+
)
102111

103112
spark_client.feature_store.list_feature_tables.return_value = [feature_table]
104-
spark_client.list_jobs.return_value = [job]
113+
spark_client.list_jobs.return_value = [job, job2]
105114

106115
ensure_stream_ingestion_jobs(spark_client, all_projects=True)
107116

@@ -114,41 +123,43 @@ def test_update_existing_job(spark_client, feature_table):
114123

115124
new_ft = copy.deepcopy(feature_table)
116125
new_ft.stream_source._kafka_options.topic = "new_t"
117-
job = SimpleStreamingIngestionJob("", feature_table, SparkJobStatus.IN_PROGRESS)
126+
job = SimpleStreamingIngestionJob(
127+
"", "default", feature_table, SparkJobStatus.IN_PROGRESS
128+
)
118129

119130
spark_client.feature_store.list_feature_tables.return_value = [new_ft]
120131
spark_client.list_jobs.return_value = [job]
121132

122133
ensure_stream_ingestion_jobs(spark_client, all_projects=True)
123134

124135
assert job.get_status() == SparkJobStatus.COMPLETED
125-
spark_client.start_stream_to_online_ingestion.assert_called_once_with(
126-
new_ft, [], project="default"
127-
)
136+
assert spark_client.start_stream_to_online_ingestion.call_count == 2
128137

129138

130139
def test_not_cancelling_starting_job(spark_client, feature_table):
131140
""" Feature Table spec was updated but previous version is still starting """
132141

133142
new_ft = copy.deepcopy(feature_table)
134143
new_ft.stream_source._kafka_options.topic = "new_t"
135-
job = SimpleStreamingIngestionJob("", feature_table, SparkJobStatus.STARTING)
144+
job = SimpleStreamingIngestionJob(
145+
"", "default", feature_table, SparkJobStatus.STARTING
146+
)
136147

137148
spark_client.feature_store.list_feature_tables.return_value = [new_ft]
138149
spark_client.list_jobs.return_value = [job]
139150

140151
ensure_stream_ingestion_jobs(spark_client, all_projects=True)
141152

142153
assert job.get_status() == SparkJobStatus.STARTING
143-
spark_client.start_stream_to_online_ingestion.assert_called_once_with(
144-
new_ft, [], project="default"
145-
)
154+
assert spark_client.start_stream_to_online_ingestion.call_count == 2
146155

147156

148157
def test_not_retrying_failed_job(spark_client, feature_table):
149158
""" Job has failed on previous try """
150159

151-
job = SimpleStreamingIngestionJob("", feature_table, SparkJobStatus.FAILED)
160+
job = SimpleStreamingIngestionJob(
161+
"", "default", feature_table, SparkJobStatus.FAILED
162+
)
152163

153164
spark_client.feature_store.list_feature_tables.return_value = [feature_table]
154165
spark_client.list_jobs.return_value = [job]
@@ -157,29 +168,33 @@ def test_not_retrying_failed_job(spark_client, feature_table):
157168

158169
spark_client.list_jobs.assert_called_once_with(include_terminated=True)
159170
assert job.get_status() == SparkJobStatus.FAILED
160-
spark_client.start_stream_to_online_ingestion.assert_not_called()
171+
spark_client.start_stream_to_online_ingestion.assert_called_once_with(
172+
feature_table, [], project="ride"
173+
)
161174

162175

163176
def test_restarting_completed_job(spark_client, feature_table):
164177
""" Job has succesfully finished on previous try """
165-
job = SimpleStreamingIngestionJob("", feature_table, SparkJobStatus.COMPLETED)
178+
job = SimpleStreamingIngestionJob(
179+
"", "default", feature_table, SparkJobStatus.COMPLETED
180+
)
166181

167182
spark_client.feature_store.list_feature_tables.return_value = [feature_table]
168183
spark_client.list_jobs.return_value = [job]
169184

170185
ensure_stream_ingestion_jobs(spark_client, all_projects=True)
171186

172-
spark_client.start_stream_to_online_ingestion.assert_called_once_with(
173-
feature_table, [], project="default"
174-
)
187+
assert spark_client.start_stream_to_online_ingestion.call_count == 2
175188

176189

177190
def test_stopping_running_job(spark_client, feature_table):
178191
""" Streaming source was deleted """
179192
new_ft = copy.deepcopy(feature_table)
180193
new_ft.stream_source = None
181194

182-
job = SimpleStreamingIngestionJob("", feature_table, SparkJobStatus.IN_PROGRESS)
195+
job = SimpleStreamingIngestionJob(
196+
"", "default", feature_table, SparkJobStatus.IN_PROGRESS
197+
)
183198

184199
spark_client.feature_store.list_feature_tables.return_value = [new_ft]
185200
spark_client.list_jobs.return_value = [job]
@@ -194,7 +209,9 @@ def test_restarting_failed_jobs(feature_table):
194209
""" If configured - restart failed jobs """
195210

196211
feast_client = FeastClient(
197-
job_service_pause_between_jobs=0, job_service_retry_failed_jobs=True
212+
job_service_pause_between_jobs=0,
213+
job_service_retry_failed_jobs=True,
214+
options={"whitelisted_projects": "default,ride"},
198215
)
199216
feast_client.list_projects = Mock(return_value=["default"])
200217
feast_client.list_feature_tables = Mock()

spark/ingestion/src/test/scala/feast/ingestion/SparkSpec.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class SparkSpec extends UnitSpec with BeforeAndAfter {
3030
val sparkConf = new SparkConf()
3131
.setMaster("local[4]")
3232
.setAppName("Testing")
33+
.set("spark.driver.bindAddress", "localhost")
3334
.set("spark.default.parallelism", "8")
3435
.set(
3536
"spark.metrics.conf.*.sink.statsd.class",

0 commit comments

Comments
 (0)