diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 10a814b8c..c1f1ed772 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -332,6 +332,7 @@ flow_shift: 3.0 # Skips the unconditional forward pass on ~35% of steps via residual compensation. # See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2 use_cfg_cache: False +use_kv_cache: False use_magcache: False magcache_thresh: 0.12 magcache_K: 2 diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index fde9efe8d..416d89ae8 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -286,6 +286,7 @@ flow_shift: 3.0 # Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) use_cfg_cache: False +use_kv_cache: False # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 0a013285b..5d0df31c7 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -307,6 +307,7 @@ boundary_ratio: 0.875 # Diffusion CFG cache (FasterCache-style) use_cfg_cache: False +use_kv_cache: False # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass # when predicted output change (based on accumulated latent/timestep drift) is small use_sen_cache: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index 686f66280..f0ef8300e 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -291,6 +291,7 @@ flow_shift: 5.0 # Diffusion CFG cache (FasterCache-style) use_cfg_cache: False +use_kv_cache: False # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) use_sen_cache: False use_magcache: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 3eac96ccd..3694d76f1 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -303,6 +303,7 @@ boundary_ratio: 0.875 # Diffusion CFG cache (FasterCache-style) use_cfg_cache: False +use_kv_cache: False # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) use_sen_cache: False diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 8c96a299b..e02f8ab46 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -15,7 +15,7 @@ import contextlib import functools import math -from typing import Optional, Callable, Tuple +from typing import Optional, Callable, Tuple, Dict import flax.linen as nn from flax import nnx import jax @@ -198,7 +198,8 @@ def convert_to_tokamax_splash_config( residual_checkpoint_name: str | None = None, attn_logits_soft_cap: float | None = None, fuse_reciprocal: bool = True, - use_base2_exp: bool = False, + use_base2_exp: bool = True, + use_experimental_scheduler: bool = True, max_logit_const: float | None = None, interpret: bool = False, dq_reduction_steps: int | None = None, @@ -221,6 +222,7 @@ def convert_to_tokamax_splash_config( attn_logits_soft_cap=attn_logits_soft_cap, fuse_reciprocal=fuse_reciprocal, use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, max_logit_const=max_logit_const, interpret=interpret, dq_reduction_steps=dq_reduction_steps, @@ -1132,6 +1134,7 @@ def __call__( encoder_attention_mask: Optional[jax.Array] = None, deterministic: bool = True, rngs: nnx.Rngs = None, + cached_kv: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None, ) -> jax.Array: axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) @@ -1146,16 +1149,22 @@ def __call__( if not is_i2v_cross_attention: with jax.named_scope("query_proj"): query_proj = self.query(hidden_states) - with jax.named_scope("key_proj"): - key_proj = self.key(encoder_hidden_states) - with jax.named_scope("value_proj"): - value_proj = self.value(encoder_hidden_states) - + if self.qk_norm: with self.conditional_named_scope("attn_q_norm"): query_proj = self.norm_q(query_proj) - with self.conditional_named_scope("attn_k_norm"): - key_proj = self.norm_k(key_proj) + + if not is_self_attention and cached_kv is not None and "text" in cached_kv: + key_proj, value_proj = cached_kv["text"] + else: + with jax.named_scope("key_proj"): + key_proj = self.key(encoder_hidden_states) + with jax.named_scope("value_proj"): + value_proj = self.value(encoder_hidden_states) + + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj = self.norm_k(key_proj) if rotary_emb is not None: with self.conditional_named_scope("attn_rope"): @@ -1170,7 +1179,7 @@ def __call__( value_proj = checkpoint_name(value_proj, "value_proj") with jax.named_scope("apply_attention"): - attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj, attention_mask=encoder_attention_mask) else: # NEW PATH for I2V CROSS-ATTENTION @@ -1197,9 +1206,11 @@ def __call__( # It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384 if encoder_attention_mask is not None: encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len] + encoder_attention_mask_text = encoder_attention_mask[:, padded_img_len:] else: # Fallback: no mask means treat all as valid (for dot product attention) encoder_attention_mask_img = None + encoder_attention_mask_text = None else: # If no image_seq_len is specified, treat all as text encoder_hidden_states_img = None @@ -1213,22 +1224,29 @@ def __call__( query_proj_text = query_proj_raw # Text K/V - with self.conditional_named_scope("proj_key"): - key_proj_text = self.key(encoder_hidden_states_text) - if self.qk_norm: - with self.conditional_named_scope("attn_k_norm"): - key_proj_text = self.norm_k(key_proj_text) - with self.conditional_named_scope("proj_value"): - value_proj_text = self.value(encoder_hidden_states_text) + if cached_kv is not None and "text" in cached_kv: + key_proj_text, value_proj_text = cached_kv["text"] + else: + with self.conditional_named_scope("proj_key"): + key_proj_text = self.key(encoder_hidden_states_text) + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj_text = self.norm_k(key_proj_text) + with self.conditional_named_scope("proj_value"): + value_proj_text = self.value(encoder_hidden_states_text) # Image K/V (only if image embeddings are present) if encoder_hidden_states_img is not None: - with self.conditional_named_scope("add_proj_k"): - key_proj_img = self.add_k_proj(encoder_hidden_states_img) - with self.conditional_named_scope("norm_add_k"): - key_proj_img = self.norm_added_k(key_proj_img) - with self.conditional_named_scope("add_proj_v"): - value_proj_img = self.add_v_proj(encoder_hidden_states_img) + if cached_kv is not None and "image" in cached_kv: + key_proj_img, value_proj_img = cached_kv["image"] + else: + with self.conditional_named_scope("add_proj_k"): + key_proj_img = self.add_k_proj(encoder_hidden_states_img) + with self.conditional_named_scope("norm_add_k"): + key_proj_img = self.norm_added_k(key_proj_img) + with self.conditional_named_scope("add_proj_v"): + value_proj_img = self.add_v_proj(encoder_hidden_states_img) + query_proj_img = query_proj_raw # Check norm_added_k too # Checkpointing @@ -1241,7 +1259,7 @@ def __call__( # Attention - tensors are (B, S, D) with self.conditional_named_scope("cross_attn_text_apply"): - attn_output_text = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text) + attn_output_text = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text, attention_mask=encoder_attention_mask_text) with self.conditional_named_scope("cross_attn_img_apply"): # Pass encoder_attention_mask_img for image cross-attention to mask padded tokens attn_output_img = self.attention_op.apply_attention( @@ -1256,7 +1274,7 @@ def __call__( value_proj_text = checkpoint_name(value_proj_text, "value_proj_text") with self.conditional_named_scope("cross_attn_text_apply"): - attn_output = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text) + attn_output = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text, attention_mask=encoder_attention_mask) attn_output = attn_output.astype(dtype=dtype) attn_output = checkpoint_name(attn_output, "attn_output") @@ -1267,6 +1285,64 @@ def __call__( hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) return hidden_states + def compute_kv( + self, + encoder_hidden_states: jax.Array, + encoder_attention_mask: Optional[jax.Array] = None, + ) -> Dict[str, Tuple[jax.Array, jax.Array]]: + is_i2v_cross_attention = self.added_kv_proj_dim is not None + + if not is_i2v_cross_attention: + with jax.named_scope("key_proj"): + key_proj = self.key(encoder_hidden_states) + with jax.named_scope("value_proj"): + value_proj = self.value(encoder_hidden_states) + + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj = self.norm_k(key_proj) + + return {"text": (key_proj, value_proj)} + else: + # Image embeddings are padded to multiples of 128 for TPU flash attention + alignment = 128 + if self.image_seq_len is not None: + image_seq_len_actual = self.image_seq_len + else: + image_seq_len_actual = 257 + padded_img_len = ((image_seq_len_actual + alignment - 1) // alignment) * alignment # 257 -> 384 + + if encoder_attention_mask is None: + padded_img_len = image_seq_len_actual + + encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :] + encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :] + + # Text K/V + with self.conditional_named_scope("proj_key"): + key_proj_text = self.key(encoder_hidden_states_text) + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj_text = self.norm_k(key_proj_text) + with self.conditional_named_scope("proj_value"): + value_proj_text = self.value(encoder_hidden_states_text) + + # Image K/V (only if image embeddings are present) + if encoder_hidden_states_img is not None: + with self.conditional_named_scope("add_proj_k"): + key_proj_img = self.add_k_proj(encoder_hidden_states_img) + with self.conditional_named_scope("norm_add_k"): + key_proj_img = self.norm_added_k(key_proj_img) + with self.conditional_named_scope("add_proj_v"): + value_proj_img = self.add_v_proj(encoder_hidden_states_img) + + return { + "text": (key_proj_text, value_proj_text), + "image": (key_proj_img, value_proj_img) + } + else: + return {"text": (key_proj_text, value_proj_text)} + class FlaxFluxAttention(nn.Module): query_dim: int diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index edb450454..51b85379c 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -152,17 +152,25 @@ def __init__( ) def __call__( - self, timestep: jax.Array, encoder_hidden_states: jax.Array, encoder_hidden_states_image: Optional[jax.Array] = None + self, + timestep: jax.Array, + encoder_hidden_states: jax.Array, + encoder_hidden_states_image: Optional[jax.Array] = None, + skip_embeddings: bool = False, ): timestep = self.timesteps_proj(timestep) temb = self.time_embedder(timestep) with jax.named_scope("time_proj"): timestep_proj = self.time_proj(self.act_fn(temb)) - - encoder_hidden_states = self.text_embedder(encoder_hidden_states) - encoder_attention_mask = None - if encoder_hidden_states_image is not None: - encoder_hidden_states_image, encoder_attention_mask = self.image_embedder(encoder_hidden_states_image) + + if not skip_embeddings: + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + encoder_attention_mask = None + if encoder_hidden_states_image is not None: + encoder_hidden_states_image, encoder_attention_mask = self.image_embedder(encoder_hidden_states_image) + else: + encoder_attention_mask = None + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask @@ -375,6 +383,7 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs = None, encoder_attention_mask: Optional[jax.Array] = None, + cached_kv: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None, ): with self.conditional_named_scope("transformer_block"): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( @@ -414,6 +423,7 @@ def __call__( deterministic=deterministic, rngs=rngs, encoder_attention_mask=encoder_attention_mask, + cached_kv=cached_kv, ) with self.conditional_named_scope("cross_attn_residual"): hidden_states = hidden_states + attn_output @@ -432,6 +442,13 @@ def __call__( ) return hidden_states + def compute_kv( + self, + encoder_hidden_states: jax.Array, + encoder_attention_mask: Optional[jax.Array] = None, + ) -> Dict[str, Tuple[jax.Array, jax.Array]]: + return self.attn2.compute_kv(encoder_hidden_states, encoder_attention_mask) + class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -584,6 +601,58 @@ def conditional_named_scope(self, name: str): """Return a JAX named scope if enabled, otherwise a null context.""" return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def compute_kv_cache( + self, + encoder_hidden_states: jax.Array, + encoder_hidden_states_image: Optional[jax.Array] = None, + timestep: Optional[jax.Array] = None, + text_mask: Optional[jax.Array] = None, + ) -> Tuple[Dict[str, Tuple[jax.Array, jax.Array]], Optional[jax.Array]]: + if timestep is None: + batch_size = encoder_hidden_states.shape[0] + timestep = jnp.zeros((batch_size,), dtype=jnp.int32) + + with self.conditional_named_scope("condition_embedder"): + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + encoder_attention_mask, + ) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) + if encoder_attention_mask is not None: + if text_mask is None: + text_mask = jnp.ones( + (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), + dtype=jnp.int32, + ) + encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) + else: + if encoder_attention_mask is None: + encoder_attention_mask = text_mask + + if self.scan_layers: + @nnx.vmap(in_axes=(0, None, None), out_axes=0, transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"}) + def _compute_kv(block, enc_states, enc_mask): + return block.compute_kv(enc_states, enc_mask) + + kv_cache = _compute_kv(self.blocks, encoder_hidden_states, encoder_attention_mask) + else: + kv_cache_list = [] + for block in self.blocks: + kv_cache_list.append(block.compute_kv(encoder_hidden_states, encoder_attention_mask)) + keys = kv_cache_list[0].keys() + kv_cache = {} + for k in keys: + k_list = [d[k][0] for d in kv_cache_list] + v_list = [d[k][1] for d in kv_cache_list] + kv_cache[k] = (jnp.stack(k_list, axis=0), jnp.stack(v_list, axis=0)) + + return kv_cache, encoder_attention_mask + @jax.named_scope("WanModel") def __call__( self, @@ -598,6 +667,10 @@ def __call__( skip_blocks: Optional[jax.Array] = None, cached_residual: Optional[jax.Array] = None, return_residual: bool = False, + kv_cache: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None, + rotary_emb: Optional[jax.Array] = None, + encoder_attention_mask: Optional[jax.Array] = None, + text_mask: Optional[jax.Array] = None, ) -> Union[jax.Array, Tuple[jax.Array, jax.Array], Dict[str, jax.Array]]: hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None)) batch_size, _, num_frames, height, width = hidden_states.shape @@ -608,7 +681,8 @@ def __call__( hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) with self.conditional_named_scope("rotary_embedding"): - rotary_emb = self.rope(hidden_states) + if rotary_emb is None: + rotary_emb = self.rope(hidden_states) with self.conditional_named_scope("patch_embedding"): hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) @@ -616,27 +690,43 @@ def __call__( ( temb, timestep_proj, - encoder_hidden_states, - encoder_hidden_states_image, - encoder_attention_mask, - ) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) + encoder_hidden_states_out, + encoder_hidden_states_image_out, + encoder_attention_mask_out, + ) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image, skip_embeddings=(kv_cache is not None)) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) + if kv_cache is not None and encoder_attention_mask is not None: + encoder_attention_mask = encoder_attention_mask + else: + encoder_attention_mask = encoder_attention_mask_out + if encoder_attention_mask is None: + encoder_attention_mask = text_mask + if encoder_hidden_states_image is not None: - encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) - if encoder_attention_mask is not None: - text_mask = jnp.ones( - (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), - dtype=jnp.int32, - ) + encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states_out], axis=1) + if kv_cache is None and encoder_attention_mask is not None: + if text_mask is None: + text_mask = jnp.ones( + (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), + dtype=jnp.int32, + ) encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype) + else: + encoder_hidden_states = encoder_hidden_states_out.astype(hidden_states.dtype) def _run_all_blocks(h): if self.scan_layers: - def scan_fn(carry, block): + def scan_fn(carry, block_input): hidden_states_carry, rngs_carry = carry + if kv_cache is not None: + block, layer_kv_cache = block_input + else: + block = block_input + layer_kv_cache = None + hidden_states = block( hidden_states_carry, encoder_hidden_states, @@ -645,6 +735,7 @@ def scan_fn(carry, block): deterministic, rngs_carry, encoder_attention_mask, + cached_kv=layer_kv_cache, ) new_carry = (hidden_states, rngs_carry) return new_carry, None @@ -653,19 +744,28 @@ def scan_fn(carry, block): scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers ) initial_carry = (h, rngs) + + if kv_cache is not None: + scan_input = (self.blocks, kv_cache) + else: + scan_input = self.blocks + final_carry, _ = nnx.scan( rematted_block_forward, length=self.num_layers, in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), - )(initial_carry, self.blocks) + )(initial_carry, scan_input) h_out, _ = final_carry else: h_out = h - for block in self.blocks: + for i, block in enumerate(self.blocks): + layer_kv_cache = None + if kv_cache is not None: + layer_kv_cache = jax.tree_map(lambda x: x[i], kv_cache) - def layer_forward(hidden_states): + def layer_forward(hidden_states, l_kv): return block( hidden_states, encoder_hidden_states, @@ -674,6 +774,7 @@ def layer_forward(hidden_states): deterministic, rngs, encoder_attention_mask=encoder_attention_mask, + cached_kv=l_kv, ) rematted_layer_forward = self.gradient_checkpoint.apply( @@ -682,7 +783,7 @@ def layer_forward(hidden_states): self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers, ) - h_out = rematted_layer_forward(h_out) + h_out = rematted_layer_forward(h_out, layer_kv_cache) return h_out hidden_states_before_blocks = hidden_states diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index f94f7cdf8..a8026d2eb 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -473,7 +473,11 @@ def _get_t5_prompt_embeds( prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - return prompt_embeds + mask = mask.repeat(1, num_videos_per_prompt) + mask = mask.view(batch_size * num_videos_per_prompt, seq_len) + mask = jnp.array(mask.detach().numpy(), dtype=jnp.int32) + + return prompt_embeds, mask def encode_prompt( self, @@ -483,28 +487,36 @@ def encode_prompt( max_sequence_length: int = 226, prompt_embeds: jax.Array = None, negative_prompt_embeds: jax.Array = None, + prompt_mask: jax.Array = None, + negative_prompt_mask: jax.Array = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt if prompt_embeds is None: - prompt_embeds = self._get_t5_prompt_embeds( + prompt_embeds, prompt_mask = self._get_t5_prompt_embeds( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, ) prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32) + else: + if prompt_mask is None: + prompt_mask = jnp.ones((prompt_embeds.shape[0], prompt_embeds.shape[1]), dtype=jnp.int32) if negative_prompt_embeds is None: batch_size = len(prompt_embeds) negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_embeds = self._get_t5_prompt_embeds( + negative_prompt_embeds, negative_prompt_mask = self._get_t5_prompt_embeds( prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, ) negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=jnp.float32) + else: + if negative_prompt_mask is None: + negative_prompt_mask = jnp.ones((negative_prompt_embeds.shape[0], negative_prompt_embeds.shape[1]), dtype=jnp.int32) - return prompt_embeds, negative_prompt_embeds + return prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask def prepare_latents( self, @@ -647,7 +659,7 @@ def _prepare_model_inputs_i2v( effective_batch_size = batch_size * num_videos_per_prompt # 1. Encode Prompts - prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, @@ -691,8 +703,10 @@ def _prepare_model_inputs_i2v( prompt_embeds = jax.device_put(prompt_embeds, data_sharding) negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) image_embeds = jax.device_put(image_embeds, data_sharding) + prompt_mask = jax.device_put(prompt_mask, data_sharding) + negative_prompt_mask = jax.device_put(negative_prompt_mask, data_sharding) - return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size + return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size, prompt_mask, negative_prompt_mask def _prepare_model_inputs( self, @@ -724,7 +738,7 @@ def _prepare_model_inputs( batch_size = len(prompt) with jax.named_scope("Encode-Prompt"): - prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt_embeds, prompt_mask, negative_prompt_embeds, negative_prompt_mask = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, max_sequence_length=max_sequence_length, @@ -752,12 +766,14 @@ def _prepare_model_inputs( latents = jax.device_put(latents, data_sharding) prompt_embeds = jax.device_put(prompt_embeds, data_sharding) negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) + prompt_mask = jax.device_put(prompt_mask, data_sharding) + negative_prompt_mask = jax.device_put(negative_prompt_mask, data_sharding) scheduler_state = self.scheduler.set_timesteps( self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape ) - return latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames + return latents, prompt_embeds, negative_prompt_embeds, prompt_mask, negative_prompt_mask, scheduler_state, num_frames @abstractmethod def __call__(self, **kwargs): @@ -779,6 +795,10 @@ def transformer_forward_pass( skip_blocks=None, cached_residual=None, return_residual=False, + kv_cache=None, + rotary_emb=None, + encoder_attention_mask=None, + text_mask=None, ): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) outputs = wan_transformer( @@ -789,6 +809,10 @@ def transformer_forward_pass( skip_blocks=skip_blocks, cached_residual=cached_residual, return_residual=return_residual, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, + text_mask=text_mask, ) if return_residual: @@ -819,6 +843,10 @@ def transformer_forward_pass_full_cfg( prompt_embeds_combined: jnp.array, guidance_scale: float, encoder_hidden_states_image=None, + kv_cache=None, + rotary_emb=None, + encoder_attention_mask=None, + text_mask=None, ): """Full CFG forward pass. @@ -837,6 +865,10 @@ def transformer_forward_pass_full_cfg( skip_blocks=False, cached_residual=None, return_residual=False, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, + text_mask=text_mask, ) noise_cond = noise_pred[:bsz] noise_uncond = noise_pred[bsz:] @@ -858,6 +890,10 @@ def transformer_forward_pass_cfg_cache( w1: float = 1.0, w2: float = 1.0, encoder_hidden_states_image=None, + kv_cache=None, + rotary_emb=None, + encoder_attention_mask=None, + text_mask=None, ): """CFG-Cache forward pass with FFT frequency-domain compensation. @@ -883,6 +919,10 @@ def transformer_forward_pass_cfg_cache( timestep=timestep_cond, encoder_hidden_states=prompt_cond_embeds, encoder_hidden_states_image=encoder_hidden_states_image, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, + text_mask=text_mask, ) # FFT over spatial dims (H, W) — last 2 dims of [B, C, F, H, W] diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 589ab6076..fe0e12b2d 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -95,6 +95,7 @@ def __call__( magcache_thresh: Optional[float] = None, magcache_K: Optional[int] = None, retention_ratio: Optional[float] = None, + use_kv_cache: bool = False, ): config = getattr(self, "config", None) if magcache_thresh is None: @@ -110,7 +111,7 @@ def __call__( "CFG cache accelerates classifier-free guidance, which is disabled when guidance_scale <= 1.0." ) - latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( + latents, prompt_embeds, negative_prompt_embeds, prompt_mask, negative_prompt_mask, scheduler_state, num_frames = self._prepare_model_inputs( prompt, negative_prompt, height, @@ -140,6 +141,7 @@ def __call__( retention_ratio=retention_ratio, height=height, mag_ratios_base=getattr(config, "mag_ratios_base", None), + use_kv_cache=use_kv_cache, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -150,6 +152,8 @@ def __call__( latents=latents, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + prompt_mask=prompt_mask, + negative_prompt_mask=negative_prompt_mask, ) latents = self._denormalize_latents(latents) return self._decode_latents_to_video(latents) @@ -162,6 +166,8 @@ def run_inference_2_1( latents: jnp.array, prompt_embeds: jnp.array, negative_prompt_embeds: jnp.array, + prompt_mask: jnp.array, + negative_prompt_mask: jnp.array, guidance_scale: float, num_inference_steps: int, scheduler: FlaxUniPCMultistepScheduler, @@ -173,6 +179,7 @@ def run_inference_2_1( retention_ratio: float = 0.2, height: int = 480, mag_ratios_base: Optional[List[float]] = None, + use_kv_cache: bool = False, ): """Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache. @@ -213,8 +220,12 @@ def run_inference_2_1( # Pre-split embeds once, outside the loop. prompt_cond_embeds = prompt_embeds prompt_embeds_combined = None + prompt_mask_combined = None if do_cfg: prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + prompt_mask_combined = jnp.concatenate([prompt_mask, negative_prompt_mask], axis=0) + else: + prompt_mask_combined = prompt_mask # Pre-compute cache schedule and phase-dependent weights. # t₀ = midpoint step; before t₀ boost low-freq, after boost high-freq. @@ -244,6 +255,20 @@ def run_inference_2_1( cached_noise_cond = None cached_noise_uncond = None + transformer_obj = nnx.merge(graphdef, sharded_state, rest_of_state) + + # Compute RoPE once as it only depends on shape + dummy_hidden_states = jnp.zeros((latents.shape[0], latents.shape[2], latents.shape[3], latents.shape[4], latents.shape[1])) + rotary_emb = transformer_obj.rope(dummy_hidden_states) + + kv_cache = None + encoder_attention_mask = None + + if use_kv_cache: + kv_cache, encoder_attention_mask = transformer_obj.compute_kv_cache(prompt_embeds_combined if do_cfg else prompt_cond_embeds, text_mask=prompt_mask_combined) + else: + encoder_attention_mask = prompt_mask_combined + if use_magcache and do_cfg: magcache_init = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) accumulated_state = magcache_init[:6] @@ -273,6 +298,9 @@ def run_inference_2_1( skip_blocks=bool(skip_blocks), cached_residual=cached_residual, return_residual=True, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) if not skip_blocks: @@ -284,6 +312,8 @@ def run_inference_2_1( if is_cache_step: w1, w2 = step_w1w2[step] timestep = jnp.broadcast_to(t, bsz) + kv_cache_cond = jax.tree_map(lambda x: x[:, :bsz], kv_cache) if kv_cache is not None else None + encoder_attention_mask_cond = encoder_attention_mask[:bsz] if encoder_attention_mask is not None else None noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( graphdef, sharded_state, @@ -296,6 +326,9 @@ def run_inference_2_1( guidance_scale=guidance_scale, w1=jnp.float32(w1), w2=jnp.float32(w2), + kv_cache=kv_cache_cond, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask_cond, ) elif do_cfg: @@ -309,6 +342,9 @@ def run_inference_2_1( timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) else: @@ -322,6 +358,9 @@ def run_inference_2_1( prompt_cond_embeds, do_classifier_free_guidance=False, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index f6a3d9370..cd70dab2b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -112,6 +112,7 @@ def __call__( vae_only: bool = False, use_cfg_cache: bool = False, use_sen_cache: bool = False, + use_kv_cache: bool = False, ): if use_cfg_cache and use_sen_cache: raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.") @@ -130,7 +131,7 @@ def __call__( "SenCache requires classifier-free guidance to be enabled for both transformer phases." ) - latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( + latents, prompt_embeds, negative_prompt_embeds, prompt_mask, negative_prompt_mask, scheduler_state, num_frames = self._prepare_model_inputs( prompt, negative_prompt, height, @@ -161,6 +162,7 @@ def __call__( use_cfg_cache=use_cfg_cache, use_sen_cache=use_sen_cache, height=height, + use_kv_cache=use_kv_cache, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -174,6 +176,8 @@ def __call__( latents=latents, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + prompt_mask=prompt_mask, + negative_prompt_mask=negative_prompt_mask, ) latents = self._denormalize_latents(latents) return self._decode_latents_to_video(latents) @@ -189,6 +193,8 @@ def run_inference_2_2( latents: jnp.array, prompt_embeds: jnp.array, negative_prompt_embeds: jnp.array, + prompt_mask: jnp.array, + negative_prompt_mask: jnp.array, guidance_scale_low: float, guidance_scale_high: float, boundary: int, @@ -198,6 +204,7 @@ def run_inference_2_2( use_cfg_cache: bool = False, use_sen_cache: bool = False, height: int = 480, + use_kv_cache: bool = False, ): """Denoising loop for WAN 2.2 T2V with optional caching acceleration. @@ -217,6 +224,33 @@ def run_inference_2_2( do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 bsz = latents.shape[0] + prompt_embeds_combined = ( + jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds + ) + prompt_mask_combined = ( + jnp.concatenate([prompt_mask, negative_prompt_mask], axis=0) if do_classifier_free_guidance else prompt_mask + ) + + low_transformer = nnx.merge(low_noise_graphdef, low_noise_state, low_noise_rest) + + # Compute RoPE once as it only depends on shape + dummy_hidden_states = jnp.zeros((latents.shape[0], latents.shape[2], latents.shape[3], latents.shape[4], latents.shape[1])) + rotary_emb = low_transformer.rope(dummy_hidden_states) + + kv_cache_low = None + encoder_attention_mask_low = None + kv_cache_high = None + encoder_attention_mask_high = None + + if use_kv_cache: + kv_cache_low, encoder_attention_mask_low = low_transformer.compute_kv_cache(prompt_embeds_combined, text_mask=prompt_mask_combined) + + high_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) + kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache(prompt_embeds_combined, text_mask=prompt_mask_combined) + else: + encoder_attention_mask_low = prompt_mask_combined + encoder_attention_mask_high = prompt_mask_combined + # ── SenCache path (arXiv:2602.24208) ── if use_sen_cache and do_classifier_free_guidance: timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) @@ -240,8 +274,6 @@ def run_inference_2_2( # uses sigmas in [0, 1]. Without normalization |Δt|≈20 >> ε and nothing caches. num_train_timesteps = float(scheduler.config.num_train_timesteps) - prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - # SenCache state ref_noise_pred = None # y^r: cached denoiser output ref_latent = None # x^r: latent at last cache refresh @@ -259,9 +291,13 @@ def run_inference_2_2( if step_uses_high[step]: graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low # Force full compute: warmup, first 30%, last 10%, or transformer boundary is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1] @@ -280,6 +316,9 @@ def run_inference_2_2( timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) ref_noise_pred = noise_pred ref_latent = latents @@ -316,6 +355,9 @@ def run_inference_2_2( timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) ref_noise_pred = noise_pred ref_latent = latents @@ -352,7 +394,6 @@ def run_inference_2_2( # Pre-split embeds once prompt_cond_embeds = prompt_embeds - prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) # Determine the first low-noise step (boundary transition). # In Wan 2.2 the boundary IS the structural→detail transition, so @@ -400,14 +441,20 @@ def run_inference_2_2( if step_uses_high[step]: graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low if is_cache_step: # ── Cache step: cond-only forward + FFT frequency compensation ── w1, w2 = step_w1w2[step] timestep = jnp.broadcast_to(t, bsz) + kv_cache_cond = jax.tree_map(lambda x: x[:, :bsz], kv_cache) if kv_cache is not None else None + encoder_attention_mask_cond = encoder_attention_mask[:bsz] if encoder_attention_mask is not None else None noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( graphdef, state, @@ -420,6 +467,9 @@ def run_inference_2_2( guidance_scale=guidance_scale, w1=jnp.float32(w1), w2=jnp.float32(w2), + kv_cache=kv_cache_cond, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask_cond, ) else: # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── @@ -433,6 +483,9 @@ def run_inference_2_2( timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() @@ -445,9 +498,7 @@ def run_inference_2_2( timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)] - prompt_embeds_combined = ( - jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds - ) + for step in range(num_inference_steps): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] @@ -455,9 +506,13 @@ def run_inference_2_2( if step_uses_high[step]: graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low if do_classifier_free_guidance: latents_doubled = jnp.concatenate([latents] * 2) @@ -470,6 +525,9 @@ def run_inference_2_2( timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) else: timestep = jnp.broadcast_to(t, bsz) @@ -482,6 +540,9 @@ def run_inference_2_2( prompt_embeds, do_classifier_free_guidance, guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index b98aa2961..df475a350 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -153,6 +153,7 @@ def __call__( magcache_thresh: Optional[float] = None, magcache_K: Optional[int] = None, retention_ratio: Optional[float] = None, + use_kv_cache: bool = False, ): config = getattr(self, "config", None) if magcache_thresh is None: @@ -176,7 +177,7 @@ def __call__( max_logging.log(f"Adjusted num_frames to: {num_frames}") num_frames = max(num_frames, 1) - prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v( + prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size, prompt_mask, negative_prompt_mask = self._prepare_model_inputs_i2v( prompt, image, negative_prompt, @@ -233,6 +234,8 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): prompt_embeds = jax.device_put(prompt_embeds, data_sharding) negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) image_embeds = jax.device_put(image_embeds, data_sharding) + prompt_mask = jax.device_put(prompt_mask, data_sharding) + negative_prompt_mask = jax.device_put(negative_prompt_mask, data_sharding) if first_frame_mask is not None: first_frame_mask = jax.device_put(first_frame_mask, data_sharding) @@ -250,6 +253,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): retention_ratio=retention_ratio, height=height, mag_ratios_base=self.config.mag_ratios_base_720p if height >= 720 else self.config.mag_ratios_base_480p, + use_kv_cache=use_kv_cache, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -259,6 +263,8 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, image_embeds=image_embeds, + prompt_mask=prompt_mask, + negative_prompt_mask=negative_prompt_mask, scheduler_state=scheduler_state, ) latents = jnp.transpose(latents, (0, 4, 1, 2, 3)) @@ -278,6 +284,8 @@ def run_inference_2_1_i2v( prompt_embeds: jnp.array, negative_prompt_embeds: jnp.array, image_embeds: jnp.array, + prompt_mask: jnp.array, + negative_prompt_mask: jnp.array, guidance_scale: float, num_inference_steps: int, scheduler: FlaxUniPCMultistepScheduler, @@ -288,6 +296,7 @@ def run_inference_2_1_i2v( retention_ratio: float = 0.2, height: int = 480, mag_ratios_base: Optional[List[float]] = None, + use_kv_cache: bool = False, ): do_cfg = guidance_scale > 1.0 @@ -302,10 +311,24 @@ def run_inference_2_1_i2v( prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0) condition_combined = jnp.concatenate([condition] * 2) + prompt_mask_combined = jnp.concatenate([prompt_mask, negative_prompt_mask], axis=0) else: prompt_embeds_combined = prompt_embeds image_embeds_combined = image_embeds condition_combined = condition + prompt_mask_combined = prompt_mask + + transformer_obj = nnx.merge(graphdef, sharded_state, rest_of_state) + + # Compute RoPE once as it only depends on shape + dummy_hidden_states = jnp.zeros((latents.shape[0], latents.shape[2], latents.shape[3], latents.shape[4], latents.shape[1])) + rotary_emb = transformer_obj.rope(dummy_hidden_states) + + kv_cache = None + encoder_attention_mask = None + + if use_kv_cache: + kv_cache, encoder_attention_mask = transformer_obj.compute_kv_cache(prompt_embeds_combined, image_embeds_combined, text_mask=prompt_mask_combined) for step in range(num_inference_steps): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] @@ -337,6 +360,10 @@ def run_inference_2_1_i2v( skip_blocks=bool(skip_blocks) if use_magcache and do_cfg else None, cached_residual=cached_residual if use_magcache and do_cfg else None, return_residual=True if use_magcache and do_cfg else False, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, + text_mask=prompt_mask_combined, ) if use_magcache and do_cfg: noise_pred, _, residual_x_cur = outputs diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index ffbe14965..dab055353 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -168,6 +168,7 @@ def __call__( rng: Optional[jax.Array] = None, use_cfg_cache: bool = False, use_sen_cache: bool = False, + use_kv_cache: bool = False, ): if use_cfg_cache and use_sen_cache: raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.") @@ -199,7 +200,7 @@ def __call__( max_logging.log(f"Adjusted num_frames to: {num_frames}") num_frames = max(num_frames, 1) - prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v( + prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size, prompt_mask, negative_prompt_mask = self._prepare_model_inputs_i2v( prompt, image, negative_prompt, @@ -258,6 +259,8 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): condition = jax.device_put(condition, data_sharding) prompt_embeds = jax.device_put(prompt_embeds, data_sharding) negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) + prompt_mask = jax.device_put(prompt_mask, data_sharding) + negative_prompt_mask = jax.device_put(negative_prompt_mask, data_sharding) # WAN 2.2 I2V doesn't use image_embeds (it's None), but we still need to pass it to the function if image_embeds is not None: image_embeds = jax.device_put(image_embeds, data_sharding) @@ -277,6 +280,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): use_cfg_cache=use_cfg_cache, use_sen_cache=use_sen_cache, height=height, + use_kv_cache=use_kv_cache, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -291,6 +295,8 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): condition=condition, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + prompt_mask=prompt_mask, + negative_prompt_mask=negative_prompt_mask, scheduler_state=scheduler_state, ) latents = jnp.transpose(latents, (0, 4, 1, 2, 3)) @@ -312,6 +318,8 @@ def run_inference_2_2_i2v( condition: jnp.array, prompt_embeds: jnp.array, negative_prompt_embeds: jnp.array, + prompt_mask: jnp.array, + negative_prompt_mask: jnp.array, image_embeds: jnp.array, guidance_scale_low: float, guidance_scale_high: float, @@ -322,10 +330,40 @@ def run_inference_2_2_i2v( use_cfg_cache: bool = False, use_sen_cache: bool = False, height: int = 480, + use_kv_cache: bool = False, ): do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 bsz = latents.shape[0] + prompt_embeds_combined = ( + jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds + ) + prompt_mask_combined = ( + jnp.concatenate([prompt_mask, negative_prompt_mask], axis=0) if do_classifier_free_guidance else prompt_mask + ) + + if image_embeds is not None: + image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0) if do_classifier_free_guidance else image_embeds + else: + image_embeds_combined = None + + low_transformer = nnx.merge(low_noise_graphdef, low_noise_state, low_noise_rest) + + # Compute RoPE once as it only depends on shape + dummy_hidden_states = jnp.zeros((latents.shape[0], latents.shape[2], latents.shape[3], latents.shape[4], latents.shape[1])) + rotary_emb = low_transformer.rope(dummy_hidden_states) + + kv_cache_low = None + encoder_attention_mask_low = None + kv_cache_high = None + encoder_attention_mask_high = None + + if use_kv_cache: + kv_cache_low, encoder_attention_mask_low = low_transformer.compute_kv_cache(prompt_embeds_combined, image_embeds_combined, text_mask=prompt_mask_combined) + + high_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) + kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache(prompt_embeds_combined, image_embeds_combined, text_mask=prompt_mask_combined) + # ── SenCache path (arXiv:2602.24208) ── if use_sen_cache and do_classifier_free_guidance: timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) @@ -343,7 +381,6 @@ def run_inference_2_2_i2v( nocache_end_begin = int(num_inference_steps * (1.0 - nocache_end_ratio)) num_train_timesteps = float(scheduler.config.num_train_timesteps) - prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if image_embeds is not None: image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0) else: @@ -366,9 +403,13 @@ def run_inference_2_2_i2v( if step_uses_high[step]: graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1] force_compute = ( @@ -389,6 +430,10 @@ def run_inference_2_2_i2v( prompt_embeds_combined, guidance_scale=guidance_scale, encoder_hidden_states_image=image_embeds_combined, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, + text_mask=prompt_mask_combined, ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) ref_noise_pred = noise_pred @@ -425,6 +470,10 @@ def run_inference_2_2_i2v( prompt_embeds_combined, guidance_scale=guidance_scale, encoder_hidden_states_image=image_embeds_combined, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, + text_mask=prompt_mask_combined, ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) ref_noise_pred = noise_pred @@ -461,14 +510,6 @@ def run_inference_2_2_i2v( # Pre-split embeds prompt_cond_embeds = prompt_embeds - prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - - if image_embeds is not None: - image_embeds_cond = image_embeds - image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0) - else: - image_embeds_cond = None - image_embeds_combined = None # Keep condition in both single and doubled forms condition_cond = condition @@ -514,9 +555,13 @@ def run_inference_2_2_i2v( if step_uses_high[step]: graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low if is_cache_step: # ── Cache step: cond-only forward + FFT frequency compensation ── @@ -525,6 +570,8 @@ def run_inference_2_2_i2v( latent_model_input = jnp.concatenate([latents, condition_cond], axis=-1) latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) timestep = jnp.broadcast_to(t, bsz) + kv_cache_cond = jax.tree_map(lambda x: x[:, :bsz], kv_cache) if kv_cache is not None else None + encoder_attention_mask_cond = encoder_attention_mask[:bsz] if encoder_attention_mask is not None else None noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( graphdef, state, @@ -537,7 +584,11 @@ def run_inference_2_2_i2v( guidance_scale=guidance_scale, w1=jnp.float32(w1), w2=jnp.float32(w2), - encoder_hidden_states_image=image_embeds_cond, + encoder_hidden_states_image=image_embeds, + kv_cache=kv_cache_cond, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask_cond, + text_mask=prompt_mask_combined[:bsz] if prompt_mask_combined is not None else None, ) else: # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── @@ -554,6 +605,10 @@ def run_inference_2_2_i2v( prompt_embeds_combined, guidance_scale=guidance_scale, encoder_hidden_states_image=image_embeds_combined, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, + text_mask=prompt_mask_combined, ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) # BCFHW -> BFHWC @@ -562,7 +617,7 @@ def run_inference_2_2_i2v( # ── Original non-cache path ── def high_noise_branch(operands): - latents_input, ts_input, pe_input, ie_input = operands + latents_input, ts_input, pe_input, ie_input, kv_cache_high, _, r_emb, mask_high, _, text_mask = operands latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) noise_pred, latents_out = transformer_forward_pass( high_noise_graphdef, @@ -574,11 +629,15 @@ def high_noise_branch(operands): do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_high, encoder_hidden_states_image=ie_input, + kv_cache=kv_cache_high, + rotary_emb=r_emb, + encoder_attention_mask=mask_high, + text_mask=text_mask, ) return noise_pred, latents_out def low_noise_branch(operands): - latents_input, ts_input, pe_input, ie_input = operands + latents_input, ts_input, pe_input, ie_input, _, kv_cache_low, r_emb, _, mask_low, text_mask = operands latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) noise_pred, latents_out = transformer_forward_pass( low_noise_graphdef, @@ -590,6 +649,10 @@ def low_noise_branch(operands): do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_low, encoder_hidden_states_image=ie_input, + kv_cache=kv_cache_low, + rotary_emb=r_emb, + encoder_attention_mask=mask_low, + text_mask=text_mask, ) return noise_pred, latents_out @@ -610,7 +673,10 @@ def low_noise_branch(operands): use_high_noise = jnp.greater_equal(t, boundary) noise_pred, _ = jax.lax.cond( - use_high_noise, high_noise_branch, low_noise_branch, (latent_model_input, timestep, prompt_embeds, image_embeds) + use_high_noise, + high_noise_branch, + low_noise_branch, + (latent_model_input, timestep, prompt_embeds, image_embeds, kv_cache_high, kv_cache_low, rotary_emb, encoder_attention_mask_high, encoder_attention_mask_low, prompt_mask_combined), ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()