From ba309408559c548cdd01eea1e589a179203c8d51 Mon Sep 17 00:00:00 2001 From: David del Real Sifuentes Date: Mon, 15 Jun 2026 23:54:53 +0000 Subject: [PATCH 01/10] chore(dataflow/gemma): update dependencies and format code - Update tensorflow base image to 2.20.0-gpu and beam sdk to 3.11/2.74.0 - Update apache_beam, keras, keras_nlp, and protobuf dependencies - Update test dependencies including google-cloud-aiplatform, storage, and pytest - Format custom_model_gemma.py and e2e_test.py - Update ignored python versions in noxfile_config.py --- dataflow/gemma/Dockerfile | 4 +-- dataflow/gemma/custom_model_gemma.py | 22 +++++++++------ dataflow/gemma/e2e_test.py | 41 ++++++++++++++-------------- dataflow/gemma/noxfile_config.py | 8 ++---- dataflow/gemma/requirements-test.txt | 10 +++---- dataflow/gemma/requirements.txt | 9 +++--- 6 files changed, 49 insertions(+), 45 deletions(-) diff --git a/dataflow/gemma/Dockerfile b/dataflow/gemma/Dockerfile index b3472a56955..ebe142aa64d 100644 --- a/dataflow/gemma/Dockerfile +++ b/dataflow/gemma/Dockerfile @@ -15,7 +15,7 @@ # This uses Ubuntu with Python 3.11 # You can check the Python version for a given tensorflow # container at https://hub.docker.com/r/tensorflow/tensorflow/tags -ARG SERVING_BUILD_IMAGE=tensorflow/tensorflow:2.16.1-gpu +ARG SERVING_BUILD_IMAGE=tensorflow/tensorflow:2.20.0-gpu FROM ${SERVING_BUILD_IMAGE} @@ -29,7 +29,7 @@ RUN pip install --upgrade --no-cache-dir pip \ && pip install --no-cache-dir -r requirements.txt # Copy files from official SDK image, including script/dependencies. -COPY --from=apache/beam_python3.14_sdk:2.73.0 /opt/apache/beam /opt/apache/beam +COPY --from=apache/beam_python3.11_sdk:2.74.0 /opt/apache/beam /opt/apache/beam # Copy the model directory downloaded from Kaggle and the pipeline code. COPY gemma_2b gemma_2B diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index fbf0b975057..456a9680e67 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -35,7 +35,7 @@ def __init__( self, model_name: str = "gemma_2B", ): - """ Implementation of the ModelHandler interface for Gemma using text as input. + """Implementation of the ModelHandler interface for Gemma using text as input. Example Usage:: @@ -48,7 +48,7 @@ def __init__( self._env_vars = {} def share_model_across_processes(self) -> bool: - """ Indicates if the model should be loaded once-per-VM rather than + """Indicates if the model should be loaded once-per-VM rather than once-per-worker-process on a VM. Because Gemma is a large language model, this will always return True to avoid OOM errors. """ @@ -62,7 +62,7 @@ def run_inference( self, batch: Sequence[str], model: GemmaCausalLM, - inference_args: Optional[dict[str, Any]] = None + inference_args: Optional[dict[str, Any]] = None, ) -> Iterable[PredictionResult]: """Runs inferences on a batch of text strings. @@ -85,7 +85,8 @@ def run_inference( class FormatOutput(beam.DoFn): def process(self, element, *args, **kwargs): yield "Input: {input}, Output: {output}".format( - input=element.example, output=element.inference) + input=element.example, output=element.inference + ) if __name__ == "__main__": @@ -119,13 +120,16 @@ def process(self, element, *args, **kwargs): pipeline = beam.Pipeline(options=beam_options) _ = ( - pipeline | "Read Topic" >> - beam.io.ReadFromPubSub(subscription=args.messages_subscription) + pipeline + | "Read Topic" + >> beam.io.ReadFromPubSub(subscription=args.messages_subscription) | "Parse" >> beam.Map(lambda x: x.decode("utf-8")) - | "RunInference-Gemma" >> RunInference( + | "RunInference-Gemma" + >> RunInference( GemmaModelHandler(args.model_path) ) # Send the prompts to the model and get responses. | "Format Output" >> beam.ParDo(FormatOutput()) # Format the output. - | "Publish Result" >> - beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic)) + | "Publish Result" + >> beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic) + ) pipeline.run() diff --git a/dataflow/gemma/e2e_test.py b/dataflow/gemma/e2e_test.py index 6f65fb15959..43bc439b0b6 100644 --- a/dataflow/gemma/e2e_test.py +++ b/dataflow/gemma/e2e_test.py @@ -39,6 +39,7 @@ NOTE: For the tests to find the conftest in the testing infrastructure, add the PYTHONPATH to the "env" in your noxfile_config.py file. """ + from collections.abc import Callable, Iterator import conftest # python-docs-samples/dataflow/conftest.py @@ -70,8 +71,9 @@ def messages_topic(pubsub_topic: Callable[[str], str]) -> str: @pytest.fixture(scope="session") -def messages_subscription(pubsub_subscription: Callable[[str, str], str], - messages_topic: str) -> str: +def messages_subscription( + pubsub_subscription: Callable[[str, str], str], messages_topic: str +) -> str: return pubsub_subscription("messages", messages_topic) @@ -81,20 +83,21 @@ def responses_topic(pubsub_topic: Callable[[str], str]) -> str: @pytest.fixture(scope="session") -def responses_subscription(pubsub_subscription: Callable[[str, str], str], - responses_topic: str) -> str: +def responses_subscription( + pubsub_subscription: Callable[[str, str], str], responses_topic: str +) -> str: return pubsub_subscription("responses", responses_topic) @pytest.fixture(scope="session") def dataflow_job( - project: str, - bucket_name: str, - location: str, - unique_name: str, - container_image: str, - messages_subscription: str, - responses_topic: str, + project: str, + bucket_name: str, + location: str, + unique_name: str, + container_image: str, + messages_subscription: str, + responses_topic: str, ) -> Iterator[str]: # Launch the streaming Dataflow pipeline. conftest.run_cmd( @@ -127,20 +130,18 @@ def dataflow_job( @pytest.mark.timeout(3600) def test_pipeline_dataflow( - project: str, - location: str, - dataflow_job: str, - messages_topic: str, - responses_subscription: str, + project: str, + location: str, + dataflow_job: str, + messages_topic: str, + responses_subscription: str, ) -> None: print(f"Waiting for the Dataflow workers to start: {dataflow_job}") conftest.wait_until( - lambda: conftest.dataflow_num_workers(project, location, dataflow_job) - > 0, + lambda: conftest.dataflow_num_workers(project, location, dataflow_job) > 0, "workers are running", ) - num_workers = conftest.dataflow_num_workers(project, location, - dataflow_job) + num_workers = conftest.dataflow_num_workers(project, location, dataflow_job) print(f"Dataflow job num_workers: {num_workers}") messages = ["This is a test for a Python sample."] diff --git a/dataflow/gemma/noxfile_config.py b/dataflow/gemma/noxfile_config.py index 35321dbbdea..6641345788e 100644 --- a/dataflow/gemma/noxfile_config.py +++ b/dataflow/gemma/noxfile_config.py @@ -18,9 +18,7 @@ # You can opt out from the test for specific Python versions. # The Python version used is defined by the Dockerfile and the job # submission enviornment must match. - # Note: Docker-based sample, testing only against version specified in Dockerfile (3.14) - "ignored_versions": ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"], - "envs": { - "PYTHONPATH": ".." - }, + # Note: Docker-based sample, testing only against version specified in Dockerfile (3.11) + "ignored_versions": ["3.8", "3.9", "3.10"], + "envs": {"PYTHONPATH": ".."}, } diff --git a/dataflow/gemma/requirements-test.txt b/dataflow/gemma/requirements-test.txt index 238d774fdde..37f1bad2342 100644 --- a/dataflow/gemma/requirements-test.txt +++ b/dataflow/gemma/requirements-test.txt @@ -1,5 +1,5 @@ -google-cloud-aiplatform==1.49.0 -google-cloud-dataflow-client==0.8.10 -google-cloud-storage==2.16.0 -pytest==9.0.3; python_version >= "3.10" -pytest-timeout==2.3.1 \ No newline at end of file +google-cloud-aiplatform==1.157.0 +google-cloud-dataflow-client==0.14.0 +google-cloud-storage==3.12.0 +pytest==9.0.3 +pytest-timeout==2.4.0 \ No newline at end of file diff --git a/dataflow/gemma/requirements.txt b/dataflow/gemma/requirements.txt index 76fc60632ee..b2d60a3eced 100644 --- a/dataflow/gemma/requirements.txt +++ b/dataflow/gemma/requirements.txt @@ -1,4 +1,5 @@ -apache_beam[gcp]==2.54.0 -protobuf==4.25.0 -keras_nlp==0.8.2 -keras==3.0.5 \ No newline at end of file +protobuf==6.33.6 +apache_beam[gcp]==2.74.0 +keras==3.14.1 +keras_nlp==0.29.1 +pyOpenSSL==25.3.0 \ No newline at end of file From 7534737fcfdceb4aff1c2477386a75bc7eedde21 Mon Sep 17 00:00:00 2001 From: David del Real Sifuentes Date: Tue, 23 Jun 2026 21:46:12 +0000 Subject: [PATCH 02/10] Finished setting config for gemma sample. it's now working --- dataflow/gemma/e2e_test.py | 10 +++++++--- dataflow/gemma/noxfile_config.py | 2 +- dataflow/gemma/requirements-test.txt | 4 ++-- dataflow/gemma/requirements.txt | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dataflow/gemma/e2e_test.py b/dataflow/gemma/e2e_test.py index 43bc439b0b6..72a1bc3e131 100644 --- a/dataflow/gemma/e2e_test.py +++ b/dataflow/gemma/e2e_test.py @@ -34,7 +34,7 @@ OPTION B: Run tests with nox pip install nox - nox -s py-3.10 + nox -s py-3.11 NOTE: For the tests to find the conftest in the testing infrastructure, add the PYTHONPATH to the "env" in your noxfile_config.py file. @@ -47,8 +47,11 @@ import pytest -DATAFLOW_MACHINE_TYPE = "g2-standard-4" -GEMMA_GCS = "gs://perm-dataflow-gemma-example-testdata/gemma_2b" +DATAFLOW_MACHINE_TYPE = "g2-standard-8" +# TODO If testing locally, point this to a bucket you control +# and download the gemma assets +# GEMMA_GCS = "gs://perm-dataflow-gemma-example-testdata/gemma_2b" +GEMMA_GCS = "gs://test-bucket-for-gemma/assets_here/gemma_2b" NAME = "dataflow/gemma/streaming" @@ -111,6 +114,7 @@ def dataflow_job( f"--temp_location=gs://{bucket_name}/temp", f"--region={location}", f"--machine_type={DATAFLOW_MACHINE_TYPE}", + "--disk_size_gb=100", f"--sdk_container_image=gcr.io/{project}/{container_image}", "--dataflow_service_options=worker_accelerator=type:nvidia-l4;count:1;install-nvidia-driver:5xx", "--requirements_cache=skip", diff --git a/dataflow/gemma/noxfile_config.py b/dataflow/gemma/noxfile_config.py index 6641345788e..1706c31b186 100644 --- a/dataflow/gemma/noxfile_config.py +++ b/dataflow/gemma/noxfile_config.py @@ -19,6 +19,6 @@ # The Python version used is defined by the Dockerfile and the job # submission enviornment must match. # Note: Docker-based sample, testing only against version specified in Dockerfile (3.11) - "ignored_versions": ["3.8", "3.9", "3.10"], + "ignored_versions": ["3.8", "3.9", "3.10", "3.12", "3.13", "3.14"], "envs": {"PYTHONPATH": ".."}, } diff --git a/dataflow/gemma/requirements-test.txt b/dataflow/gemma/requirements-test.txt index 37f1bad2342..91653a8a423 100644 --- a/dataflow/gemma/requirements-test.txt +++ b/dataflow/gemma/requirements-test.txt @@ -1,5 +1,5 @@ -google-cloud-aiplatform==1.157.0 +google-cloud-aiplatform==1.158.0 google-cloud-dataflow-client==0.14.0 google-cloud-storage==3.12.0 pytest==9.0.3 -pytest-timeout==2.4.0 \ No newline at end of file +pytest-timeout==2.4.0 diff --git a/dataflow/gemma/requirements.txt b/dataflow/gemma/requirements.txt index b2d60a3eced..a88a18920dd 100644 --- a/dataflow/gemma/requirements.txt +++ b/dataflow/gemma/requirements.txt @@ -2,4 +2,4 @@ protobuf==6.33.6 apache_beam[gcp]==2.74.0 keras==3.14.1 keras_nlp==0.29.1 -pyOpenSSL==25.3.0 \ No newline at end of file +pyOpenSSL==25.3.0 From 69d67206ace467b93e1cf9c35f3f76491cd52ad2 Mon Sep 17 00:00:00 2001 From: David del Real Sifuentes Date: Tue, 23 Jun 2026 21:51:22 +0000 Subject: [PATCH 03/10] used isort to fix import order. --- dataflow/gemma/custom_model_gemma.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index 456a9680e67..5f38d2e091e 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -12,21 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable, Sequence - import logging - -from typing import Any -from typing import Optional +from collections.abc import Iterable, Sequence +from typing import Any, Optional import apache_beam as beam +import keras_nlp from apache_beam.ml.inference import utils -from apache_beam.ml.inference.base import ModelHandler -from apache_beam.ml.inference.base import PredictionResult -from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.base import (ModelHandler, PredictionResult, + RunInference) from apache_beam.options.pipeline_options import PipelineOptions - -import keras_nlp from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM From 504120adb73e6cc651589acd944c70a764f2be45 Mon Sep 17 00:00:00 2001 From: David del Real Sifuentes Date: Tue, 23 Jun 2026 21:54:50 +0000 Subject: [PATCH 04/10] import order again --- dataflow/gemma/custom_model_gemma.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index 5f38d2e091e..9dffc5ef310 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -12,16 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from collections.abc import Iterable, Sequence -from typing import Any, Optional +from typing import Any +from typing import Optional import apache_beam as beam -import keras_nlp from apache_beam.ml.inference import utils -from apache_beam.ml.inference.base import (ModelHandler, PredictionResult, - RunInference) +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference from apache_beam.options.pipeline_options import PipelineOptions + +import logging + +import keras_nlp from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM From 2277601a74766a7ed9f3d5433b9a010a3a9a7077 Mon Sep 17 00:00:00 2001 From: David del Real Sifuentes Date: Tue, 23 Jun 2026 21:57:25 +0000 Subject: [PATCH 05/10] more import order stuff --- dataflow/gemma/custom_model_gemma.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index 9dffc5ef310..e1c1225bccf 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -12,20 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from collections.abc import Iterable, Sequence from typing import Any from typing import Optional import apache_beam as beam +import keras_nlp from apache_beam.ml.inference import utils from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.options.pipeline_options import PipelineOptions - -import logging - -import keras_nlp from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM From 6c2771993a2d56129204b1c0c76ac381aeb06915 Mon Sep 17 00:00:00 2001 From: David del Real Sifuentes Date: Tue, 23 Jun 2026 22:02:42 +0000 Subject: [PATCH 06/10] even more linter stuff. --- dataflow/gemma/custom_model_gemma.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index e1c1225bccf..b424b94c3c6 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -12,18 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from collections.abc import Iterable, Sequence + +import logging + from typing import Any from typing import Optional import apache_beam as beam -import keras_nlp from apache_beam.ml.inference import utils from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.options.pipeline_options import PipelineOptions +import keras_nlp from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM From 2a177ee901dbc19c57a58b07454c962698c58885 Mon Sep 17 00:00:00 2001 From: David del Real Sifuentes Date: Tue, 23 Jun 2026 22:05:37 +0000 Subject: [PATCH 07/10] linter --- dataflow/gemma/custom_model_gemma.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index b424b94c3c6..9bdd7fc443e 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -13,12 +13,11 @@ # limitations under the License. from collections.abc import Iterable, Sequence - -import logging - from typing import Any from typing import Optional +import logging + import apache_beam as beam from apache_beam.ml.inference import utils from apache_beam.ml.inference.base import ModelHandler From 7e30d08ac0fc32cb3f5d7b05187f8e66a3534b52 Mon Sep 17 00:00:00 2001 From: David del Real Sifuentes Date: Tue, 23 Jun 2026 23:20:20 +0000 Subject: [PATCH 08/10] linting --- dataflow/gemma/custom_model_gemma.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index 9bdd7fc443e..8e023d71849 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -16,9 +16,10 @@ from typing import Any from typing import Optional +import apache_beam as beam + import logging -import apache_beam as beam from apache_beam.ml.inference import utils from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import PredictionResult From 9cd6455806338507d5430d8c45586cfe969d54b7 Mon Sep 17 00:00:00 2001 From: David del Real Sifuentes Date: Tue, 23 Jun 2026 23:25:31 +0000 Subject: [PATCH 09/10] more linting, and refactore message on testing locally. --- dataflow/gemma/custom_model_gemma.py | 5 +++-- dataflow/gemma/e2e_test.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index 8e023d71849..e32c0c76a0c 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -18,13 +18,14 @@ import apache_beam as beam -import logging - from apache_beam.ml.inference import utils from apache_beam.ml.inference.base import ModelHandler from apache_beam.ml.inference.base import PredictionResult from apache_beam.ml.inference.base import RunInference from apache_beam.options.pipeline_options import PipelineOptions + +import logging + import keras_nlp from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM diff --git a/dataflow/gemma/e2e_test.py b/dataflow/gemma/e2e_test.py index 72a1bc3e131..82bb500e0ff 100644 --- a/dataflow/gemma/e2e_test.py +++ b/dataflow/gemma/e2e_test.py @@ -48,10 +48,10 @@ import pytest DATAFLOW_MACHINE_TYPE = "g2-standard-8" -# TODO If testing locally, point this to a bucket you control -# and download the gemma assets -# GEMMA_GCS = "gs://perm-dataflow-gemma-example-testdata/gemma_2b" -GEMMA_GCS = "gs://test-bucket-for-gemma/assets_here/gemma_2b" +# NOTE: For local testing, ensure the 'gemma_2b_en' directory is uploaded +# to a GCS bucket you manage. Update the constant below to point to +# the root path of this uploaded directory (e.g., 'gs://your-bucket-name/path/to/gemma_2b_en'). +GEMMA_GCS = "gs://perm-dataflow-gemma-example-testdata/gemma_2b" NAME = "dataflow/gemma/streaming" From f5880f3b9b7f569a2e8749b9c53718028a939990 Mon Sep 17 00:00:00 2001 From: David del Real Sifuentes Date: Tue, 23 Jun 2026 23:30:51 +0000 Subject: [PATCH 10/10] more linting --- dataflow/gemma/custom_model_gemma.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index e32c0c76a0c..d2b556a0247 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -24,11 +24,12 @@ from apache_beam.ml.inference.base import RunInference from apache_beam.options.pipeline_options import PipelineOptions -import logging - import keras_nlp + from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM +import logging + class GemmaModelHandler(ModelHandler[str, PredictionResult, GemmaCausalLM]): def __init__(