mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
49 Commits
if-tests
...
peftpart-1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ece3b02f13 | ||
|
|
0985d17ea9 | ||
|
|
920333ffaa | ||
|
|
71650d403d | ||
|
|
5e6f343d16 | ||
|
|
724b52bc56 | ||
|
|
bd46ae9db7 | ||
|
|
e072655765 | ||
|
|
b72ef23dfc | ||
|
|
b412adc158 | ||
|
|
325462dcd2 | ||
|
|
cb484056c8 | ||
|
|
e836b145e8 | ||
|
|
d01a29273e | ||
|
|
3d7c567f90 | ||
|
|
27e3da69dc | ||
|
|
ea05959c6a | ||
|
|
40a489457d | ||
|
|
c90f85d3f0 | ||
|
|
9cb8563b1d | ||
|
|
74e33a9376 | ||
|
|
b83fcbaf86 | ||
|
|
dc83fa0ec7 | ||
|
|
3ba2d4eb05 | ||
|
|
f8e87f6220 | ||
|
|
f8909061ee | ||
|
|
6f1adcd65d | ||
|
|
9d650c9032 | ||
|
|
ecbc7144f1 | ||
|
|
78a01d5151 | ||
|
|
78a860d276 | ||
|
|
1d13f40548 | ||
|
|
4162ddfdba | ||
|
|
c4295c9432 | ||
|
|
0c62ef3daf | ||
|
|
40a60286b4 | ||
|
|
ec87c196f3 | ||
|
|
d56a14db7b | ||
|
|
c06c40bad6 | ||
|
|
14db139116 | ||
|
|
7918851640 | ||
|
|
691368b060 | ||
|
|
cdbe7391a8 | ||
|
|
961e776298 | ||
|
|
5a150b2059 | ||
|
|
01f6d1d88c | ||
|
|
2a6e5358a0 | ||
|
|
c17634c39e | ||
|
|
ba24f2a5ce |
1
setup.py
1
setup.py
@@ -128,6 +128,7 @@ _deps = [
|
||||
"torchvision",
|
||||
"transformers>=4.25.1",
|
||||
"urllib3<=2.0.0",
|
||||
"peft>=0.5.0"
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
|
||||
@@ -41,4 +41,5 @@ deps = {
|
||||
"torchvision": "torchvision",
|
||||
"transformers": "transformers>=4.25.1",
|
||||
"urllib3": "urllib3<=2.0.0",
|
||||
"peft": "peft>=0.5.0",
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
@@ -23,6 +24,7 @@ import requests
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, model_info
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
@@ -30,11 +32,20 @@ from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HF_HUB_OFFLINE,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_rank_and_alpha_pattern,
|
||||
is_accelerate_available,
|
||||
is_omegaconf_available,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
@@ -61,6 +72,21 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
||||
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
# available.
|
||||
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
|
||||
_required_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
) > version.parse("0.5")
|
||||
_required_transformers_version = version.parse(
|
||||
version.parse(importlib.metadata.version("transformers")).base_version
|
||||
) > version.parse("4.33")
|
||||
|
||||
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
|
||||
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
|
||||
|
||||
|
||||
class PatchedLoraProjection(nn.Module):
|
||||
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
|
||||
super().__init__()
|
||||
@@ -1077,8 +1103,11 @@ class LoraLoaderMixin:
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
unet_name = UNET_NAME
|
||||
num_fused_loras = 0
|
||||
use_peft_backend = USE_PEFT_BACKEND
|
||||
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
`self.text_encoder`.
|
||||
@@ -1122,6 +1151,7 @@ class LoraLoaderMixin:
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
_pipeline=self,
|
||||
adapter_name=adapter_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1478,6 +1508,7 @@ class LoraLoaderMixin:
|
||||
lora_scale=1.0,
|
||||
low_cpu_mem_usage=None,
|
||||
_pipeline=None,
|
||||
adapter_name=None,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -1500,6 +1531,9 @@ class LoraLoaderMixin:
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
@@ -1520,55 +1554,35 @@ class LoraLoaderMixin:
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
if cls.use_peft_backend:
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
|
||||
# Convert from the old naming convention to the new naming convention.
|
||||
#
|
||||
# Previously, the old LoRA layers were stored on the state dict at the
|
||||
# same level as the attention block i.e.
|
||||
# `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`.
|
||||
#
|
||||
# This is no actual module at that point, they were monkey patched on to the
|
||||
# existing module. We want to be able to load them via their actual state dict.
|
||||
# They're in `PatchedLoraProjection.lora_linear_layer` now.
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.q_proj.lora_linear_layer.up.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.k_proj.lora_linear_layer.up.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.v_proj.lora_linear_layer.up.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight")
|
||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.q_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.k_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.v_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight")
|
||||
text_encoder_lora_state_dict[
|
||||
f"{name}.out_proj.lora_linear_layer.down.weight"
|
||||
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
|
||||
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
|
||||
else:
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
|
||||
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
|
||||
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
@@ -1578,56 +1592,90 @@ class LoraLoaderMixin:
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
cls._modify_text_encoder(
|
||||
text_encoder,
|
||||
lora_scale,
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if cls.use_peft_backend:
|
||||
from peft import LoraConfig
|
||||
|
||||
is_pipeline_offloaded = _pipeline is not None and any(
|
||||
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
|
||||
)
|
||||
if is_pipeline_offloaded and low_cpu_mem_usage:
|
||||
low_cpu_mem_usage = True
|
||||
logger.info(
|
||||
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
||||
r, lora_alpha, rank_pattern, alpha_pattern, target_modules = get_rank_and_alpha_pattern(
|
||||
rank, network_alphas, text_encoder_lora_state_dict
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=target_modules,
|
||||
lora_alpha=lora_alpha,
|
||||
rank_pattern=rank_pattern,
|
||||
alpha_pattern=alpha_pattern,
|
||||
)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, lora_weightage=lora_scale)
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
else:
|
||||
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
|
||||
unexpected_keys = load_state_dict_results.unexpected_keys
|
||||
|
||||
if len(unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||
cls._modify_text_encoder(
|
||||
text_encoder,
|
||||
lora_scale,
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(
|
||||
getattr(component, "_hf_hook"), AlignDevicesHook
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
is_pipeline_offloaded = _pipeline is not None and any(
|
||||
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
|
||||
for c in _pipeline.components.values()
|
||||
)
|
||||
if is_pipeline_offloaded and low_cpu_mem_usage:
|
||||
low_cpu_mem_usage = True
|
||||
logger.info(
|
||||
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
load_state_dict_results = text_encoder.load_state_dict(
|
||||
text_encoder_lora_state_dict, strict=False
|
||||
)
|
||||
unexpected_keys = load_state_dict_results.unexpected_keys
|
||||
|
||||
if len(unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(
|
||||
getattr(component, "_hf_hook"), AlignDevicesHook
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
@@ -1645,10 +1693,20 @@ class LoraLoaderMixin:
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
if self.use_peft_backend:
|
||||
remove_method = recurse_remove_peft_layers
|
||||
else:
|
||||
remove_method = self._remove_text_encoder_monkey_patch_classmethod
|
||||
|
||||
if hasattr(self, "text_encoder"):
|
||||
remove_method(self.text_encoder)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
remove_method(self.text_encoder_2)
|
||||
|
||||
@classmethod
|
||||
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
|
||||
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_linear_layer = None
|
||||
@@ -1675,6 +1733,7 @@ class LoraLoaderMixin:
|
||||
r"""
|
||||
Monkey-patches the forward passes of attention modules of the text encoder.
|
||||
"""
|
||||
deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
|
||||
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
|
||||
@@ -2049,24 +2108,38 @@ class LoraLoaderMixin:
|
||||
if fuse_unet:
|
||||
self.unet.fuse_lora(lora_scale)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale)
|
||||
attn_module.k_proj._fuse_lora(lora_scale)
|
||||
attn_module.v_proj._fuse_lora(lora_scale)
|
||||
attn_module.out_proj._fuse_lora(lora_scale)
|
||||
if self.use_peft_backend:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._fuse_lora(lora_scale)
|
||||
mlp_module.fc2._fuse_lora(lora_scale)
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
module.merge()
|
||||
|
||||
else:
|
||||
deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale)
|
||||
attn_module.k_proj._fuse_lora(lora_scale)
|
||||
attn_module.v_proj._fuse_lora(lora_scale)
|
||||
attn_module.out_proj._fuse_lora(lora_scale)
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._fuse_lora(lora_scale)
|
||||
mlp_module.fc2._fuse_lora(lora_scale)
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
fuse_text_encoder_lora(self.text_encoder)
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
fuse_text_encoder_lora(self.text_encoder_2)
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
|
||||
|
||||
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
|
||||
r"""
|
||||
@@ -2088,18 +2161,29 @@ class LoraLoaderMixin:
|
||||
if unfuse_unet:
|
||||
self.unet.unfuse_lora()
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._unfuse_lora()
|
||||
attn_module.k_proj._unfuse_lora()
|
||||
attn_module.v_proj._unfuse_lora()
|
||||
attn_module.out_proj._unfuse_lora()
|
||||
if self.use_peft_backend:
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._unfuse_lora()
|
||||
mlp_module.fc2._unfuse_lora()
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
else:
|
||||
deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._unfuse_lora()
|
||||
attn_module.k_proj._unfuse_lora()
|
||||
attn_module.v_proj._unfuse_lora()
|
||||
attn_module.out_proj._unfuse_lora()
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._unfuse_lora()
|
||||
mlp_module.fc2._unfuse_lora()
|
||||
|
||||
if unfuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -2109,6 +2193,65 @@ class LoraLoaderMixin:
|
||||
|
||||
self.num_fused_loras -= 1
|
||||
|
||||
def set_adapter(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
unet_weights: List[float] = None,
|
||||
te_weights: List[float] = None,
|
||||
te2_weights: List[float] = None,
|
||||
):
|
||||
if not self.use_peft_backend:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
def process_weights(adapter_names, weights):
|
||||
if weights is None:
|
||||
weights = [1.0] * len(adapter_names)
|
||||
elif isinstance(weights, float):
|
||||
weights = [weights]
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||
)
|
||||
return weights
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
|
||||
# To Do
|
||||
# Handle the UNET
|
||||
|
||||
# Handle the Text Encoder
|
||||
te_weights = process_weights(adapter_names, te_weights)
|
||||
if hasattr(self, "text_encoder"):
|
||||
set_weights_and_activate_adapters(self.text_encoder, adapter_names, te_weights)
|
||||
te2_weights = process_weights(adapter_names, te2_weights)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
set_weights_and_activate_adapters(self.text_encoder_2, adapter_names, te2_weights)
|
||||
|
||||
def disable_lora(self):
|
||||
if not self.use_peft_backend:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
# To Do
|
||||
# Disbale unet adapters
|
||||
|
||||
# Disbale text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
set_adapter_layers(self.text_encoder, enabled=False)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
set_adapter_layers(self.text_encoder_2, enabled=False)
|
||||
|
||||
def enable_lora(self):
|
||||
if not self.use_peft_backend:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
# To Do
|
||||
# Enable unet adapters
|
||||
|
||||
# Enable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
set_adapter_layers(self.text_encoder, enabled=True)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
set_adapter_layers(self.text_encoder_2, enabled=True)
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
@@ -2810,5 +2953,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
)
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
if self.use_peft_backend:
|
||||
recurse_remove_peft_layers(self.text_encoder)
|
||||
recurse_remove_peft_layers(self.text_encoder_2)
|
||||
else:
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
|
||||
@@ -19,24 +19,27 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
from ..utils import logging
|
||||
from ..utils import logging, scale_lora_layers
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_scale = lora_scale
|
||||
attn_module.k_proj.lora_scale = lora_scale
|
||||
attn_module.v_proj.lora_scale = lora_scale
|
||||
attn_module.out_proj.lora_scale = lora_scale
|
||||
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
|
||||
if use_peft_backend:
|
||||
scale_lora_layers(text_encoder, lora_weightage=lora_scale)
|
||||
else:
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_scale = lora_scale
|
||||
attn_module.k_proj.lora_scale = lora_scale
|
||||
attn_module.v_proj.lora_scale = lora_scale
|
||||
attn_module.out_proj.lora_scale = lora_scale
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1.lora_scale = lora_scale
|
||||
mlp_module.fc2.lora_scale = lora_scale
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1.lora_scale = lora_scale
|
||||
mlp_module.fc2.lora_scale = lora_scale
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
|
||||
@@ -303,7 +303,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -301,7 +301,7 @@ class AltDiffusionImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -291,7 +291,7 @@ class StableDiffusionControlNetPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -315,7 +315,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -442,7 +442,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -315,7 +315,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -288,7 +288,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -326,7 +326,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -301,7 +301,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -332,7 +332,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -481,7 +481,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -278,7 +278,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -309,7 +309,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -302,7 +302,7 @@ class StableDiffusionImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -375,7 +375,7 @@ class StableDiffusionInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -297,7 +297,7 @@ class StableDiffusionInpaintPipelineLegacy(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -211,7 +211,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -272,7 +272,7 @@ class StableDiffusionLDM3DPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -244,7 +244,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -221,7 +221,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -256,7 +256,7 @@ class StableDiffusionParadigmsPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -446,7 +446,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -244,7 +244,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -240,7 +240,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -346,7 +346,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -296,7 +296,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -264,7 +264,7 @@ class StableDiffusionXLPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -271,7 +271,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -420,7 +420,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -272,7 +272,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
|
||||
@@ -296,7 +296,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -288,7 +288,7 @@ class StableDiffusionXLAdapterPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -228,7 +228,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -290,7 +290,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -67,6 +67,7 @@ from .import_utils import (
|
||||
is_note_seq_available,
|
||||
is_omegaconf_available,
|
||||
is_onnx_available,
|
||||
is_peft_available,
|
||||
is_scipy_available,
|
||||
is_tensorboard_available,
|
||||
is_torch_available,
|
||||
@@ -82,7 +83,16 @@ from .import_utils import (
|
||||
from .loading_utils import load_image
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import (
|
||||
get_adapter_name,
|
||||
get_rank_and_alpha_pattern,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
|
||||
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -267,6 +267,14 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_invisible_watermark_available = False
|
||||
|
||||
|
||||
_peft_available = importlib.util.find_spec("peft") is not None
|
||||
try:
|
||||
_peft_version = importlib_metadata.version("peft")
|
||||
logger.debug(f"Successfully imported accelerate version {_peft_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_peft_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
@@ -351,6 +359,10 @@ def is_invisible_watermark_available():
|
||||
return _invisible_watermark_available
|
||||
|
||||
|
||||
def is_peft_available():
|
||||
return _peft_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
|
||||
138
src/diffusers/utils/peft_utils.py
Normal file
138
src/diffusers/utils/peft_utils.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
PEFT utilities: Utilities related to peft library
|
||||
"""
|
||||
from .import_utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def recurse_remove_peft_layers(model):
|
||||
r"""
|
||||
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
|
||||
"""
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
## compound module, go inside it
|
||||
recurse_remove_peft_layers(module)
|
||||
|
||||
module_replaced = False
|
||||
|
||||
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
||||
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
|
||||
module.weight.device
|
||||
)
|
||||
new_module.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_module.bias = module.bias
|
||||
|
||||
module_replaced = True
|
||||
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
|
||||
new_module = torch.nn.Conv2d(
|
||||
module.in_channels,
|
||||
module.out_channels,
|
||||
module.kernel_size,
|
||||
module.stride,
|
||||
module.padding,
|
||||
module.dilation,
|
||||
module.groups,
|
||||
module.bias,
|
||||
).to(module.weight.device)
|
||||
|
||||
new_module.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_module.bias = module.bias
|
||||
|
||||
module_replaced = True
|
||||
|
||||
if module_replaced:
|
||||
setattr(model, name, new_module)
|
||||
del module
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def scale_lora_layers(model, lora_weightage):
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.scale_layer(lora_weightage)
|
||||
|
||||
|
||||
def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict):
|
||||
rank_pattern = None
|
||||
alpha_pattern = None
|
||||
r = lora_alpha = list(rank_dict.values())[0]
|
||||
if len(set(rank_dict.values())) > 1:
|
||||
# get the rank occuring the most number of times
|
||||
r = max(set(rank_dict.values()), key=list(rank_dict.values()).count)
|
||||
|
||||
# for modules with rank different from the most occuring rank, add it to the `rank_pattern`
|
||||
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
|
||||
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
|
||||
|
||||
if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1:
|
||||
# get the alpha occuring the most number of times
|
||||
lora_alpha = max(set(network_alpha_dict.values()), key=list(network_alpha_dict.values()).count)
|
||||
|
||||
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
|
||||
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
|
||||
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
|
||||
|
||||
# layer names without the Diffusers specific
|
||||
target_modules = {name.split(".lora")[0] for name in peft_state_dict.keys()}
|
||||
|
||||
return r, lora_alpha, rank_pattern, alpha_pattern, target_modules
|
||||
|
||||
|
||||
def get_adapter_name(model):
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return f"default_{len(module.r)}"
|
||||
return "default_0"
|
||||
|
||||
|
||||
def set_adapter_layers(model, enabled=True):
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.disable_adapters = False if enabled else True
|
||||
|
||||
|
||||
def set_weights_and_activate_adapters(model, adapter_names, weights):
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
# iterate over each adapter, make it active and set the corresponding scaling weight
|
||||
for adapter_name, weight in zip(adapter_names, weights):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.active_adapter = adapter_name
|
||||
module.scale_layer(weight)
|
||||
|
||||
# set multiple active adapters
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.active_adapter = adapter_names
|
||||
180
src/diffusers/utils/state_dict_utils.py
Normal file
180
src/diffusers/utils/state_dict_utils.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
State dict utilities: utility methods for converting state dicts easily
|
||||
"""
|
||||
import enum
|
||||
|
||||
|
||||
class StateDictType(enum.Enum):
|
||||
"""
|
||||
The mode to use when converting state dicts.
|
||||
"""
|
||||
|
||||
DIFFUSERS_OLD = "diffusers_old"
|
||||
# KOHYA_SS = "kohya_ss" # TODO: implement this
|
||||
PEFT = "peft"
|
||||
DIFFUSERS = "diffusers"
|
||||
|
||||
|
||||
DIFFUSERS_TO_PEFT = {
|
||||
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
|
||||
".q_proj.lora_linear_layer.down": ".q_proj.lora_A",
|
||||
".k_proj.lora_linear_layer.up": ".k_proj.lora_B",
|
||||
".k_proj.lora_linear_layer.down": ".k_proj.lora_A",
|
||||
".v_proj.lora_linear_layer.up": ".v_proj.lora_B",
|
||||
".v_proj.lora_linear_layer.down": ".v_proj.lora_A",
|
||||
".out_proj.lora_linear_layer.up": ".out_proj.lora_B",
|
||||
".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
|
||||
}
|
||||
|
||||
DIFFUSERS_OLD_TO_PEFT = {
|
||||
".to_q_lora.up": ".q_proj.lora_B",
|
||||
".to_q_lora.down": ".q_proj.lora_A",
|
||||
".to_k_lora.up": ".k_proj.lora_B",
|
||||
".to_k_lora.down": ".k_proj.lora_A",
|
||||
".to_v_lora.up": ".v_proj.lora_B",
|
||||
".to_v_lora.down": ".v_proj.lora_A",
|
||||
".to_out_lora.up": ".out_proj.lora_B",
|
||||
".to_out_lora.down": ".out_proj.lora_A",
|
||||
}
|
||||
|
||||
PEFT_TO_DIFFUSERS = {
|
||||
".q_proj.lora_B": ".q_proj.lora_linear_layer.up",
|
||||
".q_proj.lora_A": ".q_proj.lora_linear_layer.down",
|
||||
".k_proj.lora_B": ".k_proj.lora_linear_layer.up",
|
||||
".k_proj.lora_A": ".k_proj.lora_linear_layer.down",
|
||||
".v_proj.lora_B": ".v_proj.lora_linear_layer.up",
|
||||
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
|
||||
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
|
||||
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
|
||||
}
|
||||
|
||||
DIFFUSERS_OLD_TO_DIFFUSERS = {
|
||||
".to_q_lora.up": ".q_proj.lora_linear_layer.up",
|
||||
".to_q_lora.down": ".q_proj.lora_linear_layer.down",
|
||||
".to_k_lora.up": ".k_proj.lora_linear_layer.up",
|
||||
".to_k_lora.down": ".k_proj.lora_linear_layer.down",
|
||||
".to_v_lora.up": ".v_proj.lora_linear_layer.up",
|
||||
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
|
||||
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
|
||||
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
|
||||
}
|
||||
|
||||
PEFT_STATE_DICT_MAPPINGS = {
|
||||
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
|
||||
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
|
||||
}
|
||||
|
||||
DIFFUSERS_STATE_DICT_MAPPINGS = {
|
||||
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS,
|
||||
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
|
||||
}
|
||||
|
||||
|
||||
def convert_state_dict(state_dict, mapping):
|
||||
r"""
|
||||
Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
mapping (`dict[str, str]`):
|
||||
The mapping to use for conversion, the mapping should be a dictionary with the following structure:
|
||||
- key: the pattern to replace
|
||||
- value: the pattern to replace with
|
||||
|
||||
Returns:
|
||||
converted_state_dict (`dict`)
|
||||
The converted state dict.
|
||||
"""
|
||||
converted_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if any(pattern in k for pattern in mapping.keys()):
|
||||
for old, new in mapping.items():
|
||||
k = k.replace(old, new)
|
||||
|
||||
converted_state_dict[k] = v
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs):
|
||||
r"""
|
||||
Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or
|
||||
new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
original_type (`StateDictType`, *optional*):
|
||||
The original type of the state dict, if not provided, the method will try to infer it automatically.
|
||||
"""
|
||||
if original_type is None:
|
||||
# Old diffusers to PEFT
|
||||
if any("to_out_lora" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.DIFFUSERS_OLD
|
||||
elif any("lora_linear_layer" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.DIFFUSERS
|
||||
else:
|
||||
raise ValueError("Could not automatically infer state dict type")
|
||||
|
||||
if original_type not in PEFT_STATE_DICT_MAPPINGS.keys():
|
||||
raise ValueError(f"Original type {original_type} is not supported")
|
||||
|
||||
mapping = PEFT_STATE_DICT_MAPPINGS[original_type]
|
||||
return convert_state_dict(state_dict, mapping)
|
||||
|
||||
|
||||
def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
|
||||
r"""
|
||||
Converts a state dict to new diffusers format. The state dict can be from previous diffusers format
|
||||
(`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will
|
||||
return the state dict as is.
|
||||
|
||||
The method only supports the conversion from diffusers old, PEFT to diffusers new for now.
|
||||
|
||||
Args:
|
||||
state_dict (`dict[str, torch.Tensor]`):
|
||||
The state dict to convert.
|
||||
original_type (`StateDictType`, *optional*):
|
||||
The original type of the state dict, if not provided, the method will try to infer it automatically.
|
||||
kwargs (`dict`, *args*):
|
||||
Additional arguments to pass to the method.
|
||||
|
||||
- **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
|
||||
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
|
||||
`get_peft_model_state_dict` method:
|
||||
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
|
||||
but we add it here in case we don't want to rely on that method.
|
||||
"""
|
||||
peft_adapter_name = kwargs.pop("adapter_name", "")
|
||||
peft_adapter_name = "." + peft_adapter_name
|
||||
|
||||
if original_type is None:
|
||||
# Old diffusers to PEFT
|
||||
if any("to_out_lora" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.DIFFUSERS_OLD
|
||||
elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
|
||||
original_type = StateDictType.PEFT
|
||||
elif any("lora_linear_layer" in k for k in state_dict.keys()):
|
||||
# nothing to do
|
||||
return state_dict
|
||||
else:
|
||||
raise ValueError("Could not automatically infer state dict type")
|
||||
|
||||
if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys():
|
||||
raise ValueError(f"Original type {original_type} is not supported")
|
||||
|
||||
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
|
||||
return convert_state_dict(state_dict, mapping)
|
||||
147
tests/lora/test_lora_layers_peft.py
Normal file
147
tests/lora/test_lora_layers_peft.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import (
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor
|
||||
|
||||
|
||||
def create_unet_lora_layers(unet: nn.Module):
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
|
||||
return lora_attn_procs, unet_lora_layers
|
||||
|
||||
|
||||
class LoraLoaderMixinTests(unittest.TestCase):
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
|
||||
|
||||
pipeline_components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
lora_components = {
|
||||
"unet_lora_layers": unet_lora_layers,
|
||||
"unet_lora_attn_procs": unet_lora_attn_procs,
|
||||
}
|
||||
return pipeline_components, lora_components
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 10
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes)
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
|
||||
def get_dummy_tokens(self):
|
||||
max_seq_length = 77
|
||||
|
||||
inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
|
||||
|
||||
prepared_inputs = {}
|
||||
prepared_inputs["input_ids"] = inputs
|
||||
return prepared_inputs
|
||||
Reference in New Issue
Block a user