3838from transformers .testing_utils import (
3939 Expectations ,
4040 require_deterministic_for_xpu ,
41+ require_flash_attn ,
4142 require_torch_accelerator ,
4243 slow ,
4344 torch_device ,
@@ -601,10 +602,9 @@ def test_block_sharing_with_hybrid_model(self) -> None:
601602
602603 return self ._test_block_sharing (model_id , num_layer_groups , input_msg , expected_generated_tokens )
603604
604- # The test always passes on H100 with torch 2.9, but only passed case 0 on A100 with torch 2.6 and fails on A100
605- # with torch 2.9. This might be due to a GPU diff, so test might be flaky on the CI which runs on A10.
606605 @parameterized .expand ([True , False ])
607606 @require_torch_accelerator
607+ @require_flash_attn # otherwise the test can fail because attention bias has a very slight impact on SDPA and eager
608608 def test_num_return_sequences (self , allow_block_sharing : bool ) -> None :
609609 model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
610610 tokenizer = AutoTokenizer .from_pretrained (model_id , padding_side = "left" )
@@ -616,7 +616,7 @@ def test_num_return_sequences(self, allow_block_sharing: bool) -> None:
616616 input_ids = [(x if isinstance (x , list ) else x ["input_ids" ]) for x in tokenized ]
617617
618618 # Generation with continuous batching
619- model = AutoModelForCausalLM .from_pretrained (model_id , attn_implementation = "sdpa " )
619+ model = AutoModelForCausalLM .from_pretrained (model_id , attn_implementation = "flash_attention_2 " )
620620 model = model .to (torch_device ).eval ()
621621 model .generation_config .max_new_tokens = 30
622622 model .generation_config .do_sample = False
0 commit comments