Skip to content

Commit f653ded

Browse files
[LoRA] Make sure LoRA can be disabled after it's run (huggingface#2128)
1 parent e92d43f commit f653ded

2 files changed

Lines changed: 71 additions & 25 deletions

File tree

src/diffusers/models/cross_attention.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
import torch.nn.functional as F
1818
from torch import nn
1919

20+
from ..utils import logging
2021
from ..utils.import_utils import is_xformers_available
2122

2223

24+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25+
26+
2327
if is_xformers_available():
2428
import xformers
2529
import xformers.ops
@@ -151,6 +155,16 @@ def set_attention_slice(self, slice_size):
151155
self.set_processor(processor)
152156

153157
def set_processor(self, processor: "AttnProcessor"):
158+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
159+
# pop `processor` from `self._modules`
160+
if (
161+
hasattr(self, "processor")
162+
and isinstance(self.processor, torch.nn.Module)
163+
and not isinstance(processor, torch.nn.Module)
164+
):
165+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
166+
self._modules.pop("processor")
167+
154168
self.processor = processor
155169

156170
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):

tests/models/test_models_unet_2d_condition.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121

2222
from diffusers import UNet2DConditionModel
23-
from diffusers.models.cross_attention import LoRACrossAttnProcessor
23+
from diffusers.models.cross_attention import CrossAttnProcessor, LoRACrossAttnProcessor
2424
from diffusers.utils import (
2525
floats_tensor,
2626
load_hf_numpy,
@@ -40,6 +40,34 @@
4040
torch.backends.cuda.matmul.allow_tf32 = False
4141

4242

43+
def create_lora_layers(model):
44+
lora_attn_procs = {}
45+
for name in model.attn_processors.keys():
46+
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
47+
if name.startswith("mid_block"):
48+
hidden_size = model.config.block_out_channels[-1]
49+
elif name.startswith("up_blocks"):
50+
block_id = int(name[len("up_blocks.")])
51+
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
52+
elif name.startswith("down_blocks"):
53+
block_id = int(name[len("down_blocks.")])
54+
hidden_size = model.config.block_out_channels[block_id]
55+
56+
lora_attn_procs[name] = LoRACrossAttnProcessor(
57+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
58+
)
59+
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
60+
61+
# add 1 to weights to mock trained weights
62+
with torch.no_grad():
63+
lora_attn_procs[name].to_q_lora.up.weight += 1
64+
lora_attn_procs[name].to_k_lora.up.weight += 1
65+
lora_attn_procs[name].to_v_lora.up.weight += 1
66+
lora_attn_procs[name].to_out_lora.up.weight += 1
67+
68+
return lora_attn_procs
69+
70+
4371
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
4472
model_class = UNet2DConditionModel
4573

@@ -336,30 +364,7 @@ def test_lora_save_load(self):
336364
with torch.no_grad():
337365
old_sample = model(**inputs_dict).sample
338366

339-
lora_attn_procs = {}
340-
for name in model.attn_processors.keys():
341-
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
342-
if name.startswith("mid_block"):
343-
hidden_size = model.config.block_out_channels[-1]
344-
elif name.startswith("up_blocks"):
345-
block_id = int(name[len("up_blocks.")])
346-
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
347-
elif name.startswith("down_blocks"):
348-
block_id = int(name[len("down_blocks.")])
349-
hidden_size = model.config.block_out_channels[block_id]
350-
351-
lora_attn_procs[name] = LoRACrossAttnProcessor(
352-
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
353-
)
354-
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
355-
356-
# add 1 to weights to mock trained weights
357-
with torch.no_grad():
358-
lora_attn_procs[name].to_q_lora.up.weight += 1
359-
lora_attn_procs[name].to_k_lora.up.weight += 1
360-
lora_attn_procs[name].to_v_lora.up.weight += 1
361-
lora_attn_procs[name].to_out_lora.up.weight += 1
362-
367+
lora_attn_procs = create_lora_layers(model)
363368
model.set_attn_processor(lora_attn_procs)
364369

365370
with torch.no_grad():
@@ -380,6 +385,33 @@ def test_lora_save_load(self):
380385
# LoRA and no LoRA should NOT be the same
381386
assert (sample - old_sample).abs().max() > 1e-4
382387

388+
def test_lora_on_off(self):
389+
# enable deterministic behavior for gradient checkpointing
390+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
391+
392+
init_dict["attention_head_dim"] = (8, 16)
393+
394+
torch.manual_seed(0)
395+
model = self.model_class(**init_dict)
396+
model.to(torch_device)
397+
398+
with torch.no_grad():
399+
old_sample = model(**inputs_dict).sample
400+
401+
lora_attn_procs = create_lora_layers(model)
402+
model.set_attn_processor(lora_attn_procs)
403+
404+
with torch.no_grad():
405+
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
406+
407+
model.set_attn_processor(CrossAttnProcessor())
408+
409+
with torch.no_grad():
410+
new_sample = model(**inputs_dict).sample
411+
412+
assert (sample - new_sample).abs().max() < 1e-4
413+
assert (sample - old_sample).abs().max() < 1e-4
414+
383415

384416
@slow
385417
class UNet2DConditionModelIntegrationTests(unittest.TestCase):

0 commit comments

Comments
 (0)