Skip to content

Commit d48108d

Browse files
authored
Merge pull request #37874 from aIbrahiim/fix-30644-inference-python-benchmark
Fix vLLM Gemma benchmark and PyTorch language modeling tests
2 parents d51177b + fee107a commit d48108d

4 files changed

Lines changed: 4 additions & 4 deletions

File tree

.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_VLLM_Gemma_Batch.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,6 @@
3232
--metrics_table=gemma_vllm_batch
3333
--influx_measurement=gemma_vllm_batch
3434
--model_gcs_path=gs://apache-beam-ml/models/gemma-2b-it
35+
--requirements_file=apache_beam/ml/inference/vllm_tests_requirements.txt
3536
--dataflow_service_options=worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver
3637
--experiments=use_runner_v2

sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def tokenize_sentence(
5252
text_and_mask: tuple[str, str],
5353
bert_tokenizer: BertTokenizer) -> tuple[str, dict[str, torch.Tensor]]:
5454
text, masked_text = text_and_mask
55-
tokenized_sentence = bert_tokenizer.encode_plus(
56-
masked_text, return_tensors="pt")
55+
tokenized_sentence = bert_tokenizer(masked_text, return_tensors="pt")
5756

5857
# Workaround to manually remove batch dim until we have the feature to
5958
# add optional batching flag.

sdks/python/apache_beam/examples/inference/vllm_gemma_batch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def run(argv=None, save_main_session=True, test_pipeline=None):
103103

104104
gem = opts.view_as(GemmaVLLMOptions)
105105
opts.view_as(SetupOptions).save_main_session = save_main_session
106-
107106
logging.info("Pipeline starting with model path: %s", gem.model_gcs_path)
108107
handler = GcsVLLMCompletionsModelHandler(
109108
model_name=gem.model_gcs_path,

sdks/python/apache_beam/ml/inference/vllm_tests_requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
torch>=1.7.1
1818
torchvision>=0.8.2
1919
pillow>=8.0.0
20-
transformers>=4.18.0
20+
transformers==4.57.1
21+
sentencepiece==0.2.1
2122
google-cloud-monitoring>=2.27.0
2223
openai>=1.52.2

0 commit comments

Comments
 (0)