Compare commits

..

5 Commits

Author SHA1 Message Date
Patrick von Platen
f5942649f5 Release: v.0.21.4-patch 2023-09-29 15:39:55 +00:00
Patrick von Platen
edea57749e [Lora] fix lora fuse unfuse (#5003)
* fix lora fuse unfuse

* add same changes to loaders.py

* add test

---------

Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>
2023-09-29 15:30:16 +00:00
Sayak Paul
c37c840b1b update 2023-09-27 22:13:19 +05:30
Sayak Paul
9858053bfe Release: v0.21.3 2023-09-27 22:10:50 +05:30
Patrick von Platen
6a3301fe34 resolve conflicts. 2023-09-27 22:09:04 +05:30
13 changed files with 91 additions and 41 deletions

View File

@@ -244,7 +244,7 @@ install_requires = [
setup(
name="diffusers",
version="0.21.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.21.4", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -1,4 +1,4 @@
__version__ = "0.21.2"
__version__ = "0.21.4"
from typing import TYPE_CHECKING

View File

@@ -119,7 +119,7 @@ class PatchedLoraProjection(nn.Module):
self.lora_scale = lora_scale
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.regular_linear_layer.weight.data
@@ -2666,6 +2666,7 @@ class FromOriginalControlnetMixin:
return controlnet
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""

View File

@@ -304,19 +304,16 @@ class Attention(nn.Module):
self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor"):
if (
hasattr(self, "processor")
and not isinstance(processor, LORA_ATTENTION_PROCESSORS)
and self.to_q.lora_layer is not None
):
def set_processor(self, processor: "AttnProcessor", _remove_lora=False):
if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
deprecate(
"set_processor to offload LoRA",
"0.26.0",
"In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
"In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
)
# TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
# We need to remove all LoRA layers
# Don't forget to remove ALL `_remove_lora` from the codebase
for module in self.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)

View File

@@ -196,7 +196,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -220,9 +222,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -244,7 +246,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:

View File

@@ -517,7 +517,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -541,9 +543,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -565,7 +567,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):

View File

@@ -139,7 +139,7 @@ class LoRACompatibleConv(nn.Conv2d):
self._lora_scale = lora_scale
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.weight.data
@@ -204,7 +204,7 @@ class LoRACompatibleLinear(nn.Linear):
self._lora_scale = lora_scale
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.weight.data

View File

@@ -191,7 +191,9 @@ class PriorTransformer(ModelMixin, ConfigMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -215,9 +217,9 @@ class PriorTransformer(ModelMixin, ConfigMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -239,7 +241,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
def forward(
self,

View File

@@ -613,7 +613,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -637,9 +639,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -660,7 +662,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
def set_attention_slice(self, slice_size):
r"""

View File

@@ -366,7 +366,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -390,9 +392,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -454,7 +456,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):

View File

@@ -538,7 +538,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -562,9 +564,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -586,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):

View File

@@ -820,7 +820,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -844,9 +846,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -868,7 +870,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f" {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
def set_attention_slice(self, slice_size):
r"""

View File

@@ -43,7 +43,7 @@ from diffusers.models.attention_processor import (
LoRAAttnProcessor2_0,
XFormersAttnProcessor,
)
from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu, slow, torch_device
from diffusers.utils.testing_utils import floats_tensor, nightly, require_torch_gpu, slow, torch_device
def create_unet_lora_layers(unet: nn.Module):
@@ -1464,3 +1464,41 @@ class LoraIntegrationTests(unittest.TestCase):
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
@nightly
def test_sequential_fuse_unfuse(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
# 1. round
pipe.load_lora_weights("Pclanglais/TintinIA")
pipe.fuse_lora()
generator = torch.Generator().manual_seed(0)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
image_slice = images[0, -3:, -3:, -1].flatten()
pipe.unfuse_lora()
# 2. round
pipe.load_lora_weights("ProomptEngineer/pe-balloon-diffusion-style")
pipe.fuse_lora()
pipe.unfuse_lora()
# 3. round
pipe.load_lora_weights("ostris/crayon_style_lora_sdxl")
pipe.fuse_lora()
pipe.unfuse_lora()
# 4. back to 1st round
pipe.load_lora_weights("Pclanglais/TintinIA")
pipe.fuse_lora()
generator = torch.Generator().manual_seed(0)
images_2 = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
image_slice_2 = images_2[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(image_slice, image_slice_2, atol=1e-3))