Skip to content

Commit 61317f5

Browse files
authored
[CB] Ensure parallel decoding test passes using FA (#43277)
Ensure parallel decoding matches using FA
1 parent 1efe1a6 commit 61317f5

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

tests/generation/test_continuous_batching.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from transformers.testing_utils import (
3939
Expectations,
4040
require_deterministic_for_xpu,
41+
require_flash_attn,
4142
require_torch_accelerator,
4243
slow,
4344
torch_device,
@@ -601,10 +602,9 @@ def test_block_sharing_with_hybrid_model(self) -> None:
601602

602603
return self._test_block_sharing(model_id, num_layer_groups, input_msg, expected_generated_tokens)
603604

604-
# The test always passes on H100 with torch 2.9, but only passed case 0 on A100 with torch 2.6 and fails on A100
605-
# with torch 2.9. This might be due to a GPU diff, so test might be flaky on the CI which runs on A10.
606605
@parameterized.expand([True, False])
607606
@require_torch_accelerator
607+
@require_flash_attn # otherwise the test can fail because attention bias has a very slight impact on SDPA and eager
608608
def test_num_return_sequences(self, allow_block_sharing: bool) -> None:
609609
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
610610
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
@@ -616,7 +616,7 @@ def test_num_return_sequences(self, allow_block_sharing: bool) -> None:
616616
input_ids = [(x if isinstance(x, list) else x["input_ids"]) for x in tokenized]
617617

618618
# Generation with continuous batching
619-
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa")
619+
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
620620
model = model.to(torch_device).eval()
621621
model.generation_config.max_new_tokens = 30
622622
model.generation_config.do_sample = False

0 commit comments

Comments
 (0)