mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
4 Commits
modular-do
...
lora-old-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7b95976b3 | ||
|
|
20d5f88dd0 | ||
|
|
a86f0b00cd | ||
|
|
449353298c |
13
.github/workflows/pr_tests.yml
vendored
13
.github/workflows/pr_tests.yml
vendored
@@ -34,11 +34,6 @@ jobs:
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_models_schedulers
|
||||
- name: LoRA
|
||||
framework: lora
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_lora
|
||||
- name: Fast Flax CPU tests
|
||||
framework: flax
|
||||
runner: docker-cpu
|
||||
@@ -94,14 +89,6 @@ jobs:
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/models tests/schedulers tests/others
|
||||
|
||||
- name: Run fast PyTorch LoRA CPU tests
|
||||
if: ${{ matrix.config.framework == 'lora' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and not Dependency" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/lora
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
@@ -26,7 +25,7 @@ from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from .. import __version__
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
@@ -34,7 +33,6 @@ from ..utils import (
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
@@ -51,10 +49,9 @@ from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -106,6 +103,9 @@ class LoraLoaderMixin:
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
@@ -397,16 +397,17 @@ class LoraLoaderMixin:
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
if all(key.startswith("unet.unet") for key in keys):
|
||||
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
|
||||
deprecate("unet.unet keys", "0.27", deprecation_message)
|
||||
|
||||
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
@@ -427,9 +428,7 @@ class LoraLoaderMixin:
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
logger.warn(warn_message)
|
||||
|
||||
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
if adapter_name in getattr(unet, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
||||
@@ -518,6 +517,11 @@ class LoraLoaderMixin:
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -539,34 +543,21 @@ class LoraLoaderMixin:
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
else:
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
@@ -576,84 +567,25 @@ class LoraLoaderMixin:
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft import LoraConfig
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
|
||||
)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
else:
|
||||
cls._modify_text_encoder(
|
||||
text_encoder,
|
||||
lora_scale,
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
is_pipeline_offloaded = _pipeline is not None and any(
|
||||
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
|
||||
for c in _pipeline.components.values()
|
||||
)
|
||||
if is_pipeline_offloaded and low_cpu_mem_usage:
|
||||
low_cpu_mem_usage = True
|
||||
logger.info(
|
||||
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
load_state_dict_results = text_encoder.load_state_dict(
|
||||
text_encoder_lora_state_dict, strict=False
|
||||
)
|
||||
unexpected_keys = load_state_dict_results.unexpected_keys
|
||||
|
||||
if len(unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(
|
||||
getattr(component, "_hf_hook"), AlignDevicesHook
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
@@ -689,6 +621,8 @@ class LoraLoaderMixin:
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
@@ -705,8 +639,6 @@ class LoraLoaderMixin:
|
||||
}
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
@@ -754,118 +686,20 @@ class LoraLoaderMixin:
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
if USE_PEFT_BACKEND:
|
||||
remove_method = recurse_remove_peft_layers
|
||||
else:
|
||||
remove_method = self._remove_text_encoder_monkey_patch_classmethod
|
||||
|
||||
remove_method = recurse_remove_peft_layers
|
||||
if hasattr(self, "text_encoder"):
|
||||
remove_method(self.text_encoder)
|
||||
|
||||
# In case text encoder have no Lora attached
|
||||
if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
remove_method(self.text_encoder_2)
|
||||
if USE_PEFT_BACKEND:
|
||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
|
||||
@classmethod
|
||||
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
|
||||
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_linear_layer = None
|
||||
attn_module.k_proj.lora_linear_layer = None
|
||||
attn_module.v_proj.lora_linear_layer = None
|
||||
attn_module.out_proj.lora_linear_layer = None
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1.lora_linear_layer = None
|
||||
mlp_module.fc2.lora_linear_layer = None
|
||||
|
||||
@classmethod
|
||||
def _modify_text_encoder(
|
||||
cls,
|
||||
text_encoder,
|
||||
lora_scale=1,
|
||||
network_alphas=None,
|
||||
rank: Union[Dict[str, int], int] = 4,
|
||||
dtype=None,
|
||||
patch_mlp=False,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
r"""
|
||||
Monkey-patches the forward passes of attention modules of the text encoder.
|
||||
"""
|
||||
deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
|
||||
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
|
||||
|
||||
lora_parameters.extend(model.lora_linear_layer.parameters())
|
||||
return model
|
||||
|
||||
# First, remove any monkey-patch that might have been applied before
|
||||
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
|
||||
|
||||
lora_parameters = []
|
||||
network_alphas = {} if network_alphas is None else network_alphas
|
||||
is_network_alphas_populated = len(network_alphas) > 0
|
||||
|
||||
for name, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
|
||||
key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
|
||||
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
|
||||
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
|
||||
|
||||
if isinstance(rank, dict):
|
||||
current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
|
||||
else:
|
||||
current_rank = rank
|
||||
|
||||
attn_module.q_proj = create_patched_linear_lora(
|
||||
attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.k_proj = create_patched_linear_lora(
|
||||
attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.v_proj = create_patched_linear_lora(
|
||||
attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.out_proj = create_patched_linear_lora(
|
||||
attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
|
||||
if patch_mlp:
|
||||
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None)
|
||||
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None)
|
||||
|
||||
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
|
||||
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
|
||||
|
||||
mlp_module.fc1 = create_patched_linear_lora(
|
||||
mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
|
||||
)
|
||||
mlp_module.fc2 = create_patched_linear_lora(
|
||||
mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
|
||||
)
|
||||
|
||||
if is_network_alphas_populated and len(network_alphas) > 0:
|
||||
raise ValueError(
|
||||
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
|
||||
)
|
||||
|
||||
return lora_parameters
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
@@ -1039,6 +873,8 @@ class LoraLoaderMixin:
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if fuse_unet or fuse_text_encoder:
|
||||
self.num_fused_loras += 1
|
||||
if self.num_fused_loras > 1:
|
||||
@@ -1050,52 +886,26 @@ class LoraLoaderMixin:
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
else:
|
||||
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs):
|
||||
if "adapter_names" in kwargs and kwargs["adapter_names"] is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported in your environment. Please switch to PEFT "
|
||||
"backend to use this argument by installing latest PEFT and transformers."
|
||||
" `pip install -U peft transformers`"
|
||||
)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
|
||||
mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1120,40 +930,18 @@ class LoraLoaderMixin:
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
if unfuse_unet:
|
||||
if not USE_PEFT_BACKEND:
|
||||
unet.unfuse_lora()
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
else:
|
||||
deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._unfuse_lora()
|
||||
attn_module.k_proj._unfuse_lora()
|
||||
attn_module.v_proj._unfuse_lora()
|
||||
attn_module.out_proj._unfuse_lora()
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._unfuse_lora()
|
||||
mlp_module.fc2._unfuse_lora()
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
if unfuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1434,6 +1222,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
@@ -1538,17 +1329,13 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
)
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
if USE_PEFT_BACKEND:
|
||||
recurse_remove_peft_layers(self.text_encoder)
|
||||
# TODO: @younesbelkada handle this in transformers side
|
||||
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
recurse_remove_peft_layers(self.text_encoder)
|
||||
# TODO: @younesbelkada handle this in transformers side
|
||||
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
recurse_remove_peft_layers(self.text_encoder_2)
|
||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
else:
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
recurse_remove_peft_layers(self.text_encoder_2)
|
||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
|
||||
@@ -27,7 +27,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import logging
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.import_utils import is_transformers_available
|
||||
|
||||
|
||||
@@ -82,6 +82,9 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
||||
|
||||
class PatchedLoraProjection(torch.nn.Module):
|
||||
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
|
||||
deprecation_message = "Use of `PatchedLoraProjection` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("PatchedLoraProjection", "1.0.0", deprecation_message)
|
||||
|
||||
super().__init__()
|
||||
from ..models.lora import LoRALinearLayer
|
||||
|
||||
@@ -293,10 +296,16 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
|
||||
deprecation_message = "Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
||||
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("set_lora_layer", "1.0.0", deprecation_message)
|
||||
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||
@@ -371,10 +380,15 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
|
||||
deprecation_message = "Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
||||
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("set_lora_layer", "1.0.0", deprecation_message)
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,64 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
class PEFTLoRALoading(unittest.TestCase):
|
||||
def get_dummy_inputs(self):
|
||||
pipeline_inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
"generator": torch.manual_seed(0),
|
||||
}
|
||||
return pipeline_inputs
|
||||
|
||||
def test_stable_diffusion_peft_lora_loading_in_non_peft(self):
|
||||
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
||||
# This LoRA was obtained using similarly as how it's done in the training scripts.
|
||||
# For details on how the LoRA was obtained, refer to:
|
||||
# https://hf.co/datasets/diffusers/notebooks/blob/main/check_logits_with_serialization_peft_lora.py
|
||||
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sd-lora-peft")
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
outputs = sd_pipe(**inputs).images
|
||||
|
||||
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.5396, 0.5707, 0.477, 0.4665, 0.5419, 0.4594, 0.4857, 0.4741, 0.4804])
|
||||
|
||||
self.assertTrue(outputs.shape == (1, 64, 64, 3))
|
||||
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_stable_diffusion_xl_peft_lora_loading_in_non_peft(self):
|
||||
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-pipe").to(torch_device)
|
||||
# This LoRA was obtained using similarly as how it's done in the training scripts.
|
||||
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sdxl-lora-peft")
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
outputs = sd_pipe(**inputs).images
|
||||
|
||||
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.613, 0.5566, 0.54, 0.4162, 0.4042, 0.4596, 0.5374, 0.5286, 0.5038])
|
||||
|
||||
self.assertTrue(outputs.shape == (1, 64, 64, 3))
|
||||
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
|
||||
Reference in New Issue
Block a user