Compare commits

...

4 Commits

Author SHA1 Message Date
sayakpaul
e7b95976b3 safe remove old lora backend 👋 2024-02-09 18:12:58 +05:30
sayakpaul
20d5f88dd0 Merge branch 'main' into lora-old-backend 2024-02-09 17:39:13 +05:30
sayakpaul
a86f0b00cd uncomment necessary things. 2024-02-07 15:46:36 +05:30
sayakpaul
449353298c deprecate certain lora methods from the old backend. 2024-02-07 13:41:04 +05:30
5 changed files with 104 additions and 2573 deletions

View File

@@ -34,11 +34,6 @@ jobs:
runner: docker-cpu runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_models_schedulers report: torch_cpu_models_schedulers
- name: LoRA
framework: lora
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_lora
- name: Fast Flax CPU tests - name: Fast Flax CPU tests
framework: flax framework: flax
runner: docker-cpu runner: docker-cpu
@@ -94,14 +89,6 @@ jobs:
--make-reports=tests_${{ matrix.config.report }} \ --make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others tests/models tests/schedulers tests/others
- name: Run fast PyTorch LoRA CPU tests
if: ${{ matrix.config.framework == 'lora' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and not Dependency" \
--make-reports=tests_${{ matrix.config.report }} \
tests/lora
- name: Run fast Flax TPU tests - name: Run fast Flax TPU tests
if: ${{ matrix.config.framework == 'flax' }} if: ${{ matrix.config.framework == 'flax' }}
run: | run: |

View File

@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import os import os
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
@@ -26,7 +25,7 @@ from packaging import version
from torch import nn from torch import nn
from .. import __version__ from .. import __version__
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
@@ -34,7 +33,6 @@ from ..utils import (
convert_state_dict_to_peft, convert_state_dict_to_peft,
convert_unet_state_dict_to_peft, convert_unet_state_dict_to_peft,
delete_adapter_layers, delete_adapter_layers,
deprecate,
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
is_accelerate_available, is_accelerate_available,
@@ -51,10 +49,9 @@ from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_
if is_transformers_available(): if is_transformers_available():
from transformers import PreTrainedModel from transformers import PreTrainedModel
from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -106,6 +103,9 @@ class LoraLoaderMixin:
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
""" """
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
@@ -397,16 +397,17 @@ class LoraLoaderMixin:
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
""" """
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes. # their prefixes.
keys = list(state_dict.keys()) keys = list(state_dict.keys())
if all(key.startswith("unet.unet") for key in keys):
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
deprecate("unet.unet keys", "0.27", deprecation_message)
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet. # Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.") logger.info(f"Loading {cls.unet_name}.")
@@ -427,9 +428,7 @@ class LoraLoaderMixin:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warn(warn_message) logger.warn(warn_message)
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0: if len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
if adapter_name in getattr(unet, "peft_config", {}): if adapter_name in getattr(unet, "peft_config", {}):
raise ValueError( raise ValueError(
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
@@ -518,6 +517,11 @@ class LoraLoaderMixin:
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
""" """
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
from peft import LoraConfig
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -539,34 +543,21 @@ class LoraLoaderMixin:
rank = {} rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
if USE_PEFT_BACKEND: # convert state dict
# convert state dict text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder): for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_B.weight" rank_key = f"{name}.out_proj.lora_B.weight"
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp: if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder): for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight" rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight" rank_key_fc2 = f"{name}.fc2.lora_B.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
else:
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
if network_alphas is not None: if network_alphas is not None:
alpha_keys = [ alpha_keys = [
@@ -576,84 +567,25 @@ class LoraLoaderMixin:
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
} }
if USE_PEFT_BACKEND: lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
from peft import LoraConfig lora_config = LoraConfig(**lora_config_kwargs)
lora_config_kwargs = get_peft_kwargs( # adapter_name
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False if adapter_name is None:
) adapter_name = get_adapter_name(text_encoder)
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) # inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
)
# inject LoRA layers and load the state dict # scale LoRA layers with `lora_scale`
# in transformers we automatically check whether the adapter name is already in use or not scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
else:
cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
is_pipeline_offloaded = _pipeline is not None and any(
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
for c in _pipeline.components.values()
)
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
load_state_dict_results = text_encoder.load_state_dict(
text_encoder_lora_state_dict, strict=False
)
unexpected_keys = load_state_dict_results.unexpected_keys
if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
@@ -689,6 +621,8 @@ class LoraLoaderMixin:
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
""" """
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
keys = list(state_dict.keys()) keys = list(state_dict.keys())
@@ -705,8 +639,6 @@ class LoraLoaderMixin:
} }
if len(state_dict.keys()) > 0: if len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
if adapter_name in getattr(transformer, "peft_config", {}): if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError( raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
@@ -754,118 +686,20 @@ class LoraLoaderMixin:
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def _remove_text_encoder_monkey_patch(self): def _remove_text_encoder_monkey_patch(self):
if USE_PEFT_BACKEND: remove_method = recurse_remove_peft_layers
remove_method = recurse_remove_peft_layers
else:
remove_method = self._remove_text_encoder_monkey_patch_classmethod
if hasattr(self, "text_encoder"): if hasattr(self, "text_encoder"):
remove_method(self.text_encoder) remove_method(self.text_encoder)
# In case text encoder have no Lora attached # In case text encoder have no Lora attached
if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None: if getattr(self.text_encoder, "peft_config", None) is not None:
del self.text_encoder.peft_config del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None self.text_encoder._hf_peft_config_loaded = None
if hasattr(self, "text_encoder_2"): if hasattr(self, "text_encoder_2"):
remove_method(self.text_encoder_2) remove_method(self.text_encoder_2)
if USE_PEFT_BACKEND: if getattr(self.text_encoder_2, "peft_config", None) is not None:
del self.text_encoder_2.peft_config del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None self.text_encoder_2._hf_peft_config_loaded = None
@classmethod
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_linear_layer = None
attn_module.k_proj.lora_linear_layer = None
attn_module.v_proj.lora_linear_layer = None
attn_module.out_proj.lora_linear_layer = None
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_linear_layer = None
mlp_module.fc2.lora_linear_layer = None
@classmethod
def _modify_text_encoder(
cls,
text_encoder,
lora_scale=1,
network_alphas=None,
rank: Union[Dict[str, int], int] = 4,
dtype=None,
patch_mlp=False,
low_cpu_mem_usage=False,
):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
lora_parameters.extend(model.lora_linear_layer.parameters())
return model
# First, remove any monkey-patch that might have been applied before
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
lora_parameters = []
network_alphas = {} if network_alphas is None else network_alphas
is_network_alphas_populated = len(network_alphas) > 0
for name, attn_module in text_encoder_attn_modules(text_encoder):
query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
if isinstance(rank, dict):
current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
else:
current_rank = rank
attn_module.q_proj = create_patched_linear_lora(
attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
)
attn_module.k_proj = create_patched_linear_lora(
attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
)
attn_module.v_proj = create_patched_linear_lora(
attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
)
attn_module.out_proj = create_patched_linear_lora(
attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
)
if patch_mlp:
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None)
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None)
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
mlp_module.fc1 = create_patched_linear_lora(
mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
)
mlp_module.fc2 = create_patched_linear_lora(
mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
)
if is_network_alphas_populated and len(network_alphas) > 0:
raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
)
return lora_parameters
@classmethod @classmethod
def save_lora_weights( def save_lora_weights(
cls, cls,
@@ -1039,6 +873,8 @@ class LoraLoaderMixin:
pipeline.fuse_lora(lora_scale=0.7) pipeline.fuse_lora(lora_scale=0.7)
``` ```
""" """
from peft.tuners.tuners_utils import BaseTunerLayer
if fuse_unet or fuse_text_encoder: if fuse_unet or fuse_text_encoder:
self.num_fused_loras += 1 self.num_fused_loras += 1
if self.num_fused_loras > 1: if self.num_fused_loras > 1:
@@ -1050,52 +886,26 @@ class LoraLoaderMixin:
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
if USE_PEFT_BACKEND: def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
from peft.tuners.tuners_utils import BaseTunerLayer merge_kwargs = {"safe_merge": safe_fusing}
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): for module in text_encoder.modules():
merge_kwargs = {"safe_merge": safe_fusing} if isinstance(module, BaseTunerLayer):
if lora_scale != 1.0:
module.scale_layer(lora_scale)
for module in text_encoder.modules(): # For BC with previous PEFT versions, we need to check the signature
if isinstance(module, BaseTunerLayer): # of the `merge` method to see if it supports the `adapter_names` argument.
if lora_scale != 1.0: supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
module.scale_layer(lora_scale) if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported with your PEFT version. "
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
)
# For BC with previous PEFT versions, we need to check the signature module.merge(**merge_kwargs)
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported with your PEFT version. "
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
)
module.merge(**merge_kwargs)
else:
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs):
if "adapter_names" in kwargs and kwargs["adapter_names"] is not None:
raise ValueError(
"The `adapter_names` argument is not supported in your environment. Please switch to PEFT "
"backend to use this argument by installing latest PEFT and transformers."
" `pip install -U peft transformers`"
)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)
if fuse_text_encoder: if fuse_text_encoder:
if hasattr(self, "text_encoder"): if hasattr(self, "text_encoder"):
@@ -1120,40 +930,18 @@ class LoraLoaderMixin:
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect. LoRA parameters then it won't have any effect.
""" """
from peft.tuners.tuners_utils import BaseTunerLayer
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if unfuse_unet: if unfuse_unet:
if not USE_PEFT_BACKEND: for module in unet.modules():
unet.unfuse_lora() if isinstance(module, BaseTunerLayer):
else: module.unmerge()
from peft.tuners.tuners_utils import BaseTunerLayer
for module in unet.modules(): def unfuse_text_encoder_lora(text_encoder):
if isinstance(module, BaseTunerLayer): for module in text_encoder.modules():
module.unmerge() if isinstance(module, BaseTunerLayer):
module.unmerge()
if USE_PEFT_BACKEND:
from peft.tuners.tuners_utils import BaseTunerLayer
def unfuse_text_encoder_lora(text_encoder):
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
else:
deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()
if unfuse_text_encoder: if unfuse_text_encoder:
if hasattr(self, "text_encoder"): if hasattr(self, "text_encoder"):
@@ -1434,6 +1222,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`]. See [`~loaders.LoraLoaderMixin.lora_state_dict`].
""" """
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
# We could have accessed the unet config from `lora_state_dict()` too. We pass # We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL # it here explicitly to be able to tell that it's coming from an SDXL
# pipeline. # pipeline.
@@ -1538,17 +1329,13 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
) )
def _remove_text_encoder_monkey_patch(self): def _remove_text_encoder_monkey_patch(self):
if USE_PEFT_BACKEND: recurse_remove_peft_layers(self.text_encoder)
recurse_remove_peft_layers(self.text_encoder) # TODO: @younesbelkada handle this in transformers side
# TODO: @younesbelkada handle this in transformers side if getattr(self.text_encoder, "peft_config", None) is not None:
if getattr(self.text_encoder, "peft_config", None) is not None: del self.text_encoder.peft_config
del self.text_encoder.peft_config self.text_encoder._hf_peft_config_loaded = None
self.text_encoder._hf_peft_config_loaded = None
recurse_remove_peft_layers(self.text_encoder_2) recurse_remove_peft_layers(self.text_encoder_2)
if getattr(self.text_encoder_2, "peft_config", None) is not None: if getattr(self.text_encoder_2, "peft_config", None) is not None:
del self.text_encoder_2.peft_config del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None self.text_encoder_2._hf_peft_config_loaded = None
else:
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)

View File

@@ -27,7 +27,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import logging from ..utils import deprecate, logging
from ..utils.import_utils import is_transformers_available from ..utils.import_utils import is_transformers_available
@@ -82,6 +82,9 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
class PatchedLoraProjection(torch.nn.Module): class PatchedLoraProjection(torch.nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
deprecation_message = "Use of `PatchedLoraProjection` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("PatchedLoraProjection", "1.0.0", deprecation_message)
super().__init__() super().__init__()
from ..models.lora import LoRALinearLayer from ..models.lora import LoRALinearLayer
@@ -293,10 +296,16 @@ class LoRACompatibleConv(nn.Conv2d):
""" """
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs): def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
deprecation_message = "Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.lora_layer = lora_layer self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("set_lora_layer", "1.0.0", deprecation_message)
self.lora_layer = lora_layer self.lora_layer = lora_layer
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False): def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
@@ -371,10 +380,15 @@ class LoRACompatibleLinear(nn.Linear):
""" """
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
deprecation_message = "Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.lora_layer = lora_layer self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("set_lora_layer", "1.0.0", deprecation_message)
self.lora_layer = lora_layer self.lora_layer = lora_layer
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False): def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):

File diff suppressed because it is too large Load Diff

View File

@@ -1,64 +0,0 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from diffusers import DiffusionPipeline
from diffusers.utils.testing_utils import torch_device
class PEFTLoRALoading(unittest.TestCase):
def get_dummy_inputs(self):
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
"generator": torch.manual_seed(0),
}
return pipeline_inputs
def test_stable_diffusion_peft_lora_loading_in_non_peft(self):
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
# This LoRA was obtained using similarly as how it's done in the training scripts.
# For details on how the LoRA was obtained, refer to:
# https://hf.co/datasets/diffusers/notebooks/blob/main/check_logits_with_serialization_peft_lora.py
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sd-lora-peft")
inputs = self.get_dummy_inputs()
outputs = sd_pipe(**inputs).images
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.5396, 0.5707, 0.477, 0.4665, 0.5419, 0.4594, 0.4857, 0.4741, 0.4804])
self.assertTrue(outputs.shape == (1, 64, 64, 3))
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
def test_stable_diffusion_xl_peft_lora_loading_in_non_peft(self):
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-pipe").to(torch_device)
# This LoRA was obtained using similarly as how it's done in the training scripts.
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sdxl-lora-peft")
inputs = self.get_dummy_inputs()
outputs = sd_pipe(**inputs).images
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.613, 0.5566, 0.54, 0.4162, 0.4042, 0.4596, 0.5374, 0.5286, 0.5038])
self.assertTrue(outputs.shape == (1, 64, 64, 3))
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)