Compare commits

...

49 Commits

Author SHA1 Message Date
Sourab Mangrulkar
ece3b02f13 Merge branch 'main' into peftpart-1 2023-09-22 12:36:20 +05:30
Sourab Mangrulkar
0985d17ea9 peft integration features for text encoder
1. support multiple rank/alpha values
2. support multiple active adapters
3. support disabling and enabling adapters
2023-09-22 12:35:01 +05:30
younesbelkada
920333ffaa Merge remote-tracking branch 'upstream/main' into peftpart-1 2023-09-20 12:17:26 +00:00
younesbelkada
71650d403d try to fix merge conflicts 2023-09-20 12:16:32 +00:00
younesbelkada
5e6f343d16 revert 2023-09-20 12:14:52 +00:00
younesbelkada
724b52bc56 add deprecate 2023-09-20 12:12:54 +00:00
younesbelkada
bd46ae9db7 more docstring 2023-09-20 11:57:43 +00:00
younesbelkada
e072655765 added docstrings 2023-09-20 11:54:52 +00:00
younesbelkada
b72ef23dfc conv2d support for recurse remove 2023-09-20 11:48:51 +00:00
Younes Belkada
b412adc158 Apply suggestions from code review
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-09-20 13:45:09 +02:00
younesbelkada
325462dcd2 add comment 2023-09-20 11:44:40 +00:00
Younes Belkada
cb484056c8 Apply suggestions from code review
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-09-20 13:38:20 +02:00
Younes Belkada
e836b145e8 Apply suggestions from code review
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-09-20 13:36:09 +02:00
younesbelkada
d01a29273e style 2023-09-19 15:08:46 +00:00
younesbelkada
3d7c567f90 better fix 2023-09-19 15:06:00 +00:00
younesbelkada
27e3da69dc fix 2023-09-19 15:03:40 +00:00
younesbelkada
ea05959c6a fix weird error with python 3.8 2023-09-19 15:03:15 +00:00
younesbelkada
40a489457d fix CI 2023-09-19 14:52:21 +00:00
younesbelkada
c90f85d3f0 fix examples 2023-09-19 14:31:53 +00:00
younesbelkada
9cb8563b1d use base_version 2023-09-19 14:03:00 +00:00
younesbelkada
74e33a9376 oops 2023-09-19 13:54:31 +00:00
younesbelkada
b83fcbaf86 use class method instead 2023-09-19 13:53:57 +00:00
younesbelkada
dc83fa0ec7 remove unneeded methods 2023-09-19 13:43:10 +00:00
younesbelkada
3ba2d4eb05 Merge remote-tracking branch 'upstream/main' into peftpart-1 2023-09-19 13:39:39 +00:00
younesbelkada
f8e87f6220 add conversion utils 2023-09-19 13:37:59 +00:00
younesbelkada
f8909061ee move tests 2023-09-19 13:00:25 +00:00
younesbelkada
6f1adcd65d nit 2023-09-19 12:52:56 +00:00
younesbelkada
9d650c9032 Merge branch 'peftpart-1' of https://github.com/younesbelkada/diffusers into peftpart-1 2023-09-19 12:44:30 +00:00
younesbelkada
ecbc7144f1 Merge remote-tracking branch 'upstream/main' into peftpart-1 2023-09-19 12:20:48 +00:00
Younes Belkada
78a01d5151 Merge branch 'main' into peftpart-1 2023-09-19 14:12:14 +02:00
younesbelkada
78a860d276 adjustments on adjust_lora_scale_text_encoder 2023-09-19 12:10:23 +00:00
younesbelkada
1d13f40548 keep old modules for BC 2023-09-19 12:04:33 +00:00
younesbelkada
4162ddfdba replace with recurse_replace_peft_layers 2023-09-19 11:12:43 +00:00
Younes Belkada
c4295c9432 Update src/diffusers/loaders.py
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-09-19 13:10:28 +02:00
younesbelkada
0c62ef3daf Merge remote-tracking branch 'upstream/main' into peftpart-1 2023-09-18 15:57:48 +00:00
younesbelkada
40a60286b4 fix fuse text encoder 2023-09-18 15:19:05 +00:00
younesbelkada
ec87c196f3 style 2023-09-18 14:29:08 +00:00
younesbelkada
d56a14db7b protect torch import 2023-09-18 14:07:31 +00:00
Younes Belkada
c06c40bad6 Merge branch 'main' into peftpart-1 2023-09-18 16:04:31 +02:00
younesbelkada
14db139116 few todos 2023-09-18 14:03:26 +00:00
Younes Belkada
7918851640 Apply suggestions from code review
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-09-18 16:00:24 +02:00
younesbelkada
691368b060 v1 rzfactor CI 2023-09-18 13:51:40 +00:00
younesbelkada
cdbe7391a8 more changes 2023-09-15 14:48:49 +00:00
younesbelkada
961e776298 oops 2023-09-15 13:16:14 +00:00
younesbelkada
5a150b2059 add in setup 2023-09-15 13:14:12 +00:00
younesbelkada
01f6d1d88c style 2023-09-15 13:13:59 +00:00
younesbelkada
2a6e5358a0 up 2023-09-15 13:11:30 +00:00
younesbelkada
c17634c39e up 2023-09-15 13:03:17 +00:00
younesbelkada
ba24f2a5ce more fixes 2023-09-15 12:58:22 +00:00
46 changed files with 803 additions and 164 deletions

View File

@@ -128,6 +128,7 @@ _deps = [
"torchvision",
"transformers>=4.25.1",
"urllib3<=2.0.0",
"peft>=0.5.0"
]
# this is a lookup table with items like:

View File

@@ -41,4 +41,5 @@ deps = {
"torchvision": "torchvision",
"transformers": "transformers>=4.25.1",
"urllib3": "urllib3<=2.0.0",
"peft": "peft>=0.5.0",
}

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import re
from collections import defaultdict
@@ -23,6 +24,7 @@ import requests
import safetensors
import torch
from huggingface_hub import hf_hub_download, model_info
from packaging import version
from torch import nn
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
@@ -30,11 +32,20 @@ from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
get_adapter_name,
get_rank_and_alpha_pattern,
is_accelerate_available,
is_omegaconf_available,
is_peft_available,
is_transformers_available,
logging,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .utils.import_utils import BACKENDS_MAPPING
@@ -61,6 +72,21 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
# available.
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5")
_required_transformers_version = version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
class PatchedLoraProjection(nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
super().__init__()
@@ -1077,8 +1103,11 @@ class LoraLoaderMixin:
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
num_fused_loras = 0
use_peft_backend = USE_PEFT_BACKEND
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
@@ -1122,6 +1151,7 @@ class LoraLoaderMixin:
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
_pipeline=self,
adapter_name=adapter_name,
)
@classmethod
@@ -1478,6 +1508,7 @@ class LoraLoaderMixin:
lora_scale=1.0,
low_cpu_mem_usage=None,
_pipeline=None,
adapter_name=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1500,6 +1531,9 @@ class LoraLoaderMixin:
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
@@ -1520,55 +1554,35 @@ class LoraLoaderMixin:
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
if cls.use_peft_backend:
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
# Convert from the old naming convention to the new naming convention.
#
# Previously, the old LoRA layers were stored on the state dict at the
# same level as the attention block i.e.
# `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`.
#
# This is no actual module at that point, they were monkey patched on to the
# existing module. We want to be able to load them via their actual state dict.
# They're in `PatchedLoraProjection.lora_linear_layer` now.
for name, _ in text_encoder_attn_modules(text_encoder):
text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight")
rank_key = f"{name}.out_proj.lora_B.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
else:
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
if network_alphas is not None:
alpha_keys = [
@@ -1578,56 +1592,90 @@ class LoraLoaderMixin:
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if cls.use_peft_backend:
from peft import LoraConfig
is_pipeline_offloaded = _pipeline is not None and any(
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
)
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
r, lora_alpha, rank_pattern, alpha_pattern, target_modules = get_rank_and_alpha_pattern(
rank, network_alphas, text_encoder_lora_state_dict
)
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
lora_config = LoraConfig(
r=r,
target_modules=target_modules,
lora_alpha=lora_alpha,
rank_pattern=rank_pattern,
alpha_pattern=alpha_pattern,
)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
# inject LoRA layers and load the state dict
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, lora_weightage=lora_scale)
is_model_cpu_offload = False
is_sequential_cpu_offload = False
else:
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
unexpected_keys = load_state_dict_results.unexpected_keys
if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
is_pipeline_offloaded = _pipeline is not None and any(
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
for c in _pipeline.components.values()
)
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
load_state_dict_results = text_encoder.load_state_dict(
text_encoder_lora_state_dict, strict=False
)
unexpected_keys = load_state_dict_results.unexpected_keys
if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
@@ -1645,10 +1693,20 @@ class LoraLoaderMixin:
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
if self.use_peft_backend:
remove_method = recurse_remove_peft_layers
else:
remove_method = self._remove_text_encoder_monkey_patch_classmethod
if hasattr(self, "text_encoder"):
remove_method(self.text_encoder)
if hasattr(self, "text_encoder_2"):
remove_method(self.text_encoder_2)
@classmethod
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_linear_layer = None
@@ -1675,6 +1733,7 @@ class LoraLoaderMixin:
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE)
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
@@ -2049,24 +2108,38 @@ class LoraLoaderMixin:
if fuse_unet:
self.unet.fuse_lora(lora_scale)
def fuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
if self.use_peft_backend:
from peft.tuners.tuners_utils import BaseTunerLayer
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
if lora_scale != 1.0:
module.scale_layer(lora_scale)
module.merge()
else:
deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
if fuse_text_encoder:
if hasattr(self, "text_encoder"):
fuse_text_encoder_lora(self.text_encoder)
fuse_text_encoder_lora(self.text_encoder, lora_scale)
if hasattr(self, "text_encoder_2"):
fuse_text_encoder_lora(self.text_encoder_2)
fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
r"""
@@ -2088,18 +2161,29 @@ class LoraLoaderMixin:
if unfuse_unet:
self.unet.unfuse_lora()
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
if self.use_peft_backend:
from peft.tuners.tuner_utils import BaseTunerLayer
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()
def unfuse_text_encoder_lora(text_encoder):
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
else:
deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()
if unfuse_text_encoder:
if hasattr(self, "text_encoder"):
@@ -2109,6 +2193,65 @@ class LoraLoaderMixin:
self.num_fused_loras -= 1
def set_adapter(
self,
adapter_names: Union[List[str], str],
unet_weights: List[float] = None,
te_weights: List[float] = None,
te2_weights: List[float] = None,
):
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")
def process_weights(adapter_names, weights):
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights]
if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)
return weights
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
# To Do
# Handle the UNET
# Handle the Text Encoder
te_weights = process_weights(adapter_names, te_weights)
if hasattr(self, "text_encoder"):
set_weights_and_activate_adapters(self.text_encoder, adapter_names, te_weights)
te2_weights = process_weights(adapter_names, te2_weights)
if hasattr(self, "text_encoder_2"):
set_weights_and_activate_adapters(self.text_encoder_2, adapter_names, te2_weights)
def disable_lora(self):
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")
# To Do
# Disbale unet adapters
# Disbale text encoder adapters
if hasattr(self, "text_encoder"):
set_adapter_layers(self.text_encoder, enabled=False)
if hasattr(self, "text_encoder_2"):
set_adapter_layers(self.text_encoder_2, enabled=False)
def enable_lora(self):
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")
# To Do
# Enable unet adapters
# Enable text encoder adapters
if hasattr(self, "text_encoder"):
set_adapter_layers(self.text_encoder, enabled=True)
if hasattr(self, "text_encoder_2"):
set_adapter_layers(self.text_encoder_2, enabled=True)
class FromSingleFileMixin:
"""
@@ -2810,5 +2953,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
)
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
if self.use_peft_backend:
recurse_remove_peft_layers(self.text_encoder)
recurse_remove_peft_layers(self.text_encoder_2)
else:
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)

View File

@@ -19,24 +19,27 @@ import torch.nn.functional as F
from torch import nn
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
from ..utils import logging
from ..utils import logging, scale_lora_layers
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
if use_peft_backend:
scale_lora_layers(text_encoder, lora_weightage=lora_scale)
else:
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
class LoRALinearLayer(nn.Module):

View File

@@ -303,7 +303,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -301,7 +301,7 @@ class AltDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -291,7 +291,7 @@ class StableDiffusionControlNetPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -315,7 +315,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -442,7 +442,7 @@ class StableDiffusionControlNetInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -315,7 +315,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt

View File

@@ -288,7 +288,7 @@ class StableDiffusionXLControlNetPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt

View File

@@ -326,7 +326,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt

View File

@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -301,7 +301,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -332,7 +332,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -481,7 +481,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -278,7 +278,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -309,7 +309,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -302,7 +302,7 @@ class StableDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -375,7 +375,7 @@ class StableDiffusionInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -297,7 +297,7 @@ class StableDiffusionInpaintPipelineLegacy(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -211,7 +211,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -272,7 +272,7 @@ class StableDiffusionLDM3DPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -244,7 +244,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -221,7 +221,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -256,7 +256,7 @@ class StableDiffusionParadigmsPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -446,7 +446,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -244,7 +244,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -240,7 +240,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -346,7 +346,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -296,7 +296,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -264,7 +264,7 @@ class StableDiffusionXLPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt

View File

@@ -271,7 +271,7 @@ class StableDiffusionXLImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt

View File

@@ -420,7 +420,7 @@ class StableDiffusionXLInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt

View File

@@ -272,7 +272,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str):

View File

@@ -296,7 +296,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -288,7 +288,7 @@ class StableDiffusionXLAdapterPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt

View File

@@ -228,7 +228,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -290,7 +290,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1

View File

@@ -67,6 +67,7 @@ from .import_utils import (
is_note_seq_available,
is_omegaconf_available,
is_onnx_available,
is_peft_available,
is_scipy_available,
is_tensorboard_available,
is_torch_available,
@@ -82,7 +83,16 @@ from .import_utils import (
from .loading_utils import load_image
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
get_adapter_name,
get_rank_and_alpha_pattern,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft
logger = get_logger(__name__)

View File

@@ -267,6 +267,14 @@ except importlib_metadata.PackageNotFoundError:
_invisible_watermark_available = False
_peft_available = importlib.util.find_spec("peft") is not None
try:
_peft_version = importlib_metadata.version("peft")
logger.debug(f"Successfully imported accelerate version {_peft_version}")
except importlib_metadata.PackageNotFoundError:
_peft_available = False
def is_torch_available():
return _torch_available
@@ -351,6 +359,10 @@ def is_invisible_watermark_available():
return _invisible_watermark_available
def is_peft_available():
return _peft_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the

View File

@@ -0,0 +1,138 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PEFT utilities: Utilities related to peft library
"""
from .import_utils import is_torch_available
if is_torch_available():
import torch
def recurse_remove_peft_layers(model):
r"""
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
"""
from peft.tuners.lora import LoraLayer
for name, module in model.named_children():
if len(list(module.children())) > 0:
## compound module, go inside it
recurse_remove_peft_layers(module)
module_replaced = False
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
module.weight.device
)
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias
module_replaced = True
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
new_module = torch.nn.Conv2d(
module.in_channels,
module.out_channels,
module.kernel_size,
module.stride,
module.padding,
module.dilation,
module.groups,
module.bias,
).to(module.weight.device)
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias
module_replaced = True
if module_replaced:
setattr(model, name, new_module)
del module
if torch.cuda.is_available():
torch.cuda.empty_cache()
return model
def scale_lora_layers(model, lora_weightage):
from peft.tuners.tuner_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.scale_layer(lora_weightage)
def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict):
rank_pattern = None
alpha_pattern = None
r = lora_alpha = list(rank_dict.values())[0]
if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times
r = max(set(rank_dict.values()), key=list(rank_dict.values()).count)
# for modules with rank different from the most occuring rank, add it to the `rank_pattern`
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times
lora_alpha = max(set(network_alpha_dict.values()), key=list(network_alpha_dict.values()).count)
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
# layer names without the Diffusers specific
target_modules = {name.split(".lora")[0] for name in peft_state_dict.keys()}
return r, lora_alpha, rank_pattern, alpha_pattern, target_modules
def get_adapter_name(model):
from peft.tuners.tuner_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return f"default_{len(module.r)}"
return "default_0"
def set_adapter_layers(model, enabled=True):
from peft.tuners.tuner_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = False if enabled else True
def set_weights_and_activate_adapters(model, adapter_names, weights):
from peft.tuners.tuner_utils import BaseTunerLayer
# iterate over each adapter, make it active and set the corresponding scaling weight
for adapter_name, weight in zip(adapter_names, weights):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_name
module.scale_layer(weight)
# set multiple active adapters
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_names

View File

@@ -0,0 +1,180 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
State dict utilities: utility methods for converting state dicts easily
"""
import enum
class StateDictType(enum.Enum):
"""
The mode to use when converting state dicts.
"""
DIFFUSERS_OLD = "diffusers_old"
# KOHYA_SS = "kohya_ss" # TODO: implement this
PEFT = "peft"
DIFFUSERS = "diffusers"
DIFFUSERS_TO_PEFT = {
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
".q_proj.lora_linear_layer.down": ".q_proj.lora_A",
".k_proj.lora_linear_layer.up": ".k_proj.lora_B",
".k_proj.lora_linear_layer.down": ".k_proj.lora_A",
".v_proj.lora_linear_layer.up": ".v_proj.lora_B",
".v_proj.lora_linear_layer.down": ".v_proj.lora_A",
".out_proj.lora_linear_layer.up": ".out_proj.lora_B",
".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
}
DIFFUSERS_OLD_TO_PEFT = {
".to_q_lora.up": ".q_proj.lora_B",
".to_q_lora.down": ".q_proj.lora_A",
".to_k_lora.up": ".k_proj.lora_B",
".to_k_lora.down": ".k_proj.lora_A",
".to_v_lora.up": ".v_proj.lora_B",
".to_v_lora.down": ".v_proj.lora_A",
".to_out_lora.up": ".out_proj.lora_B",
".to_out_lora.down": ".out_proj.lora_A",
}
PEFT_TO_DIFFUSERS = {
".q_proj.lora_B": ".q_proj.lora_linear_layer.up",
".q_proj.lora_A": ".q_proj.lora_linear_layer.down",
".k_proj.lora_B": ".k_proj.lora_linear_layer.up",
".k_proj.lora_A": ".k_proj.lora_linear_layer.down",
".v_proj.lora_B": ".v_proj.lora_linear_layer.up",
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
}
DIFFUSERS_OLD_TO_DIFFUSERS = {
".to_q_lora.up": ".q_proj.lora_linear_layer.up",
".to_q_lora.down": ".q_proj.lora_linear_layer.down",
".to_k_lora.up": ".k_proj.lora_linear_layer.up",
".to_k_lora.down": ".k_proj.lora_linear_layer.down",
".to_v_lora.up": ".v_proj.lora_linear_layer.up",
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
}
PEFT_STATE_DICT_MAPPINGS = {
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
}
DIFFUSERS_STATE_DICT_MAPPINGS = {
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS,
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
}
def convert_state_dict(state_dict, mapping):
r"""
Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
mapping (`dict[str, str]`):
The mapping to use for conversion, the mapping should be a dictionary with the following structure:
- key: the pattern to replace
- value: the pattern to replace with
Returns:
converted_state_dict (`dict`)
The converted state dict.
"""
converted_state_dict = {}
for k, v in state_dict.items():
if any(pattern in k for pattern in mapping.keys()):
for old, new in mapping.items():
k = k.replace(old, new)
converted_state_dict[k] = v
return converted_state_dict
def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs):
r"""
Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or
new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
original_type (`StateDictType`, *optional*):
The original type of the state dict, if not provided, the method will try to infer it automatically.
"""
if original_type is None:
# Old diffusers to PEFT
if any("to_out_lora" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS_OLD
elif any("lora_linear_layer" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS
else:
raise ValueError("Could not automatically infer state dict type")
if original_type not in PEFT_STATE_DICT_MAPPINGS.keys():
raise ValueError(f"Original type {original_type} is not supported")
mapping = PEFT_STATE_DICT_MAPPINGS[original_type]
return convert_state_dict(state_dict, mapping)
def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
r"""
Converts a state dict to new diffusers format. The state dict can be from previous diffusers format
(`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will
return the state dict as is.
The method only supports the conversion from diffusers old, PEFT to diffusers new for now.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
original_type (`StateDictType`, *optional*):
The original type of the state dict, if not provided, the method will try to infer it automatically.
kwargs (`dict`, *args*):
Additional arguments to pass to the method.
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
`get_peft_model_state_dict` method:
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
but we add it here in case we don't want to rely on that method.
"""
peft_adapter_name = kwargs.pop("adapter_name", "")
peft_adapter_name = "." + peft_adapter_name
if original_type is None:
# Old diffusers to PEFT
if any("to_out_lora" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS_OLD
elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
original_type = StateDictType.PEFT
elif any("lora_linear_layer" in k for k in state_dict.keys()):
# nothing to do
return state_dict
else:
raise ValueError("Could not automatically infer state dict type")
if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys():
raise ValueError(f"Original type {original_type} is not supported")
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
return convert_state_dict(state_dict, mapping)

View File

@@ -0,0 +1,147 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDIMScheduler,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
)
from diffusers.utils.testing_utils import floats_tensor
def create_unet_lora_layers(unet: nn.Module):
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers
class LoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
pipeline_components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
"unet_lora_attn_procs": unet_lora_attn_procs,
}
return pipeline_components, lora_components
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
def get_dummy_tokens(self):
max_seq_length = 77
inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
prepared_inputs = {}
prepared_inputs["input_ids"] = inputs
return prepared_inputs