2020import torch
2121
2222from diffusers import UNet2DConditionModel
23- from diffusers .models .cross_attention import LoRACrossAttnProcessor
23+ from diffusers .models .cross_attention import CrossAttnProcessor , LoRACrossAttnProcessor
2424from diffusers .utils import (
2525 floats_tensor ,
2626 load_hf_numpy ,
4040torch .backends .cuda .matmul .allow_tf32 = False
4141
4242
43+ def create_lora_layers (model ):
44+ lora_attn_procs = {}
45+ for name in model .attn_processors .keys ():
46+ cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
47+ if name .startswith ("mid_block" ):
48+ hidden_size = model .config .block_out_channels [- 1 ]
49+ elif name .startswith ("up_blocks" ):
50+ block_id = int (name [len ("up_blocks." )])
51+ hidden_size = list (reversed (model .config .block_out_channels ))[block_id ]
52+ elif name .startswith ("down_blocks" ):
53+ block_id = int (name [len ("down_blocks." )])
54+ hidden_size = model .config .block_out_channels [block_id ]
55+
56+ lora_attn_procs [name ] = LoRACrossAttnProcessor (
57+ hidden_size = hidden_size , cross_attention_dim = cross_attention_dim
58+ )
59+ lora_attn_procs [name ] = lora_attn_procs [name ].to (model .device )
60+
61+ # add 1 to weights to mock trained weights
62+ with torch .no_grad ():
63+ lora_attn_procs [name ].to_q_lora .up .weight += 1
64+ lora_attn_procs [name ].to_k_lora .up .weight += 1
65+ lora_attn_procs [name ].to_v_lora .up .weight += 1
66+ lora_attn_procs [name ].to_out_lora .up .weight += 1
67+
68+ return lora_attn_procs
69+
70+
4371class UNet2DConditionModelTests (ModelTesterMixin , unittest .TestCase ):
4472 model_class = UNet2DConditionModel
4573
@@ -336,30 +364,7 @@ def test_lora_save_load(self):
336364 with torch .no_grad ():
337365 old_sample = model (** inputs_dict ).sample
338366
339- lora_attn_procs = {}
340- for name in model .attn_processors .keys ():
341- cross_attention_dim = None if name .endswith ("attn1.processor" ) else model .config .cross_attention_dim
342- if name .startswith ("mid_block" ):
343- hidden_size = model .config .block_out_channels [- 1 ]
344- elif name .startswith ("up_blocks" ):
345- block_id = int (name [len ("up_blocks." )])
346- hidden_size = list (reversed (model .config .block_out_channels ))[block_id ]
347- elif name .startswith ("down_blocks" ):
348- block_id = int (name [len ("down_blocks." )])
349- hidden_size = model .config .block_out_channels [block_id ]
350-
351- lora_attn_procs [name ] = LoRACrossAttnProcessor (
352- hidden_size = hidden_size , cross_attention_dim = cross_attention_dim
353- )
354- lora_attn_procs [name ] = lora_attn_procs [name ].to (model .device )
355-
356- # add 1 to weights to mock trained weights
357- with torch .no_grad ():
358- lora_attn_procs [name ].to_q_lora .up .weight += 1
359- lora_attn_procs [name ].to_k_lora .up .weight += 1
360- lora_attn_procs [name ].to_v_lora .up .weight += 1
361- lora_attn_procs [name ].to_out_lora .up .weight += 1
362-
367+ lora_attn_procs = create_lora_layers (model )
363368 model .set_attn_processor (lora_attn_procs )
364369
365370 with torch .no_grad ():
@@ -380,6 +385,33 @@ def test_lora_save_load(self):
380385 # LoRA and no LoRA should NOT be the same
381386 assert (sample - old_sample ).abs ().max () > 1e-4
382387
388+ def test_lora_on_off (self ):
389+ # enable deterministic behavior for gradient checkpointing
390+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
391+
392+ init_dict ["attention_head_dim" ] = (8 , 16 )
393+
394+ torch .manual_seed (0 )
395+ model = self .model_class (** init_dict )
396+ model .to (torch_device )
397+
398+ with torch .no_grad ():
399+ old_sample = model (** inputs_dict ).sample
400+
401+ lora_attn_procs = create_lora_layers (model )
402+ model .set_attn_processor (lora_attn_procs )
403+
404+ with torch .no_grad ():
405+ sample = model (** inputs_dict , cross_attention_kwargs = {"scale" : 0.0 }).sample
406+
407+ model .set_attn_processor (CrossAttnProcessor ())
408+
409+ with torch .no_grad ():
410+ new_sample = model (** inputs_dict ).sample
411+
412+ assert (sample - new_sample ).abs ().max () < 1e-4
413+ assert (sample - old_sample ).abs ().max () < 1e-4
414+
383415
384416@slow
385417class UNet2DConditionModelIntegrationTests (unittest .TestCase ):
0 commit comments