From 24624a05a415f2ecd7ce962782d3c60732d67a9f Mon Sep 17 00:00:00 2001 From: David del Real Sifuentes Date: Mon, 15 Jun 2026 23:54:53 +0000 Subject: [PATCH] chore(dataflow/gemma): update dependencies and format code - Update tensorflow base image to 2.20.0-gpu and beam sdk to 3.11/2.74.0 - Update apache_beam, keras, keras_nlp, and protobuf dependencies - Update test dependencies including google-cloud-aiplatform, storage, and pytest - Format custom_model_gemma.py and e2e_test.py - Update ignored python versions in noxfile_config.py --- dataflow/gemma/Dockerfile | 4 +-- dataflow/gemma/custom_model_gemma.py | 22 +++++++++------ dataflow/gemma/e2e_test.py | 41 ++++++++++++++-------------- dataflow/gemma/noxfile_config.py | 8 ++---- dataflow/gemma/requirements-test.txt | 10 +++---- dataflow/gemma/requirements.txt | 9 +++--- 6 files changed, 49 insertions(+), 45 deletions(-) diff --git a/dataflow/gemma/Dockerfile b/dataflow/gemma/Dockerfile index b3472a56955..ebe142aa64d 100644 --- a/dataflow/gemma/Dockerfile +++ b/dataflow/gemma/Dockerfile @@ -15,7 +15,7 @@ # This uses Ubuntu with Python 3.11 # You can check the Python version for a given tensorflow # container at https://hub.docker.com/r/tensorflow/tensorflow/tags -ARG SERVING_BUILD_IMAGE=tensorflow/tensorflow:2.16.1-gpu +ARG SERVING_BUILD_IMAGE=tensorflow/tensorflow:2.20.0-gpu FROM ${SERVING_BUILD_IMAGE} @@ -29,7 +29,7 @@ RUN pip install --upgrade --no-cache-dir pip \ && pip install --no-cache-dir -r requirements.txt # Copy files from official SDK image, including script/dependencies. -COPY --from=apache/beam_python3.14_sdk:2.73.0 /opt/apache/beam /opt/apache/beam +COPY --from=apache/beam_python3.11_sdk:2.74.0 /opt/apache/beam /opt/apache/beam # Copy the model directory downloaded from Kaggle and the pipeline code. COPY gemma_2b gemma_2B diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index fbf0b975057..456a9680e67 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -35,7 +35,7 @@ def __init__( self, model_name: str = "gemma_2B", ): - """ Implementation of the ModelHandler interface for Gemma using text as input. + """Implementation of the ModelHandler interface for Gemma using text as input. Example Usage:: @@ -48,7 +48,7 @@ def __init__( self._env_vars = {} def share_model_across_processes(self) -> bool: - """ Indicates if the model should be loaded once-per-VM rather than + """Indicates if the model should be loaded once-per-VM rather than once-per-worker-process on a VM. Because Gemma is a large language model, this will always return True to avoid OOM errors. """ @@ -62,7 +62,7 @@ def run_inference( self, batch: Sequence[str], model: GemmaCausalLM, - inference_args: Optional[dict[str, Any]] = None + inference_args: Optional[dict[str, Any]] = None, ) -> Iterable[PredictionResult]: """Runs inferences on a batch of text strings. @@ -85,7 +85,8 @@ def run_inference( class FormatOutput(beam.DoFn): def process(self, element, *args, **kwargs): yield "Input: {input}, Output: {output}".format( - input=element.example, output=element.inference) + input=element.example, output=element.inference + ) if __name__ == "__main__": @@ -119,13 +120,16 @@ def process(self, element, *args, **kwargs): pipeline = beam.Pipeline(options=beam_options) _ = ( - pipeline | "Read Topic" >> - beam.io.ReadFromPubSub(subscription=args.messages_subscription) + pipeline + | "Read Topic" + >> beam.io.ReadFromPubSub(subscription=args.messages_subscription) | "Parse" >> beam.Map(lambda x: x.decode("utf-8")) - | "RunInference-Gemma" >> RunInference( + | "RunInference-Gemma" + >> RunInference( GemmaModelHandler(args.model_path) ) # Send the prompts to the model and get responses. | "Format Output" >> beam.ParDo(FormatOutput()) # Format the output. - | "Publish Result" >> - beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic)) + | "Publish Result" + >> beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic) + ) pipeline.run() diff --git a/dataflow/gemma/e2e_test.py b/dataflow/gemma/e2e_test.py index 6f65fb15959..43bc439b0b6 100644 --- a/dataflow/gemma/e2e_test.py +++ b/dataflow/gemma/e2e_test.py @@ -39,6 +39,7 @@ NOTE: For the tests to find the conftest in the testing infrastructure, add the PYTHONPATH to the "env" in your noxfile_config.py file. """ + from collections.abc import Callable, Iterator import conftest # python-docs-samples/dataflow/conftest.py @@ -70,8 +71,9 @@ def messages_topic(pubsub_topic: Callable[[str], str]) -> str: @pytest.fixture(scope="session") -def messages_subscription(pubsub_subscription: Callable[[str, str], str], - messages_topic: str) -> str: +def messages_subscription( + pubsub_subscription: Callable[[str, str], str], messages_topic: str +) -> str: return pubsub_subscription("messages", messages_topic) @@ -81,20 +83,21 @@ def responses_topic(pubsub_topic: Callable[[str], str]) -> str: @pytest.fixture(scope="session") -def responses_subscription(pubsub_subscription: Callable[[str, str], str], - responses_topic: str) -> str: +def responses_subscription( + pubsub_subscription: Callable[[str, str], str], responses_topic: str +) -> str: return pubsub_subscription("responses", responses_topic) @pytest.fixture(scope="session") def dataflow_job( - project: str, - bucket_name: str, - location: str, - unique_name: str, - container_image: str, - messages_subscription: str, - responses_topic: str, + project: str, + bucket_name: str, + location: str, + unique_name: str, + container_image: str, + messages_subscription: str, + responses_topic: str, ) -> Iterator[str]: # Launch the streaming Dataflow pipeline. conftest.run_cmd( @@ -127,20 +130,18 @@ def dataflow_job( @pytest.mark.timeout(3600) def test_pipeline_dataflow( - project: str, - location: str, - dataflow_job: str, - messages_topic: str, - responses_subscription: str, + project: str, + location: str, + dataflow_job: str, + messages_topic: str, + responses_subscription: str, ) -> None: print(f"Waiting for the Dataflow workers to start: {dataflow_job}") conftest.wait_until( - lambda: conftest.dataflow_num_workers(project, location, dataflow_job) - > 0, + lambda: conftest.dataflow_num_workers(project, location, dataflow_job) > 0, "workers are running", ) - num_workers = conftest.dataflow_num_workers(project, location, - dataflow_job) + num_workers = conftest.dataflow_num_workers(project, location, dataflow_job) print(f"Dataflow job num_workers: {num_workers}") messages = ["This is a test for a Python sample."] diff --git a/dataflow/gemma/noxfile_config.py b/dataflow/gemma/noxfile_config.py index 35321dbbdea..6641345788e 100644 --- a/dataflow/gemma/noxfile_config.py +++ b/dataflow/gemma/noxfile_config.py @@ -18,9 +18,7 @@ # You can opt out from the test for specific Python versions. # The Python version used is defined by the Dockerfile and the job # submission enviornment must match. - # Note: Docker-based sample, testing only against version specified in Dockerfile (3.14) - "ignored_versions": ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"], - "envs": { - "PYTHONPATH": ".." - }, + # Note: Docker-based sample, testing only against version specified in Dockerfile (3.11) + "ignored_versions": ["3.8", "3.9", "3.10"], + "envs": {"PYTHONPATH": ".."}, } diff --git a/dataflow/gemma/requirements-test.txt b/dataflow/gemma/requirements-test.txt index 238d774fdde..37f1bad2342 100644 --- a/dataflow/gemma/requirements-test.txt +++ b/dataflow/gemma/requirements-test.txt @@ -1,5 +1,5 @@ -google-cloud-aiplatform==1.49.0 -google-cloud-dataflow-client==0.8.10 -google-cloud-storage==2.16.0 -pytest==9.0.3; python_version >= "3.10" -pytest-timeout==2.3.1 \ No newline at end of file +google-cloud-aiplatform==1.157.0 +google-cloud-dataflow-client==0.14.0 +google-cloud-storage==3.12.0 +pytest==9.0.3 +pytest-timeout==2.4.0 \ No newline at end of file diff --git a/dataflow/gemma/requirements.txt b/dataflow/gemma/requirements.txt index 76fc60632ee..b2d60a3eced 100644 --- a/dataflow/gemma/requirements.txt +++ b/dataflow/gemma/requirements.txt @@ -1,4 +1,5 @@ -apache_beam[gcp]==2.54.0 -protobuf==4.25.0 -keras_nlp==0.8.2 -keras==3.0.5 \ No newline at end of file +protobuf==6.33.6 +apache_beam[gcp]==2.74.0 +keras==3.14.1 +keras_nlp==0.29.1 +pyOpenSSL==25.3.0 \ No newline at end of file