Skip to content

Commit e02c4b6

Browse files
David Cavazosrszperkweinmeister
authored
dataflow: add streaming RunInference sample (GoogleCloudPlatform#9926)
* dataflow: add streaming RunInference sample * initial test setup * added new fixtures * add tests * simplify args * update dependencies * fix lint and support conftest * move tests * move tests * move requirements-test.txt * add pythonpath * fix cmd arguments * use pytest-parallel * fix load in vertex * use gsutil * switch back to pytest-xdist * create directory if not exists * convert gcs to local path * update dependencies * update python version * add docstrings * add documentation * disable streaming tests * fix tests * build image with cloud build * fix args * prebuild image * rename image * build image manually * clean up * clean up * experiment without prebuilding the container * remove custom container * update readme * clean up * add warning * adjust vertex machine type * remove shebang * do not use a custom container * rename script * remove container * Update dataflow/run-inference/README.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> * Update dataflow/run-inference/README.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> * Update dataflow/run-inference/README.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> * Update dataflow/run-inference/README.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> * Update dataflow/run-inference/README.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> * Update dataflow/run-inference/README.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> * Update dataflow/run-inference/README.md Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> * uncomment temporary changes * update requirements * update flags to not oom * test on multiple inputs * address review comments --------- Co-authored-by: Rebecca Szper <98840847+rszper@users.noreply.github.com> Co-authored-by: Karl Weinmeister <11586922+kweinmeister@users.noreply.github.com>
1 parent 6b1b37f commit e02c4b6

File tree

10 files changed

+965
-6
lines changed

10 files changed

+965
-6
lines changed

dataflow/conftest.py

Lines changed: 272 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
from __future__ import annotations
1414

15-
from collections.abc import Callable, Iterable
15+
from collections.abc import Callable, Iterator
1616
from dataclasses import dataclass
17-
from google.api_core.exceptions import NotFound
17+
from datetime import datetime
1818
import itertools
1919
import json
2020
import logging
@@ -27,16 +27,276 @@
2727
from typing import Any
2828
import uuid
2929

30+
from google.api_core import retry
3031
import pytest
3132

33+
TIMEOUT_SEC = 30 * 60 # 30 minutes (in seconds)
34+
35+
36+
@pytest.fixture(scope="session")
37+
def project() -> str:
38+
# This is set by the testing infrastructure.
39+
project = os.environ["GOOGLE_CLOUD_PROJECT"]
40+
run_cmd("gcloud", "config", "set", "project", project)
41+
42+
# Since everything requires the project, let's confiugre and show some
43+
# debugging information here.
44+
run_cmd("gcloud", "version")
45+
run_cmd("gcloud", "config", "list")
46+
return project
47+
48+
49+
@pytest.fixture(scope="session")
50+
def location() -> str:
51+
# Override for local testing.
52+
return os.environ.get("GOOGLE_CLOUD_LOCATION", "us-central1")
53+
54+
55+
@pytest.fixture(scope="session")
56+
def unique_id() -> str:
57+
id = uuid.uuid4().hex[0:6]
58+
print(f"unique_id: {id}")
59+
return id
60+
61+
62+
@pytest.fixture(scope="session")
63+
def unique_name(test_name: str, unique_id: str) -> str:
64+
return f"{test_name.replace('/', '-')}-{unique_id}"
65+
66+
67+
@pytest.fixture(scope="session")
68+
def bucket_name(test_name: str, location: str, unique_id: str) -> Iterator[str]:
69+
# Override for local testing.
70+
if "GOOGLE_CLOUD_BUCKET" in os.environ:
71+
bucket_name = os.environ["GOOGLE_CLOUD_BUCKET"]
72+
print(f"bucket_name: {bucket_name} (from GOOGLE_CLOUD_BUCKET)")
73+
yield bucket_name
74+
return
75+
76+
from google.cloud import storage
77+
78+
storage_client = storage.Client()
79+
bucket_name = f"{test_name.replace('/', '-')}-{unique_id}"
80+
bucket = storage_client.create_bucket(bucket_name, location=location)
81+
82+
print(f"bucket_name: {bucket_name}")
83+
yield bucket_name
84+
85+
# Try to remove all files before deleting the bucket.
86+
# Deleting a bucket with too many files results in an error.
87+
try:
88+
run_cmd("gsutil", "-m", "rm", "-rf", f"gs://{bucket_name}/*")
89+
except RuntimeError:
90+
# If no files were found and it fails, ignore the error.
91+
pass
92+
93+
# Delete the bucket.
94+
bucket.delete(force=True)
95+
96+
97+
@pytest.fixture(scope="session")
98+
def pubsub_topic(
99+
test_name: str, project: str, unique_id: str
100+
) -> Iterator[Callable[[str], str]]:
101+
from google.cloud import pubsub
102+
103+
publisher = pubsub.PublisherClient()
104+
created_topics = []
105+
106+
def create_topic(name: str) -> str:
107+
unique_name = f"{test_name.replace('/', '-')}-{name}-{unique_id}"
108+
topic_path = publisher.topic_path(project, unique_name)
109+
topic = publisher.create_topic(name=topic_path)
110+
111+
print(f"pubsub_topic created: {topic.name}")
112+
created_topics.append(topic.name)
113+
return topic.name
114+
115+
yield create_topic
116+
117+
for topic_path in created_topics:
118+
publisher.delete_topic(topic=topic_path)
119+
print(f"pubsub_topic deleted: {topic_path}")
120+
121+
122+
@pytest.fixture(scope="session")
123+
def pubsub_subscription(
124+
test_name: str, project: str, unique_id: str
125+
) -> Iterator[Callable[[str, str], str]]:
126+
from google.cloud import pubsub
127+
128+
subscriber = pubsub.SubscriberClient()
129+
created_subscriptions = []
130+
131+
def create_subscription(name: str, topic_path: str) -> str:
132+
unique_name = f"{test_name.replace('/', '-')}-{name}-{unique_id}"
133+
subscription_path = subscriber.subscription_path(project, unique_name)
134+
subscription = subscriber.create_subscription(
135+
name=subscription_path, topic=topic_path
136+
)
137+
138+
print(f"pubsub_subscription created: {subscription.name}")
139+
created_subscriptions.append(subscription.name)
140+
return subscription.name
141+
142+
yield create_subscription
143+
144+
for subscription_path in created_subscriptions:
145+
subscriber.delete_subscription(subscription=subscription_path)
146+
print(f"pubsub_subscription deleted: {subscription_path}")
147+
148+
149+
def pubsub_publish(topic_path: str, messages: list[str]) -> None:
150+
from google.cloud import pubsub
151+
152+
publisher = pubsub.PublisherClient()
153+
futures = [publisher.publish(topic_path, msg.encode("utf-8")) for msg in messages]
154+
_ = [future.result() for future in futures] # wait synchronously
155+
print(f"pubsub_publish {len(messages)} message(s) to {topic_path}:")
156+
for msg in messages:
157+
print(f"- {repr(msg)}")
158+
159+
160+
@retry.Retry(retry.if_exception_type(ValueError), timeout=TIMEOUT_SEC)
161+
def pubsub_wait_for_messages(subscription_path: str) -> list[str]:
162+
from google.cloud import pubsub
163+
164+
subscriber = pubsub.SubscriberClient()
165+
with subscriber:
166+
response = subscriber.pull(subscription=subscription_path, max_messages=10)
167+
messages = [m.message.data.decode("utf-8") for m in response.received_messages]
168+
if not messages:
169+
raise ValueError("pubsub_wait_for_messages no messages received")
170+
171+
print(f"pubsub_receive got {len(messages)} message(s)")
172+
for msg in messages:
173+
print(f"- {repr(msg)}")
174+
175+
ack_ids = [m.ack_id for m in response.received_messages]
176+
subscriber.acknowledge(subscription=subscription_path, ack_ids=ack_ids)
177+
print(f"pubsub_receive ack messages")
178+
return messages
179+
180+
181+
def dataflow_job_url(project: str, location: str, job_id: str) -> str:
182+
return f"https://console.cloud.google.com/dataflow/jobs/{location}/{job_id}?project={project}"
183+
184+
185+
@retry.Retry(retry.if_exception_type(LookupError), timeout=TIMEOUT_SEC)
186+
def dataflow_find_job_by_name(project: str, location: str, job_name: str) -> str:
187+
from google.cloud import dataflow_v1beta3 as dataflow
188+
189+
# https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.services.jobs_v1_beta3.JobsV1Beta3Client#google_cloud_dataflow_v1beta3_services_jobs_v1_beta3_JobsV1Beta3Client_list_jobs
190+
dataflow_client = dataflow.JobsV1Beta3Client()
191+
request = dataflow.ListJobsRequest(
192+
project_id=project,
193+
location=location,
194+
)
195+
for job in dataflow_client.list_jobs(request):
196+
if job.name == job_name:
197+
return job.id
198+
raise LookupError(f"dataflow_find_job_by_name job name not found: {job_name}")
199+
200+
201+
@retry.Retry(retry.if_exception_type(ValueError), timeout=TIMEOUT_SEC)
202+
def dataflow_wait_until_running(project: str, location: str, job_id: str) -> str:
203+
from google.cloud import dataflow_v1beta3 as dataflow
204+
from google.cloud.dataflow_v1beta3.types import JobView, JobState
205+
206+
# https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.services.jobs_v1_beta3.JobsV1Beta3Client#google_cloud_dataflow_v1beta3_services_jobs_v1_beta3_JobsV1Beta3Client_get_job
207+
dataflow_client = dataflow.JobsV1Beta3Client()
208+
request = dataflow.GetJobRequest(
209+
project_id=project,
210+
location=location,
211+
job_id=job_id,
212+
view=JobView.JOB_VIEW_SUMMARY,
213+
)
214+
response = dataflow_client.get_job(request)
215+
216+
job_url = dataflow_job_url(project, location, job_id)
217+
state = response.current_state
218+
if state == JobState.JOB_STATE_FAILED:
219+
raise RuntimeError(f"Dataflow job failed unexpectedly\n{job_url}")
220+
if state != JobState.JOB_STATE_RUNNING:
221+
raise ValueError(f"Dataflow job is not running, state: {state.name}\n{job_url}")
222+
return state.name
223+
224+
225+
def dataflow_num_workers(project: str, location: str, job_id: str) -> int:
226+
from google.cloud import dataflow_v1beta3 as dataflow
227+
from google.cloud.dataflow_v1beta3.types import JobMessageImportance
228+
229+
# https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.services.messages_v1_beta3.MessagesV1Beta3Client#google_cloud_dataflow_v1beta3_services_messages_v1_beta3_MessagesV1Beta3Client_list_job_messages
230+
dataflow_client = dataflow.MessagesV1Beta3Client()
231+
request = dataflow.ListJobMessagesRequest(
232+
project_id=project,
233+
location=location,
234+
job_id=job_id,
235+
minimum_importance=JobMessageImportance.JOB_MESSAGE_BASIC,
236+
)
237+
238+
response = dataflow_client.list_job_messages(request)._response
239+
num_workers = [event.current_num_workers for event in response.autoscaling_events]
240+
if num_workers:
241+
return num_workers[-1]
242+
return 0
243+
244+
245+
def dataflow_cancel_job(project: str, location: str, job_id: str) -> None:
246+
from google.cloud import dataflow_v1beta3 as dataflow
247+
from google.cloud.dataflow_v1beta3.types import Job, JobState
248+
249+
# https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.services.jobs_v1_beta3.JobsV1Beta3Client#google_cloud_dataflow_v1beta3_services_jobs_v1_beta3_JobsV1Beta3Client_update_job
250+
dataflow_client = dataflow.JobsV1Beta3Client()
251+
request = dataflow.UpdateJobRequest(
252+
project_id=project,
253+
location=location,
254+
job_id=job_id,
255+
job=Job(requested_state=JobState.JOB_STATE_CANCELLED),
256+
)
257+
response = dataflow_client.update_job(request=request)
258+
print(response)
259+
260+
261+
@retry.Retry(retry.if_exception_type(AssertionError), timeout=TIMEOUT_SEC)
262+
def wait_until(condition: Callable[[], bool], message: str) -> None:
263+
assert condition(), message
264+
265+
266+
def run_cmd(*cmd: str) -> subprocess.CompletedProcess:
267+
try:
268+
print(f"run_cmd: {cmd}")
269+
start = datetime.now()
270+
p = subprocess.run(
271+
cmd,
272+
check=True,
273+
stdout=subprocess.PIPE,
274+
stderr=subprocess.PIPE,
275+
)
276+
print(p.stderr.decode("utf-8").strip())
277+
print(p.stdout.decode("utf-8").strip())
278+
elapsed = (datetime.now() - start).seconds
279+
minutes = int(elapsed / 60)
280+
seconds = elapsed - minutes * 60
281+
print(f"-- run_cmd `{cmd[0]}` finished in {minutes}m {seconds}s")
282+
return p
283+
except subprocess.CalledProcessError as e:
284+
# Include the error message from the failed command.
285+
print(e.stderr.decode("utf-8"))
286+
print(e.stdout.decode("utf-8"))
287+
raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e
288+
289+
290+
# ---- FOR BACKWARDS COMPATIBILITY ONLY, prefer fixture-style ---- #
291+
32292
# Default options.
33293
UUID = uuid.uuid4().hex[0:6]
34294
PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"]
35295
REGION = "us-central1"
36296

37-
TIMEOUT_SEC = 30 * 60 # 30 minutes in seconds
38297
POLL_INTERVAL_SEC = 60 # 1 minute in seconds
39298
LIST_PAGE_SIZE = 100
299+
TIMEOUT_SEC = 30 * 60 # 30 minutes in seconds
40300

41301
HYPHEN_NAME_RE = re.compile(r"[^\w\d-]+")
42302
UNDERSCORE_NAME_RE = re.compile(r"[^\w\d_]+")
@@ -73,6 +333,11 @@ def wait_until(
73333

74334
@staticmethod
75335
def storage_bucket(name: str) -> str:
336+
if bucket_name := os.environ.get("GOOGLE_CLOUD_BUCKET"):
337+
logging.warning(f"Using bucket from GOOGLE_CLOUD_BUCKET: {bucket_name}")
338+
yield bucket_name
339+
return # don't delete
340+
76341
from google.cloud import storage
77342

78343
storage_client = storage.Client()
@@ -100,6 +365,7 @@ def bigquery_dataset(
100365
project: str = PROJECT,
101366
location: str = REGION,
102367
) -> str:
368+
from google.api_core.exceptions import NotFound
103369
from google.cloud import bigquery
104370

105371
bigquery_client = bigquery.Client()
@@ -148,7 +414,7 @@ def bigquery_table_exists(
148414
return False
149415

150416
@staticmethod
151-
def bigquery_query(query: str, region: str = REGION) -> Iterable[dict[str, Any]]:
417+
def bigquery_query(query: str, region: str = REGION) -> Iterator[dict[str, Any]]:
152418
from google.cloud import bigquery
153419

154420
bigquery_client = bigquery.Client()
@@ -332,7 +598,7 @@ def dataflow_job_url(
332598
@staticmethod
333599
def dataflow_jobs_list(
334600
project: str = PROJECT, page_size: int = 30
335-
) -> Iterable[dict]:
601+
) -> Iterator[dict]:
336602
from googleapiclient.discovery import build
337603

338604
dataflow = build("dataflow", "v1b3")
@@ -390,7 +656,7 @@ def dataflow_jobs_wait(
390656
project: str = PROJECT,
391657
region: str = REGION,
392658
target_states: set[str] = {"JOB_STATE_DONE"},
393-
timeout_sec: str = TIMEOUT_SEC,
659+
timeout_sec: int = TIMEOUT_SEC,
394660
poll_interval_sec: int = POLL_INTERVAL_SEC,
395661
) -> str | None:
396662
"""For a list of all the valid states:

dataflow/run-inference/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
state_dict.pt

0 commit comments

Comments
 (0)