mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-09 22:14:43 +08:00
Compare commits
2 Commits
memory-opt
...
lora-hot-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c738f14cb0 | ||
|
|
d3fbd7bbc1 |
@@ -63,7 +63,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name=None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
@@ -88,6 +92,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
hotswap TODO
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -109,6 +114,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
@@ -232,7 +238,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
return state_dict, network_alphas
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, hotswap: bool = False):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
@@ -250,6 +256,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
hotswap TODO
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -263,7 +270,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_attn_procs(
|
||||
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import collections
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Union
|
||||
@@ -56,6 +57,56 @@ logger = logging.get_logger(__name__)
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
||||
|
||||
def pad_lora_weights(state_dict, target_rank):
|
||||
"""
|
||||
Pad LoRA weights in a state dict to a target rank while preserving the original behavior.
|
||||
|
||||
Args:
|
||||
state_dict (dict): The state dict containing LoRA weights
|
||||
target_rank (int): The target rank to pad to
|
||||
|
||||
Returns:
|
||||
new_state_dict: A new state dict with padded LoRA weights
|
||||
"""
|
||||
new_state_dict = {}
|
||||
|
||||
for key, weight in state_dict.items():
|
||||
if "lora_A" in key or "lora_B" in key:
|
||||
is_conv = weight.dim() == 4
|
||||
|
||||
if "lora_A" in key:
|
||||
original_rank = weight.size(0)
|
||||
if original_rank >= target_rank:
|
||||
new_state_dict[key] = weight
|
||||
continue
|
||||
|
||||
if is_conv:
|
||||
padded = torch.zeros(target_rank, weight.size(1), weight.size(2), weight.size(3),
|
||||
device=weight.device, dtype=weight.dtype)
|
||||
padded[:original_rank, :, :, :] = weight
|
||||
else:
|
||||
padded = torch.zeros(target_rank, weight.size(1), device=weight.device, dtype=weight.dtype)
|
||||
padded[:original_rank, :] = weight
|
||||
|
||||
elif "lora_B" in key:
|
||||
original_rank = weight.size(1)
|
||||
if original_rank >= target_rank:
|
||||
new_state_dict[key] = weight
|
||||
continue
|
||||
|
||||
if is_conv:
|
||||
padded = torch.zeros(weight.size(0), target_rank, weight.size(2), weight.size(3),
|
||||
device=weight.device, dtype=weight.dtype)
|
||||
padded[:, :original_rank, :, :] = weight
|
||||
else:
|
||||
padded = torch.zeros(weight.size(0), target_rank, device=weight.device, dtype=weight.dtype)
|
||||
padded[:, :original_rank] = weight
|
||||
|
||||
new_state_dict[key] = padded
|
||||
else:
|
||||
new_state_dict[key] = weight
|
||||
|
||||
return new_state_dict
|
||||
|
||||
class UNet2DConditionLoadersMixin:
|
||||
"""
|
||||
@@ -66,7 +117,7 @@ class UNet2DConditionLoadersMixin:
|
||||
unet_name = UNET_NAME
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], hotswap: bool = False, **kwargs):
|
||||
r"""
|
||||
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
|
||||
defined in
|
||||
@@ -115,6 +166,7 @@ class UNet2DConditionLoadersMixin:
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
hotswap TODO
|
||||
|
||||
Example:
|
||||
|
||||
@@ -209,6 +261,7 @@ class UNet2DConditionLoadersMixin:
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -268,7 +321,7 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
return attn_processors
|
||||
|
||||
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
|
||||
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, hotswap: bool = False):
|
||||
# This method does the following things:
|
||||
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
|
||||
# format. For legacy format no filtering is applied.
|
||||
@@ -299,23 +352,38 @@ class UNet2DConditionLoadersMixin:
|
||||
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
|
||||
|
||||
if len(state_dict_to_be_used) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
||||
)
|
||||
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
|
||||
raise ValueError(f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name.")
|
||||
|
||||
def get_rank(state_dict):
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
return rank
|
||||
|
||||
def get_r(rank_dict):
|
||||
r = list(rank_dict.values())[0]
|
||||
if len(set(rank_dict.values())) > 1:
|
||||
# get the rank occuring the most number of times
|
||||
r = collections.Counter(rank_dict.values()).most_common()[0][0]
|
||||
return r
|
||||
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
|
||||
r = get_r(get_rank(state_dict))
|
||||
|
||||
state_dict = pad_lora_weights(state_dict, 128)
|
||||
|
||||
if network_alphas is not None:
|
||||
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
|
||||
# `convert_unet_state_dict_to_peft` method.
|
||||
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
rank = get_rank(state_dict)
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
@@ -336,8 +404,128 @@ class UNet2DConditionLoadersMixin:
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
|
||||
|
||||
def _check_hotswap_configs_compatible(config0, config1):
|
||||
# To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they
|
||||
# use different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the
|
||||
# weights from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these
|
||||
# values as well, but that's not implemented yet, and it would trigger a re-compilation if the model is compiled.
|
||||
|
||||
# TODO: This is a very rough check at the moment and there are probably better ways than to error out
|
||||
config_keys_to_check = ["use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]
|
||||
config0 = config0.to_dict()
|
||||
config1 = config1.to_dict()
|
||||
for key in config_keys_to_check:
|
||||
val0 = config0[key]
|
||||
val1 = config1[key]
|
||||
if val0 != val1:
|
||||
raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}")
|
||||
|
||||
def _update_scaling(model, adapter_name, scaling_factor=None):
|
||||
target_modules = model.peft_config[adapter_name].target_modules
|
||||
for name, lora_module in model.named_modules():
|
||||
if name in target_modules and hasattr(lora_module, "scaling"):
|
||||
if not isinstance(lora_module.scaling[adapter_name], torch.Tensor):
|
||||
lora_module.scaling[adapter_name] = torch.tensor(scaling_factor, device=lora_module.weight.device)
|
||||
else:
|
||||
lora_module.scaling[adapter_name].fill_(scaling_factor)
|
||||
|
||||
def _hotswap_adapter_from_state_dict(model, state_dict, adapter_name):
|
||||
"""
|
||||
Swap out the LoRA weights from the model with the weights from state_dict.
|
||||
|
||||
It is assumed that the existing adapter and the new adapter are compatible.
|
||||
|
||||
Args:
|
||||
model: nn.Module
|
||||
The model with the loaded adapter.
|
||||
state_dict: dict[str, torch.Tensor]
|
||||
The state dict of the new adapter, which needs to be compatible (targeting same modules etc.).
|
||||
adapter_name: Optional[str]
|
||||
The name of the adapter that should be hot-swapped.
|
||||
|
||||
Raises:
|
||||
RuntimeError
|
||||
If the old and the new adapter are not compatible, a RuntimeError is raised.
|
||||
"""
|
||||
from operator import attrgetter
|
||||
|
||||
#######################
|
||||
# INSERT ADAPTER NAME #
|
||||
#######################
|
||||
|
||||
remapped_state_dict = {}
|
||||
expected_str = adapter_name + "."
|
||||
for key, val in state_dict.items():
|
||||
if expected_str not in key:
|
||||
prefix, _, suffix = key.rpartition(".")
|
||||
key = f"{prefix}.{adapter_name}.{suffix}"
|
||||
remapped_state_dict[key] = val
|
||||
state_dict = remapped_state_dict
|
||||
|
||||
####################
|
||||
# CHECK STATE_DICT #
|
||||
####################
|
||||
|
||||
# Ensure that all the keys of the new adapter correspond exactly to the keys of the old adapter, otherwise
|
||||
# hot-swapping is not possible
|
||||
parameter_prefix = "lora_" # hard-coded for now
|
||||
is_compiled = hasattr(model, "_orig_mod")
|
||||
# TODO: there is probably a more precise way to identify the adapter keys
|
||||
missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)}
|
||||
unexpected_keys = set()
|
||||
|
||||
# first: dry run, not swapping anything
|
||||
for key, new_val in state_dict.items():
|
||||
try:
|
||||
old_val = attrgetter(key)(model)
|
||||
except AttributeError:
|
||||
unexpected_keys.add(key)
|
||||
continue
|
||||
|
||||
if is_compiled:
|
||||
missing_keys.remove("_orig_mod." + key)
|
||||
else:
|
||||
missing_keys.remove(key)
|
||||
|
||||
if missing_keys or unexpected_keys:
|
||||
msg = "Hot swapping the adapter did not succeed."
|
||||
if missing_keys:
|
||||
msg += f" Missing keys: {', '.join(sorted(missing_keys))}."
|
||||
if unexpected_keys:
|
||||
msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
###################
|
||||
# ACTUAL SWAPPING #
|
||||
###################
|
||||
|
||||
for key, new_val in state_dict.items():
|
||||
# no need to account for potential _orig_mod in key here, as torch handles that
|
||||
old_val = attrgetter(key)(model)
|
||||
# print(f" dtype: {old_val.data.dtype}/{new_val.data.dtype}, layout: {old_val.data.layout}/{new_val.data.layout}")
|
||||
old_val.data.copy_(new_val.data.to(device=old_val.device))
|
||||
# TODO: wanted to use swap_tensors but this somehow does not work on nn.Parameter
|
||||
# torch.utils.swap_tensors(old_val.data, new_val.data)
|
||||
|
||||
if hotswap:
|
||||
_check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
|
||||
self.peft_config[adapter_name] = lora_config
|
||||
# update r & scaling
|
||||
self.peft_config[adapter_name].r = r
|
||||
new_scaling_factor = self.peft_config[adapter_name].lora_alpha/self.peft_config[adapter_name].r
|
||||
_update_scaling(self, adapter_name, new_scaling_factor)
|
||||
|
||||
_hotswap_adapter_from_state_dict(self, state_dict, adapter_name)
|
||||
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set it to None
|
||||
incompatible_keys = None
|
||||
else:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
|
||||
# update r & scaling
|
||||
self.peft_config[adapter_name].r = r
|
||||
new_scaling_factor = self.peft_config[adapter_name].lora_alpha/r
|
||||
_update_scaling(self, adapter_name, new_scaling_factor)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
|
||||
@@ -18,6 +18,7 @@ import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
@@ -2014,3 +2015,40 @@ class PipelineNightlyTests(unittest.TestCase):
|
||||
|
||||
# the values aren't exactly equal, but the images look the same visually
|
||||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
|
||||
|
||||
|
||||
class TestLoraHotSwapping:
|
||||
def test_hotswapping_peft_config_incompatible_raises(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def test_hotswapping_no_existing_adapter_raises(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def test_hotswapping_works(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def test_hotswapping_compiled_model_does_not_trigger_recompilation(self):
|
||||
# TODO: kinda slow, should it get a slow marker?
|
||||
env = {"TORCH_LOGS": "guards,recompiles"}
|
||||
here = os.path.dirname(__file__)
|
||||
file_name = os.path.join(here, "run_compiled_model_hotswap.py")
|
||||
|
||||
process = subprocess.Popen(
|
||||
[sys.executable, file_name],
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
|
||||
# Communicate will read the output and error streams, preventing deadlock
|
||||
stdout, stderr = process.communicate()
|
||||
exit_code = process.returncode
|
||||
|
||||
# sanity check:
|
||||
assert exit_code == 0
|
||||
|
||||
# check that the recompilation message is not present
|
||||
assert "__recompiles" not in stderr.decode()
|
||||
|
||||
Reference in New Issue
Block a user