From 6d8170e52899e57f98acf85078a09fcb0c6746aa Mon Sep 17 00:00:00 2001 From: Dan <31395415+cakedan@users.noreply.github.com> Date: Thu, 9 Oct 2025 01:12:19 -0700 Subject: [PATCH 1/3] Update pipeline_qwenimage_edit_plus.py --- .../pipelines/qwenimage/pipeline_qwenimage_edit_plus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index ec203edf166c..952c43b54d82 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -677,7 +677,7 @@ def __call__( condition_width, condition_height = calculate_dimensions( CONDITION_IMAGE_SIZE, image_width / image_height ) - vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) + vae_width, vae_height = width, height condition_image_sizes.append((condition_width, condition_height)) vae_image_sizes.append((vae_width, vae_height)) condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) From 6b77b72fd35a4ba4c0011bf09f14731d468154e9 Mon Sep 17 00:00:00 2001 From: Dan <31395415+cakedan@users.noreply.github.com> Date: Fri, 16 Jan 2026 20:15:30 -0800 Subject: [PATCH 2/3] Revert " Fix QwenImage txt_seq_lens handling (#12702)" This reverts commit dad5cb55e6ade24fb397525eb023ad4eba37019d. --- docs/source/en/api/pipelines/qwenimage.md | 38 +--- .../train_dreambooth_lora_qwen_image.py | 2 + src/diffusers/models/attention_dispatch.py | 78 +------ .../controlnets/controlnet_qwenimage.py | 71 ++---- .../transformers/transformer_qwenimage.py | 203 ++++-------------- .../qwenimage/before_denoise.py | 41 ++++ .../modular_pipelines/qwenimage/denoise.py | 11 +- .../pipelines/qwenimage/pipeline_qwenimage.py | 7 + .../pipeline_qwenimage_controlnet.py | 3 + .../pipeline_qwenimage_controlnet_inpaint.py | 3 + .../qwenimage/pipeline_qwenimage_edit.py | 7 + .../pipeline_qwenimage_edit_inpaint.py | 7 + .../qwenimage/pipeline_qwenimage_edit_plus.py | 14 +- .../qwenimage/pipeline_qwenimage_img2img.py | 7 + .../qwenimage/pipeline_qwenimage_inpaint.py | 7 + .../qwenimage/pipeline_qwenimage_layered.py | 8 +- .../test_models_transformer_qwenimage.py | 178 +-------------- 17 files changed, 172 insertions(+), 513 deletions(-) diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md index ee3dd3b28e4d..1dcf3f944fc0 100644 --- a/docs/source/en/api/pipelines/qwenimage.md +++ b/docs/source/en/api/pipelines/qwenimage.md @@ -108,46 +108,12 @@ pipe = QwenImageEditPlusPipeline.from_pretrained( image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg") image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png") image = pipe( - image=[image_1, image_2], - prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', + image=[image_1, image_2], + prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', num_inference_steps=50 ).images[0] ``` -## Performance - -### torch.compile - -Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s): - -```python -import torch -from diffusers import QwenImagePipeline - -pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda") -pipe.transformer = torch.compile(pipe.transformer) - -# First call triggers compilation (~7s overhead) -# Subsequent calls run at ~2.4x faster -image = pipe("a cat", num_inference_steps=50).images[0] -``` - -### Batched Inference with Variable-Length Prompts - -When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output. - -```python -# CFG with different prompt lengths works correctly -image = pipe( - prompt="A cat", - negative_prompt="blurry, low quality, distorted", - true_cfg_scale=3.5, - num_inference_steps=50, -).images[0] -``` - -For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f). - ## QwenImagePipeline [[autodoc]] QwenImagePipeline diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index ea9b137b0acd..53b01bf0cfc8 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -1513,12 +1513,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): height=model_input.shape[3], width=model_input.shape[4], ) + print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}") model_pred = transformer( hidden_states=packed_noisy_model_input, encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, timestep=timesteps / 1000, img_shapes=img_shapes, + txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, )[0] model_pred = QwenImagePipeline._unpack_latents( diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 61c478b03c4f..0c7c2708adb9 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2128,43 +2128,6 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): return out -def _prepare_additive_attn_mask( - attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True -) -> torch.Tensor: - """ - Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA. - - This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks. - - Args: - attn_mask: 2D tensor [batch_size, seq_len_k] - - Boolean: True means attend, False means mask out - - Additive: 0.0 means attend, -inf means mask out - target_dtype: The dtype to convert the mask to (usually query.dtype) - reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting - - Returns: - Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if - reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True. - """ - # Check if the mask is boolean or already additive - if attn_mask.dtype == torch.bool: - # Convert boolean to additive: True -> 0.0, False -> -inf - attn_mask = torch.where(attn_mask, 0.0, float("-inf")) - # Convert to target dtype - attn_mask = attn_mask.to(dtype=target_dtype) - else: - # Already additive mask - just ensure correct dtype - attn_mask = attn_mask.to(dtype=target_dtype) - - # Optionally reshape to 4D for broadcasting in attention mechanisms - if reshape_4d: - batch_size, seq_len_k = attn_mask.shape - attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k) - - return attn_mask - - @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], @@ -2184,19 +2147,6 @@ def _native_attention( ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") - - # Reshape 2D mask to 4D for SDPA - # SDPA accepts both boolean masks (torch.bool) and additive masks (float) - if ( - attn_mask is not None - and attn_mask.ndim == 2 - and attn_mask.shape[0] == query.shape[0] - and attn_mask.shape[1] == key.shape[1] - ): - # Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k] - # SDPA handles both boolean and additive masks correctly - attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) - if _parallel_config is None: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( @@ -2763,34 +2713,10 @@ def _xformers_attention( attn_mask = xops.LowerTriangularMask() elif attn_mask is not None: if attn_mask.ndim == 2: - # Convert 2D mask to 4D for xformers - # Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask) - # xformers requires 4D additive masks [batch, heads, seq_q, seq_k] - # Need memory alignment - create larger tensor and slice for alignment - original_seq_len = attn_mask.size(1) - aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8 - - # Create aligned 4D tensor and slice to ensure proper memory layout - aligned_mask = torch.zeros( - (batch_size, num_heads_q, seq_len_q, aligned_seq_len), - dtype=query.dtype, - device=query.device, - ) - # Convert to 4D additive mask (handles both boolean and additive inputs) - mask_additive = _prepare_additive_attn_mask( - attn_mask, target_dtype=query.dtype - ) # [batch, 1, 1, seq_len_k] - # Broadcast to [batch, heads, seq_q, seq_len_k] - aligned_mask[:, :, :, :original_seq_len] = mask_additive - # Mask out the padding (already -inf from zeros -> where with default) - aligned_mask[:, :, :, original_seq_len:] = float("-inf") - - # Slice to actual size with proper alignment - attn_mask = aligned_mask[:, :, :, :seq_len_kv] + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) elif attn_mask.ndim != 4: raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") - elif attn_mask.ndim == 4: - attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) if enable_gqa: if num_heads_q % num_heads_kv != 0: diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index fa374285eec1..86971271788f 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..controlnets.controlnet import zero_module @@ -31,7 +31,6 @@ QwenImageTransformerBlock, QwenTimestepProjEmbeddings, RMSNorm, - compute_text_seq_len_from_mask, ) @@ -137,7 +136,7 @@ def forward( return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ - The [`QwenImageControlNetModel`] forward method. + The [`FluxTransformer2DModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): @@ -148,39 +147,24 @@ def forward( The scale factor for ControlNet outputs. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): - Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. - Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern - (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. - img_shapes (`List[Tuple[int, int, int]]`, *optional*): - Image shapes for RoPE computation. - txt_seq_lens (`List[int]`, *optional*): - **Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence - length. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. Returns: - If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where - the first element is the controlnet block samples. + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. """ - # Handle deprecated txt_seq_lens parameter - if txt_seq_lens is not None: - deprecate( - "txt_seq_lens", - "0.39.0", - "Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in " - "version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` " - "and `encoder_hidden_states_mask`.", - standard_warn=False, - ) - if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) @@ -202,47 +186,32 @@ def forward( temb = self.time_text_embed(timestep, hidden_states) - # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask - text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( - encoder_hidden_states, encoder_hidden_states_mask - ) - - image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) timestep = timestep.to(hidden_states.dtype) encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - # Construct joint attention mask once to avoid reconstructing in every block - block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {} - if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] - batch_size, image_seq_len = hidden_states.shape[:2] - image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) - joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) - block_attention_kwargs["attention_mask"] = joint_attention_mask - block_samples = () - for block in self.transformer_blocks: + for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, - None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) + encoder_hidden_states_mask, temb, image_rotary_emb, - block_attention_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) + encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=block_attention_kwargs, + joint_attention_kwargs=joint_attention_kwargs, ) block_samples = block_samples + (hidden_states,) @@ -298,15 +267,6 @@ def forward( joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[QwenImageControlNetOutput, Tuple]: - if txt_seq_lens is not None: - deprecate( - "txt_seq_lens", - "0.39.0", - "Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be " - "removed in version 0.39.0. The text sequence length is now automatically inferred from " - "`encoder_hidden_states` and `encoder_hidden_states_mask`.", - standard_warn=False, - ) # ControlNet-Union with multiple conditions # only load one ControlNet for saving memories if len(self.nets) == 1: @@ -321,6 +281,7 @@ def forward( encoder_hidden_states_mask=encoder_hidden_states_mask, timestep=timestep, img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, joint_attention_kwargs=joint_attention_kwargs, return_dict=return_dict, ) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index cf11d8e01fb4..06006430d1cd 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -24,7 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, FeedForward @@ -142,32 +142,6 @@ def apply_rotary_emb_qwen( return x_out.type_as(x) -def compute_text_seq_len_from_mask( - encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor] -) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]: - """ - Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask. - """ - batch_size, text_seq_len = encoder_hidden_states.shape[:2] - if encoder_hidden_states_mask is None: - return text_seq_len, None, None - - if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len): - raise ValueError( - f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match " - f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})." - ) - - if encoder_hidden_states_mask.dtype != torch.bool: - encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool) - - position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) - active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) - has_active = encoder_hidden_states_mask.any(dim=1) - per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) - return text_seq_len, per_sample_len, encoder_hidden_states_mask - - class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, use_additional_t_cond=False): super().__init__() @@ -233,50 +207,21 @@ def rope_params(self, index, dim, theta=10000): def forward( self, video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], - txt_seq_lens: Optional[List[int]] = None, - device: torch.device = None, - max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None, + txt_seq_lens: List[int], + device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video. - txt_seq_lens (`List[int]`, *optional*, **Deprecated**): - Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used. - device: (`torch.device`, *optional*): + txt_seq_lens (`List[int]`): + A list of integers of length batch_size representing the length of each text prompt. + device: (`torch.device`): The device on which to perform the RoPE computation. - max_txt_seq_len (`int` or `torch.Tensor`, *optional*): - The maximum text sequence length for RoPE computation. This should match the encoder hidden states - sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). """ - # Handle deprecated txt_seq_lens parameter - if txt_seq_lens is not None: - deprecate( - "txt_seq_lens", - "0.39.0", - "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " - "Please use `max_txt_seq_len` instead. " - "The new parameter accepts a single int or tensor value representing the maximum text sequence length.", - standard_warn=False, - ) - if max_txt_seq_len is None: - # Use max of txt_seq_lens for backward compatibility - max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens - - if max_txt_seq_len is None: - raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") - - # Validate batch inference with variable-sized images - if isinstance(video_fhw, list) and len(video_fhw) > 1: - # Check if all instances have the same size - first_fhw = video_fhw[0] - if not all(fhw == first_fhw for fhw in video_fhw): - logger.warning( - "Batch inference with variable-sized images is not currently supported in QwenEmbedRope. " - "All images in the batch should have the same dimensions (frame, height, width). " - f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} " - "for RoPE computation, which may lead to incorrect results for other images in the batch." - ) + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) if isinstance(video_fhw, list): video_fhw = video_fhw[0] @@ -288,7 +233,8 @@ def forward( for idx, fhw in enumerate(video_fhw): frame, height, width = fhw # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs - video_freq = self._compute_video_freqs(frame, height, width, idx, device) + video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = video_freq.to(device) vid_freqs.append(video_freq) if self.scale_rope: @@ -296,23 +242,17 @@ def forward( else: max_vid_index = max(height, width, max_vid_index) - max_txt_seq_len_int = int(max_txt_seq_len) - # Create device-specific copy for text freqs without modifying self.pos_freqs - txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=128) - def _compute_video_freqs( - self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None - ) -> torch.Tensor: + def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor: seq_lens = frame * height * width - pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs - neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs - - freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -364,35 +304,14 @@ def rope_params(self, index, dim, theta=10000): freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs - def forward( - self, - video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], - max_txt_seq_len: Union[int, torch.Tensor], - device: torch.device = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, video_fhw, txt_seq_lens, device): """ - Args: - video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): - A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer - structures. - max_txt_seq_len (`int` or `torch.Tensor`): - The maximum text sequence length for RoPE computation. This should match the encoder hidden states - sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). - device: (`torch.device`, *optional*): - The device on which to perform the RoPE computation. + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text """ - # Validate batch inference with variable-sized images - # In Layer3DRope, the outer list represents batch, inner list/tuple represents layers - if isinstance(video_fhw, list) and len(video_fhw) > 1: - # Check if this is batch inference (list of layer lists/tuples) - first_entry = video_fhw[0] - if not all(entry == first_entry for entry in video_fhw): - logger.warning( - "Batch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. " - "All images in the batch should have the same layer structure. " - f"Detected sizes: {video_fhw}. Using the first image's layer structure {first_entry} " - "for RoPE computation, which may lead to incorrect results for other images in the batch." - ) + if self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) if isinstance(video_fhw, list): video_fhw = video_fhw[0] @@ -405,10 +324,11 @@ def forward( for idx, fhw in enumerate(video_fhw): frame, height, width = fhw if idx != layer_num: - video_freq = self._compute_video_freqs(frame, height, width, idx, device) + video_freq = self._compute_video_freqs(frame, height, width, idx) else: ### For the condition image, we set the layer index to -1 - video_freq = self._compute_condition_freqs(frame, height, width, device) + video_freq = self._compute_condition_freqs(frame, height, width) + video_freq = video_freq.to(device) vid_freqs.append(video_freq) if self.scale_rope: @@ -417,21 +337,17 @@ def forward( max_vid_index = max(height, width, max_vid_index) max_vid_index = max(max_vid_index, layer_num) - max_txt_seq_len_int = int(max_txt_seq_len) - # Create device-specific copy for text freqs without modifying self.pos_freqs - txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=None) - def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None): + def _compute_video_freqs(self, frame, height, width, idx=0): seq_lens = frame * height * width - pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs - neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs - - freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -447,13 +363,10 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device return freqs.clone().contiguous() @functools.lru_cache(maxsize=None) - def _compute_condition_freqs(self, frame, height, width, device: torch.device = None): + def _compute_condition_freqs(self, frame, height, width): seq_lens = frame * height * width - pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs - neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs - - freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -541,6 +454,7 @@ def __call__( joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) + # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, joint_key, @@ -851,25 +765,14 @@ def forward( Input `hidden_states`. encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): - Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. - Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern - (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): + Mask of the input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. - img_shapes (`List[Tuple[int, int, int]]`, *optional*): - Image shapes for RoPE computation. - txt_seq_lens (`List[int]`, *optional*, **Deprecated**): - Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be - used to compute RoPE sequence length. - guidance (`torch.Tensor`, *optional*): - Guidance tensor for conditional generation. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - controlnet_block_samples (*optional*): - ControlNet block samples to add to the transformer blocks. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. @@ -878,15 +781,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if txt_seq_lens is not None: - deprecate( - "txt_seq_lens", - "0.39.0", - "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " - "Please use `encoder_hidden_states_mask` instead. " - "The mask-based approach is more flexible and supports variable-length sequences.", - standard_warn=False, - ) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -919,11 +813,6 @@ def forward( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask - text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( - encoder_hidden_states, encoder_hidden_states_mask - ) - if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 @@ -933,17 +822,7 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond) ) - image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) - - # Construct joint attention mask once to avoid reconstructing in every block - # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility - block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} - if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] - batch_size, image_seq_len = hidden_states.shape[:2] - image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) - joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) - block_attention_kwargs["attention_mask"] = joint_attention_mask + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -951,10 +830,10 @@ def forward( block, hidden_states, encoder_hidden_states, - None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) + encoder_hidden_states_mask, temb, image_rotary_emb, - block_attention_kwargs, + attention_kwargs, modulate_index, ) @@ -962,10 +841,10 @@ def forward( encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) + encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=block_attention_kwargs, + joint_attention_kwargs=attention_kwargs, modulate_index=modulate_index, ) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index d9c8cbb01d18..e14164229c7f 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -682,6 +682,18 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), + OutputParam( + name="txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=List[int], + description="The sequence lengths of the prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="negative_txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=List[int], + description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", + ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -696,6 +708,14 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - ) ] ] * block_state.batch_size + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) self.set_block_state(state, block_state) @@ -730,6 +750,18 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), + OutputParam( + name="txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=List[int], + description="The sequence lengths of the prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="negative_txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=List[int], + description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", + ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -751,6 +783,15 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - ] ] * block_state.batch_size + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index d6bcb4a94f80..eb1e5a341c68 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -155,7 +155,7 @@ def inputs(self) -> List[InputParam]: kwargs_type="denoiser_input_fields", description=( "All conditional model inputs for the denoiser. " - "It should contain prompt_embeds/negative_prompt_embeds." + "It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens." ), ), ] @@ -182,6 +182,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState img_shapes=block_state.img_shapes, encoder_hidden_states=block_state.prompt_embeds, encoder_hidden_states_mask=block_state.prompt_embeds_mask, + txt_seq_lens=block_state.txt_seq_lens, return_dict=False, ) @@ -253,6 +254,10 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), + "txt_seq_lens": ( + getattr(block_state, "txt_seq_lens", None), + getattr(block_state, "negative_txt_seq_lens", None), + ), } transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) @@ -353,6 +358,10 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), + "txt_seq_lens": ( + getattr(block_state, "txt_seq_lens", None), + getattr(block_state, "negative_txt_seq_lens", None), + ), } transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index bc3ce84e1019..33dc2039b986 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -672,6 +672,11 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -690,6 +695,7 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -703,6 +709,7 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index ce6fc974a56e..5111096d93c1 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -909,6 +909,7 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, + txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -919,6 +920,7 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, + txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -933,6 +935,7 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 77d78a5ca7a1..102a813ab582 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -852,6 +852,7 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, + txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -862,6 +863,7 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, + txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -876,6 +878,7 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index dd723460a59e..ed37b238c8c9 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -793,6 +793,11 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -816,6 +821,7 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -830,6 +836,7 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index cf467203a9d2..d54d1881fa4e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -1008,6 +1008,11 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1030,6 +1035,7 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -1044,6 +1050,7 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 811b683e84b2..952c43b54d82 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -663,13 +663,6 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # QwenImageEditPlusPipeline does not currently support batch_size > 1 - if batch_size > 1: - raise ValueError( - f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. " - "Please process prompts one at a time." - ) - device = self._execution_device # 3. Preprocess image if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): @@ -784,6 +777,11 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -807,6 +805,7 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -821,6 +820,7 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index e0b41b8b8799..cb4c5d8016bb 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -775,6 +775,11 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -792,6 +797,7 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -805,6 +811,7 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 83f02539b1ba..1915c27eb2bb 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -944,6 +944,11 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -961,6 +966,7 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -974,6 +980,7 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 53d2c169ee63..7bb12c26baa4 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -781,6 +781,10 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long) # 6. Denoising loop self.scheduler.set_begin_index(0) @@ -805,6 +809,7 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, additional_t_cond=is_rgb, return_dict=False, @@ -820,6 +825,7 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, + txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, additional_t_cond=is_rgb, return_dict=False, @@ -879,7 +885,7 @@ def __call__( latents = latents[:, :, 1:] # remove the first frame as it is the orgin input - latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w) + latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 384954dfbad7..b24fa90503ef 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -15,10 +15,10 @@ import unittest +import pytest import torch from diffusers import QwenImageTransformer2DModel -from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -68,6 +68,7 @@ def prepare_dummy_input(self, height=4, width=4): "encoder_hidden_states_mask": encoder_hidden_states_mask, "timestep": timestep, "img_shapes": img_shapes, + "txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(), } def prepare_init_args_and_inputs_for_common(self): @@ -90,180 +91,6 @@ def test_gradient_checkpointing_is_applied(self): expected_set = {"QwenImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - def test_infers_text_seq_len_from_mask(self): - """Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors.""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) - - # Test 1: Contiguous mask with padding at the end (only first 2 tokens valid) - encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() - encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid - - rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( - inputs["encoder_hidden_states"], encoder_hidden_states_mask - ) - - # Verify rope_text_seq_len is returned as an int (for torch.compile compatibility) - self.assertIsInstance(rope_text_seq_len, int) - - # Verify per_sample_len is computed correctly (max valid position + 1 = 2) - self.assertIsInstance(per_sample_len, torch.Tensor) - self.assertEqual(int(per_sample_len.max().item()), 2) - - # Verify mask is normalized to bool dtype - self.assertTrue(normalized_mask.dtype == torch.bool) - self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values - - # Verify rope_text_seq_len is at least the sequence length - self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1]) - - # Test 2: Verify model runs successfully with inferred values - inputs["encoder_hidden_states_mask"] = normalized_mask - with torch.no_grad(): - output = model(**inputs) - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) - - # Test 3: Different mask pattern (padding at beginning) - encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone() - encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding - encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid - - rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask( - inputs["encoder_hidden_states"], encoder_hidden_states_mask2 - ) - - # Max valid position is 6 (last token), so per_sample_len should be 7 - self.assertEqual(int(per_sample_len2.max().item()), 7) - self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values - - # Test 4: No mask provided (None case) - rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( - inputs["encoder_hidden_states"], None - ) - self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1]) - self.assertIsInstance(rope_text_seq_len_none, int) - self.assertIsNone(per_sample_len_none) - self.assertIsNone(normalized_mask_none) - - def test_non_contiguous_attention_mask(self): - """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) - - # Create a non-contiguous mask pattern: valid, padding, valid, padding, etc. - encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() - # Pattern: [True, False, True, False, True, False, False] - encoder_hidden_states_mask[:, 1] = 0 - encoder_hidden_states_mask[:, 3] = 0 - encoder_hidden_states_mask[:, 5:] = 0 - - inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( - inputs["encoder_hidden_states"], encoder_hidden_states_mask - ) - self.assertEqual(int(per_sample_len.max().item()), 5) - self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1]) - self.assertIsInstance(inferred_rope_len, int) - self.assertTrue(normalized_mask.dtype == torch.bool) - - inputs["encoder_hidden_states_mask"] = normalized_mask - - with torch.no_grad(): - output = model(**inputs) - - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) - - def test_txt_seq_lens_deprecation(self): - """Test that passing txt_seq_lens raises a deprecation warning.""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) - - # Prepare inputs with txt_seq_lens (deprecated parameter) - txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]] - - # Remove encoder_hidden_states_mask to use the deprecated path - inputs_with_deprecated = inputs.copy() - inputs_with_deprecated.pop("encoder_hidden_states_mask") - inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens - - # Test that deprecation warning is raised - with self.assertWarns(FutureWarning) as warning_context: - with torch.no_grad(): - output = model(**inputs_with_deprecated) - - # Verify the warning message mentions the deprecation - warning_message = str(warning_context.warning) - self.assertIn("txt_seq_lens", warning_message) - self.assertIn("deprecated", warning_message) - self.assertIn("encoder_hidden_states_mask", warning_message) - - # Verify the model still works correctly despite the deprecation - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) - - def test_layered_model_with_mask(self): - """Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model).""" - # Create layered model config - init_dict = { - "patch_size": 2, - "in_channels": 16, - "out_channels": 4, - "num_layers": 2, - "attention_head_dim": 16, - "num_attention_heads": 3, - "joint_attention_dim": 16, - "axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16) - "use_layer3d_rope": True, # Enable layered RoPE - "use_additional_t_cond": True, # Enable additional time conditioning - } - - model = self.model_class(**init_dict).to(torch_device) - - # Verify the model uses QwenEmbedLayer3DRope - from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope - - self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope) - - # Test single generation with layered structure - batch_size = 1 - text_seq_len = 7 - img_h, img_w = 4, 4 - layers = 4 - - # For layered model: (layers + 1) because we have N layers + 1 combined image - hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device) - encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device) - - # Create mask with some padding - encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device) - encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens - - timestep = torch.tensor([1.0]).to(torch_device) - - # additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding) - addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device) - - # Layer structure: 4 layers + 1 condition image - img_shapes = [ - [ - (1, img_h, img_w), # layer 0 - (1, img_h, img_w), # layer 1 - (1, img_h, img_w), # layer 2 - (1, img_h, img_w), # layer 3 - (1, img_h, img_w), # condition image (last one gets special treatment) - ] - ] - - with torch.no_grad(): - output = model( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, - timestep=timestep, - img_shapes=img_shapes, - additional_t_cond=addition_t_cond, - ) - - self.assertEqual(output.sample.shape[1], hidden_states.shape[1]) - class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel @@ -274,5 +101,6 @@ def prepare_init_args_and_inputs_for_common(self): def prepare_dummy_input(self, height, width): return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width) + @pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True) def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break() From 50b4a8dff1d7831f39c901621b16deb4c8896bb7 Mon Sep 17 00:00:00 2001 From: Dan <31395415+cakedan@users.noreply.github.com> Date: Fri, 16 Jan 2026 20:17:23 -0800 Subject: [PATCH 3/3] Reapply " Fix QwenImage txt_seq_lens handling (#12702)" This reverts commit 6b77b72fd35a4ba4c0011bf09f14731d468154e9. --- docs/source/en/api/pipelines/qwenimage.md | 38 +++- .../train_dreambooth_lora_qwen_image.py | 2 - src/diffusers/models/attention_dispatch.py | 78 ++++++- .../controlnets/controlnet_qwenimage.py | 71 ++++-- .../transformers/transformer_qwenimage.py | 203 ++++++++++++++---- .../qwenimage/before_denoise.py | 41 ---- .../modular_pipelines/qwenimage/denoise.py | 11 +- .../pipelines/qwenimage/pipeline_qwenimage.py | 7 - .../pipeline_qwenimage_controlnet.py | 3 - .../pipeline_qwenimage_controlnet_inpaint.py | 3 - .../qwenimage/pipeline_qwenimage_edit.py | 7 - .../pipeline_qwenimage_edit_inpaint.py | 7 - .../qwenimage/pipeline_qwenimage_edit_plus.py | 14 +- .../qwenimage/pipeline_qwenimage_img2img.py | 7 - .../qwenimage/pipeline_qwenimage_inpaint.py | 7 - .../qwenimage/pipeline_qwenimage_layered.py | 8 +- .../test_models_transformer_qwenimage.py | 178 ++++++++++++++- 17 files changed, 513 insertions(+), 172 deletions(-) diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md index 1dcf3f944fc0..ee3dd3b28e4d 100644 --- a/docs/source/en/api/pipelines/qwenimage.md +++ b/docs/source/en/api/pipelines/qwenimage.md @@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained( image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg") image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png") image = pipe( - image=[image_1, image_2], - prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', + image=[image_1, image_2], + prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', num_inference_steps=50 ).images[0] ``` +## Performance + +### torch.compile + +Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s): + +```python +import torch +from diffusers import QwenImagePipeline + +pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda") +pipe.transformer = torch.compile(pipe.transformer) + +# First call triggers compilation (~7s overhead) +# Subsequent calls run at ~2.4x faster +image = pipe("a cat", num_inference_steps=50).images[0] +``` + +### Batched Inference with Variable-Length Prompts + +When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output. + +```python +# CFG with different prompt lengths works correctly +image = pipe( + prompt="A cat", + negative_prompt="blurry, low quality, distorted", + true_cfg_scale=3.5, + num_inference_steps=50, +).images[0] +``` + +For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f). + ## QwenImagePipeline [[autodoc]] QwenImagePipeline diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index 53b01bf0cfc8..ea9b137b0acd 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -1513,14 +1513,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): height=model_input.shape[3], width=model_input.shape[4], ) - print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}") model_pred = transformer( hidden_states=packed_noisy_model_input, encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, timestep=timesteps / 1000, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, )[0] model_pred = QwenImagePipeline._unpack_latents( diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 0c7c2708adb9..61c478b03c4f 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2128,6 +2128,43 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): return out +def _prepare_additive_attn_mask( + attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True +) -> torch.Tensor: + """ + Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA. + + This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks. + + Args: + attn_mask: 2D tensor [batch_size, seq_len_k] + - Boolean: True means attend, False means mask out + - Additive: 0.0 means attend, -inf means mask out + target_dtype: The dtype to convert the mask to (usually query.dtype) + reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting + + Returns: + Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if + reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True. + """ + # Check if the mask is boolean or already additive + if attn_mask.dtype == torch.bool: + # Convert boolean to additive: True -> 0.0, False -> -inf + attn_mask = torch.where(attn_mask, 0.0, float("-inf")) + # Convert to target dtype + attn_mask = attn_mask.to(dtype=target_dtype) + else: + # Already additive mask - just ensure correct dtype + attn_mask = attn_mask.to(dtype=target_dtype) + + # Optionally reshape to 4D for broadcasting in attention mechanisms + if reshape_4d: + batch_size, seq_len_k = attn_mask.shape + attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k) + + return attn_mask + + @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], @@ -2147,6 +2184,19 @@ def _native_attention( ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") + + # Reshape 2D mask to 4D for SDPA + # SDPA accepts both boolean masks (torch.bool) and additive masks (float) + if ( + attn_mask is not None + and attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + # Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k] + # SDPA handles both boolean and additive masks correctly + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) + if _parallel_config is None: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( @@ -2713,10 +2763,34 @@ def _xformers_attention( attn_mask = xops.LowerTriangularMask() elif attn_mask is not None: if attn_mask.ndim == 2: - attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + # Convert 2D mask to 4D for xformers + # Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask) + # xformers requires 4D additive masks [batch, heads, seq_q, seq_k] + # Need memory alignment - create larger tensor and slice for alignment + original_seq_len = attn_mask.size(1) + aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8 + + # Create aligned 4D tensor and slice to ensure proper memory layout + aligned_mask = torch.zeros( + (batch_size, num_heads_q, seq_len_q, aligned_seq_len), + dtype=query.dtype, + device=query.device, + ) + # Convert to 4D additive mask (handles both boolean and additive inputs) + mask_additive = _prepare_additive_attn_mask( + attn_mask, target_dtype=query.dtype + ) # [batch, 1, 1, seq_len_k] + # Broadcast to [batch, heads, seq_q, seq_len_k] + aligned_mask[:, :, :, :original_seq_len] = mask_additive + # Mask out the padding (already -inf from zeros -> where with default) + aligned_mask[:, :, :, original_seq_len:] = float("-inf") + + # Slice to actual size with proper alignment + attn_mask = aligned_mask[:, :, :, :seq_len_kv] elif attn_mask.ndim != 4: raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") - attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + elif attn_mask.ndim == 4: + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) if enable_gqa: if num_heads_q % num_heads_kv != 0: diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 86971271788f..fa374285eec1 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..controlnets.controlnet import zero_module @@ -31,6 +31,7 @@ QwenImageTransformerBlock, QwenTimestepProjEmbeddings, RMSNorm, + compute_text_seq_len_from_mask, ) @@ -136,7 +137,7 @@ def forward( return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ - The [`FluxTransformer2DModel`] forward method. + The [`QwenImageControlNetModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): @@ -147,24 +148,39 @@ def forward( The scale factor for ControlNet outputs. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. + Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern + (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. timestep ( `torch.LongTensor`): Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. + img_shapes (`List[Tuple[int, int, int]]`, *optional*): + Image shapes for RoPE computation. + txt_seq_lens (`List[int]`, *optional*): + **Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence + length. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. + If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where + the first element is the controlnet block samples. """ + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in " + "version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` " + "and `encoder_hidden_states_mask`.", + standard_warn=False, + ) + if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) @@ -186,32 +202,47 @@ def forward( temb = self.time_text_embed(timestep, hidden_states) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) + + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) timestep = timestep.to(hidden_states.dtype) encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) + # Construct joint attention mask once to avoid reconstructing in every block + block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {} + if encoder_hidden_states_mask is not None: + # Build joint mask: [text_mask, all_ones_for_image] + batch_size, image_seq_len = hidden_states.shape[:2] + image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) + joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + block_attention_kwargs["attention_mask"] = joint_attention_mask + block_samples = () - for index_block, block in enumerate(self.transformer_blocks): + for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, - encoder_hidden_states_mask, + None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) temb, image_rotary_emb, + block_attention_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, + encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=block_attention_kwargs, ) block_samples = block_samples + (hidden_states,) @@ -267,6 +298,15 @@ def forward( joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[QwenImageControlNetOutput, Tuple]: + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be " + "removed in version 0.39.0. The text sequence length is now automatically inferred from " + "`encoder_hidden_states` and `encoder_hidden_states_mask`.", + standard_warn=False, + ) # ControlNet-Union with multiple conditions # only load one ControlNet for saving memories if len(self.nets) == 1: @@ -281,7 +321,6 @@ def forward( encoder_hidden_states_mask=encoder_hidden_states_mask, timestep=timestep, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, joint_attention_kwargs=joint_attention_kwargs, return_dict=return_dict, ) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 06006430d1cd..cf11d8e01fb4 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -24,7 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, FeedForward @@ -142,6 +142,32 @@ def apply_rotary_emb_qwen( return x_out.type_as(x) +def compute_text_seq_len_from_mask( + encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor] +) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask. + """ + batch_size, text_seq_len = encoder_hidden_states.shape[:2] + if encoder_hidden_states_mask is None: + return text_seq_len, None, None + + if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len): + raise ValueError( + f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match " + f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})." + ) + + if encoder_hidden_states_mask.dtype != torch.bool: + encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool) + + position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) + active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) + has_active = encoder_hidden_states_mask.any(dim=1) + per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) + return text_seq_len, per_sample_len, encoder_hidden_states_mask + + class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, use_additional_t_cond=False): super().__init__() @@ -207,21 +233,50 @@ def rope_params(self, index, dim, theta=10000): def forward( self, video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], - txt_seq_lens: List[int], - device: torch.device, + txt_seq_lens: Optional[List[int]] = None, + device: torch.device = None, + max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video. - txt_seq_lens (`List[int]`): - A list of integers of length batch_size representing the length of each text prompt. - device: (`torch.device`): + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used. + device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. + max_txt_seq_len (`int` or `torch.Tensor`, *optional*): + The maximum text sequence length for RoPE computation. This should match the encoder hidden states + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). """ - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " + "Please use `max_txt_seq_len` instead. " + "The new parameter accepts a single int or tensor value representing the maximum text sequence length.", + standard_warn=False, + ) + if max_txt_seq_len is None: + # Use max of txt_seq_lens for backward compatibility + max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens + + if max_txt_seq_len is None: + raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") + + # Validate batch inference with variable-sized images + if isinstance(video_fhw, list) and len(video_fhw) > 1: + # Check if all instances have the same size + first_fhw = video_fhw[0] + if not all(fhw == first_fhw for fhw in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in QwenEmbedRope. " + "All images in the batch should have the same dimensions (frame, height, width). " + f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) if isinstance(video_fhw, list): video_fhw = video_fhw[0] @@ -233,8 +288,7 @@ def forward( for idx, fhw in enumerate(video_fhw): frame, height, width = fhw # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs - video_freq = self._compute_video_freqs(frame, height, width, idx) - video_freq = video_freq.to(device) + video_freq = self._compute_video_freqs(frame, height, width, idx, device) vid_freqs.append(video_freq) if self.scale_rope: @@ -242,17 +296,23 @@ def forward( else: max_vid_index = max(height, width, max_vid_index) - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=128) - def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor: + def _compute_video_freqs( + self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None + ) -> torch.Tensor: seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -304,14 +364,35 @@ def rope_params(self, index, dim, theta=10000): freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs - def forward(self, video_fhw, txt_seq_lens, device): + def forward( + self, + video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], + max_txt_seq_len: Union[int, torch.Tensor], + device: torch.device = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: - txt_length: [bs] a list of 1 integers representing the length of the text + Args: + video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): + A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer + structures. + max_txt_seq_len (`int` or `torch.Tensor`): + The maximum text sequence length for RoPE computation. This should match the encoder hidden states + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). + device: (`torch.device`, *optional*): + The device on which to perform the RoPE computation. """ - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) + # Validate batch inference with variable-sized images + # In Layer3DRope, the outer list represents batch, inner list/tuple represents layers + if isinstance(video_fhw, list) and len(video_fhw) > 1: + # Check if this is batch inference (list of layer lists/tuples) + first_entry = video_fhw[0] + if not all(entry == first_entry for entry in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. " + "All images in the batch should have the same layer structure. " + f"Detected sizes: {video_fhw}. Using the first image's layer structure {first_entry} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) if isinstance(video_fhw, list): video_fhw = video_fhw[0] @@ -324,11 +405,10 @@ def forward(self, video_fhw, txt_seq_lens, device): for idx, fhw in enumerate(video_fhw): frame, height, width = fhw if idx != layer_num: - video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = self._compute_video_freqs(frame, height, width, idx, device) else: ### For the condition image, we set the layer index to -1 - video_freq = self._compute_condition_freqs(frame, height, width) - video_freq = video_freq.to(device) + video_freq = self._compute_condition_freqs(frame, height, width, device) vid_freqs.append(video_freq) if self.scale_rope: @@ -337,17 +417,21 @@ def forward(self, video_fhw, txt_seq_lens, device): max_vid_index = max(height, width, max_vid_index) max_vid_index = max(max_vid_index, layer_num) - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=None) - def _compute_video_freqs(self, frame, height, width, idx=0): + def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None): seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -363,10 +447,13 @@ def _compute_video_freqs(self, frame, height, width, idx=0): return freqs.clone().contiguous() @functools.lru_cache(maxsize=None) - def _compute_condition_freqs(self, frame, height, width): + def _compute_condition_freqs(self, frame, height, width, device: torch.device = None): seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -454,7 +541,6 @@ def __call__( joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) - # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, joint_key, @@ -765,14 +851,25 @@ def forward( Input `hidden_states`. encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): - Mask of the input conditions. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. + Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern + (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. timestep ( `torch.LongTensor`): Used to indicate denoising step. + img_shapes (`List[Tuple[int, int, int]]`, *optional*): + Image shapes for RoPE computation. + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be + used to compute RoPE sequence length. + guidance (`torch.Tensor`, *optional*): + Guidance tensor for conditional generation. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_block_samples (*optional*): + ControlNet block samples to add to the transformer blocks. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. @@ -781,6 +878,15 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " + "Please use `encoder_hidden_states_mask` instead. " + "The mask-based approach is more flexible and supports variable-length sequences.", + standard_warn=False, + ) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -813,6 +919,11 @@ def forward( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) + # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) + if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 @@ -822,7 +933,17 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) + + # Construct joint attention mask once to avoid reconstructing in every block + # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility + block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} + if encoder_hidden_states_mask is not None: + # Build joint mask: [text_mask, all_ones_for_image] + batch_size, image_seq_len = hidden_states.shape[:2] + image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) + joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + block_attention_kwargs["attention_mask"] = joint_attention_mask for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -830,10 +951,10 @@ def forward( block, hidden_states, encoder_hidden_states, - encoder_hidden_states_mask, + None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) temb, image_rotary_emb, - attention_kwargs, + block_attention_kwargs, modulate_index, ) @@ -841,10 +962,10 @@ def forward( encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, + encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=attention_kwargs, + joint_attention_kwargs=block_attention_kwargs, modulate_index=modulate_index, ) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index e14164229c7f..d9c8cbb01d18 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -682,18 +682,6 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), - OutputParam( - name="txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the prompt embeds, used for RoPE calculation", - ), - OutputParam( - name="negative_txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", - ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -708,14 +696,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - ) ] ] * block_state.batch_size - block_state.txt_seq_lens = ( - block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None - ) - block_state.negative_txt_seq_lens = ( - block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() - if block_state.negative_prompt_embeds_mask is not None - else None - ) self.set_block_state(state, block_state) @@ -750,18 +730,6 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), - OutputParam( - name="txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the prompt embeds, used for RoPE calculation", - ), - OutputParam( - name="negative_txt_seq_lens", - kwargs_type="denoiser_input_fields", - type_hint=List[int], - description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", - ), ] def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -783,15 +751,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - ] ] * block_state.batch_size - block_state.txt_seq_lens = ( - block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None - ) - block_state.negative_txt_seq_lens = ( - block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() - if block_state.negative_prompt_embeds_mask is not None - else None - ) - self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index eb1e5a341c68..d6bcb4a94f80 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -155,7 +155,7 @@ def inputs(self) -> List[InputParam]: kwargs_type="denoiser_input_fields", description=( "All conditional model inputs for the denoiser. " - "It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens." + "It should contain prompt_embeds/negative_prompt_embeds." ), ), ] @@ -182,7 +182,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState img_shapes=block_state.img_shapes, encoder_hidden_states=block_state.prompt_embeds, encoder_hidden_states_mask=block_state.prompt_embeds_mask, - txt_seq_lens=block_state.txt_seq_lens, return_dict=False, ) @@ -254,10 +253,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), - "txt_seq_lens": ( - getattr(block_state, "txt_seq_lens", None), - getattr(block_state, "negative_txt_seq_lens", None), - ), } transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) @@ -358,10 +353,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), - "txt_seq_lens": ( - getattr(block_state, "txt_seq_lens", None), - getattr(block_state, "negative_txt_seq_lens", None), - ), } transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 33dc2039b986..bc3ce84e1019 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -672,11 +672,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -695,7 +690,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -709,7 +703,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 5111096d93c1..ce6fc974a56e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -909,7 +909,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -920,7 +919,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -935,7 +933,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 102a813ab582..77d78a5ca7a1 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -852,7 +852,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -863,7 +862,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -878,7 +876,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index ed37b238c8c9..dd723460a59e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -793,11 +793,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -821,7 +816,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -836,7 +830,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index d54d1881fa4e..cf467203a9d2 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -1008,11 +1008,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1035,7 +1030,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -1050,7 +1044,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 952c43b54d82..811b683e84b2 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -663,6 +663,13 @@ def __call__( else: batch_size = prompt_embeds.shape[0] + # QwenImageEditPlusPipeline does not currently support batch_size > 1 + if batch_size > 1: + raise ValueError( + f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. " + "Please process prompts one at a time." + ) + device = self._execution_device # 3. Preprocess image if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): @@ -777,11 +784,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -805,7 +807,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -820,7 +821,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index cb4c5d8016bb..e0b41b8b8799 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -775,11 +775,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -797,7 +792,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -811,7 +805,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 1915c27eb2bb..83f02539b1ba 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -944,11 +944,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -966,7 +961,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -980,7 +974,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 7bb12c26baa4..53d2c169ee63 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -781,10 +781,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long) # 6. Denoising loop self.scheduler.set_begin_index(0) @@ -809,7 +805,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, additional_t_cond=is_rgb, return_dict=False, @@ -825,7 +820,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, additional_t_cond=is_rgb, return_dict=False, @@ -885,7 +879,7 @@ def __call__( latents = latents[:, :, 1:] # remove the first frame as it is the orgin input - latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w) + latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w) image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index b24fa90503ef..384954dfbad7 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -15,10 +15,10 @@ import unittest -import pytest import torch from diffusers import QwenImageTransformer2DModel +from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -68,7 +68,6 @@ def prepare_dummy_input(self, height=4, width=4): "encoder_hidden_states_mask": encoder_hidden_states_mask, "timestep": timestep, "img_shapes": img_shapes, - "txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(), } def prepare_init_args_and_inputs_for_common(self): @@ -91,6 +90,180 @@ def test_gradient_checkpointing_is_applied(self): expected_set = {"QwenImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_infers_text_seq_len_from_mask(self): + """Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Test 1: Contiguous mask with padding at the end (only first 2 tokens valid) + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid + + rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask + ) + + # Verify rope_text_seq_len is returned as an int (for torch.compile compatibility) + self.assertIsInstance(rope_text_seq_len, int) + + # Verify per_sample_len is computed correctly (max valid position + 1 = 2) + self.assertIsInstance(per_sample_len, torch.Tensor) + self.assertEqual(int(per_sample_len.max().item()), 2) + + # Verify mask is normalized to bool dtype + self.assertTrue(normalized_mask.dtype == torch.bool) + self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values + + # Verify rope_text_seq_len is at least the sequence length + self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1]) + + # Test 2: Verify model runs successfully with inferred values + inputs["encoder_hidden_states_mask"] = normalized_mask + with torch.no_grad(): + output = model(**inputs) + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 3: Different mask pattern (padding at beginning) + encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding + encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid + + rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask2 + ) + + # Max valid position is 6 (last token), so per_sample_len should be 7 + self.assertEqual(int(per_sample_len2.max().item()), 7) + self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values + + # Test 4: No mask provided (None case) + rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], None + ) + self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1]) + self.assertIsInstance(rope_text_seq_len_none, int) + self.assertIsNone(per_sample_len_none) + self.assertIsNone(normalized_mask_none) + + def test_non_contiguous_attention_mask(self): + """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Create a non-contiguous mask pattern: valid, padding, valid, padding, etc. + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + # Pattern: [True, False, True, False, True, False, False] + encoder_hidden_states_mask[:, 1] = 0 + encoder_hidden_states_mask[:, 3] = 0 + encoder_hidden_states_mask[:, 5:] = 0 + + inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask + ) + self.assertEqual(int(per_sample_len.max().item()), 5) + self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1]) + self.assertIsInstance(inferred_rope_len, int) + self.assertTrue(normalized_mask.dtype == torch.bool) + + inputs["encoder_hidden_states_mask"] = normalized_mask + + with torch.no_grad(): + output = model(**inputs) + + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + def test_txt_seq_lens_deprecation(self): + """Test that passing txt_seq_lens raises a deprecation warning.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Prepare inputs with txt_seq_lens (deprecated parameter) + txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]] + + # Remove encoder_hidden_states_mask to use the deprecated path + inputs_with_deprecated = inputs.copy() + inputs_with_deprecated.pop("encoder_hidden_states_mask") + inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens + + # Test that deprecation warning is raised + with self.assertWarns(FutureWarning) as warning_context: + with torch.no_grad(): + output = model(**inputs_with_deprecated) + + # Verify the warning message mentions the deprecation + warning_message = str(warning_context.warning) + self.assertIn("txt_seq_lens", warning_message) + self.assertIn("deprecated", warning_message) + self.assertIn("encoder_hidden_states_mask", warning_message) + + # Verify the model still works correctly despite the deprecation + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + def test_layered_model_with_mask(self): + """Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model).""" + # Create layered model config + init_dict = { + "patch_size": 2, + "in_channels": 16, + "out_channels": 4, + "num_layers": 2, + "attention_head_dim": 16, + "num_attention_heads": 3, + "joint_attention_dim": 16, + "axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16) + "use_layer3d_rope": True, # Enable layered RoPE + "use_additional_t_cond": True, # Enable additional time conditioning + } + + model = self.model_class(**init_dict).to(torch_device) + + # Verify the model uses QwenEmbedLayer3DRope + from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope + + self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope) + + # Test single generation with layered structure + batch_size = 1 + text_seq_len = 7 + img_h, img_w = 4, 4 + layers = 4 + + # For layered model: (layers + 1) because we have N layers + 1 combined image + hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device) + encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device) + + # Create mask with some padding + encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device) + encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens + + timestep = torch.tensor([1.0]).to(torch_device) + + # additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding) + addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device) + + # Layer structure: 4 layers + 1 condition image + img_shapes = [ + [ + (1, img_h, img_w), # layer 0 + (1, img_h, img_w), # layer 1 + (1, img_h, img_w), # layer 2 + (1, img_h, img_w), # layer 3 + (1, img_h, img_w), # condition image (last one gets special treatment) + ] + ] + + with torch.no_grad(): + output = model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + timestep=timestep, + img_shapes=img_shapes, + additional_t_cond=addition_t_cond, + ) + + self.assertEqual(output.sample.shape[1], hidden_states.shape[1]) + class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel @@ -101,6 +274,5 @@ def prepare_init_args_and_inputs_for_common(self): def prepare_dummy_input(self, height, width): return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width) - @pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True) def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break()