Compare commits

...

2 Commits

Author SHA1 Message Date
yiyixuxu
c738f14cb0 update scaling dict
add padding draft

update
2024-09-27 17:59:05 +02:00
Benjamin Bossan
d3fbd7bbc1 [WIP][LoRA] Implement hot-swapping of LoRA
This PR adds the possibility to hot-swap LoRA adapters. It is WIP.

Description

As of now, users can already load multiple LoRA adapters. They can
offload existing adapters or they can unload them (i.e. delete them).
However, they cannot "hotswap" adapters yet, i.e. substitute the weights
from one LoRA adapter with the weights of another, without the need to
create a separate LoRA adapter.

Generally, hot-swapping may not appear not super useful but when the
model is compiled, it is necessary to prevent recompilation. See #9279
for more context.

Caveats

To hot-swap a LoRA adapter for another, these two adapters should target
exactly the same layers and the "hyper-parameters" of the two adapters
should be identical. For instance, the LoRA alpha has to be the same:
Given that we keep the alpha from the first adapter, the LoRA scaling
would be incorrect for the second adapter otherwise.

Theoretically, we could override the scaling dict with the alpha values
derived from the second adapter's config, but changing the dict will
trigger a guard for recompilation, defeating the main purpose of the
feature.

I also found that compilation flags can have an impact on whether this
works or not. E.g. when passing "reduce-overhead", there will be errors
of the type:

> input name: arg861_1. data pointer changed from 139647332027392 to
139647331054592

I don't know enough about compilation to determine whether this is
problematic or not.

Current state

This is obviously WIP right now to collect feedback and discuss which
direction to take this. If this PR turns out to be useful, the
hot-swapping functions will be added to PEFT itself and can be imported
here (or there is a separate copy in diffusers to avoid the need for a
min PEFT version to use this feature).

Moreover, more tests need to be added to better cover this feature,
although we don't necessarily need tests for the hot-swapping
functionality itself, since those tests will be added to PEFT.

Furthermore, as of now, this is only implemented for the unet. Other
pipeline components have yet to implement this feature.

Finally, it should be properly documented.

I would like to collect feedback on the current state of the PR before
putting more time into finalizing it.
2024-09-17 16:09:01 +02:00
3 changed files with 250 additions and 13 deletions

View File

@@ -63,7 +63,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_name = TEXT_ENCODER_NAME
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name=None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
@@ -88,6 +92,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
hotswap TODO
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -109,6 +114,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
adapter_name=adapter_name,
_pipeline=self,
hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -232,7 +238,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
return state_dict, network_alphas
@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, hotswap: bool = False):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -250,6 +256,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
hotswap TODO
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -263,7 +270,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
hotswap=hotswap,
)
@classmethod

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import os
from collections import defaultdict
import collections
from contextlib import nullcontext
from pathlib import Path
from typing import Callable, Dict, Union
@@ -56,6 +57,56 @@ logger = logging.get_logger(__name__)
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
def pad_lora_weights(state_dict, target_rank):
"""
Pad LoRA weights in a state dict to a target rank while preserving the original behavior.
Args:
state_dict (dict): The state dict containing LoRA weights
target_rank (int): The target rank to pad to
Returns:
new_state_dict: A new state dict with padded LoRA weights
"""
new_state_dict = {}
for key, weight in state_dict.items():
if "lora_A" in key or "lora_B" in key:
is_conv = weight.dim() == 4
if "lora_A" in key:
original_rank = weight.size(0)
if original_rank >= target_rank:
new_state_dict[key] = weight
continue
if is_conv:
padded = torch.zeros(target_rank, weight.size(1), weight.size(2), weight.size(3),
device=weight.device, dtype=weight.dtype)
padded[:original_rank, :, :, :] = weight
else:
padded = torch.zeros(target_rank, weight.size(1), device=weight.device, dtype=weight.dtype)
padded[:original_rank, :] = weight
elif "lora_B" in key:
original_rank = weight.size(1)
if original_rank >= target_rank:
new_state_dict[key] = weight
continue
if is_conv:
padded = torch.zeros(weight.size(0), target_rank, weight.size(2), weight.size(3),
device=weight.device, dtype=weight.dtype)
padded[:, :original_rank, :, :] = weight
else:
padded = torch.zeros(weight.size(0), target_rank, device=weight.device, dtype=weight.dtype)
padded[:, :original_rank] = weight
new_state_dict[key] = padded
else:
new_state_dict[key] = weight
return new_state_dict
class UNet2DConditionLoadersMixin:
"""
@@ -66,7 +117,7 @@ class UNet2DConditionLoadersMixin:
unet_name = UNET_NAME
@validate_hf_hub_args
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], hotswap: bool = False, **kwargs):
r"""
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
defined in
@@ -115,6 +166,7 @@ class UNet2DConditionLoadersMixin:
`default_{i}` where i is the total number of adapters being loaded.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
hotswap TODO
Example:
@@ -209,6 +261,7 @@ class UNet2DConditionLoadersMixin:
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
hotswap=hotswap,
)
else:
raise ValueError(
@@ -268,7 +321,7 @@ class UNet2DConditionLoadersMixin:
return attn_processors
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, hotswap: bool = False):
# This method does the following things:
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
# format. For legacy format no filtering is applied.
@@ -299,23 +352,38 @@ class UNet2DConditionLoadersMixin:
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
if len(state_dict_to_be_used) > 0:
if adapter_name in getattr(self, "peft_config", {}):
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
raise ValueError(
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
)
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
raise ValueError(f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name.")
def get_rank(state_dict):
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
return rank
def get_r(rank_dict):
r = list(rank_dict.values())[0]
if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]
return r
state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
r = get_r(get_rank(state_dict))
state_dict = pad_lora_weights(state_dict, 128)
if network_alphas is not None:
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
# `convert_unet_state_dict_to_peft` method.
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
rank = get_rank(state_dict)
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
@@ -336,8 +404,128 @@ class UNet2DConditionLoadersMixin:
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
def _check_hotswap_configs_compatible(config0, config1):
# To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they
# use different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the
# weights from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these
# values as well, but that's not implemented yet, and it would trigger a re-compilation if the model is compiled.
# TODO: This is a very rough check at the moment and there are probably better ways than to error out
config_keys_to_check = ["use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]
config0 = config0.to_dict()
config1 = config1.to_dict()
for key in config_keys_to_check:
val0 = config0[key]
val1 = config1[key]
if val0 != val1:
raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}")
def _update_scaling(model, adapter_name, scaling_factor=None):
target_modules = model.peft_config[adapter_name].target_modules
for name, lora_module in model.named_modules():
if name in target_modules and hasattr(lora_module, "scaling"):
if not isinstance(lora_module.scaling[adapter_name], torch.Tensor):
lora_module.scaling[adapter_name] = torch.tensor(scaling_factor, device=lora_module.weight.device)
else:
lora_module.scaling[adapter_name].fill_(scaling_factor)
def _hotswap_adapter_from_state_dict(model, state_dict, adapter_name):
"""
Swap out the LoRA weights from the model with the weights from state_dict.
It is assumed that the existing adapter and the new adapter are compatible.
Args:
model: nn.Module
The model with the loaded adapter.
state_dict: dict[str, torch.Tensor]
The state dict of the new adapter, which needs to be compatible (targeting same modules etc.).
adapter_name: Optional[str]
The name of the adapter that should be hot-swapped.
Raises:
RuntimeError
If the old and the new adapter are not compatible, a RuntimeError is raised.
"""
from operator import attrgetter
#######################
# INSERT ADAPTER NAME #
#######################
remapped_state_dict = {}
expected_str = adapter_name + "."
for key, val in state_dict.items():
if expected_str not in key:
prefix, _, suffix = key.rpartition(".")
key = f"{prefix}.{adapter_name}.{suffix}"
remapped_state_dict[key] = val
state_dict = remapped_state_dict
####################
# CHECK STATE_DICT #
####################
# Ensure that all the keys of the new adapter correspond exactly to the keys of the old adapter, otherwise
# hot-swapping is not possible
parameter_prefix = "lora_" # hard-coded for now
is_compiled = hasattr(model, "_orig_mod")
# TODO: there is probably a more precise way to identify the adapter keys
missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)}
unexpected_keys = set()
# first: dry run, not swapping anything
for key, new_val in state_dict.items():
try:
old_val = attrgetter(key)(model)
except AttributeError:
unexpected_keys.add(key)
continue
if is_compiled:
missing_keys.remove("_orig_mod." + key)
else:
missing_keys.remove(key)
if missing_keys or unexpected_keys:
msg = "Hot swapping the adapter did not succeed."
if missing_keys:
msg += f" Missing keys: {', '.join(sorted(missing_keys))}."
if unexpected_keys:
msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}."
raise RuntimeError(msg)
###################
# ACTUAL SWAPPING #
###################
for key, new_val in state_dict.items():
# no need to account for potential _orig_mod in key here, as torch handles that
old_val = attrgetter(key)(model)
# print(f" dtype: {old_val.data.dtype}/{new_val.data.dtype}, layout: {old_val.data.layout}/{new_val.data.layout}")
old_val.data.copy_(new_val.data.to(device=old_val.device))
# TODO: wanted to use swap_tensors but this somehow does not work on nn.Parameter
# torch.utils.swap_tensors(old_val.data, new_val.data)
if hotswap:
_check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
self.peft_config[adapter_name] = lora_config
# update r & scaling
self.peft_config[adapter_name].r = r
new_scaling_factor = self.peft_config[adapter_name].lora_alpha/self.peft_config[adapter_name].r
_update_scaling(self, adapter_name, new_scaling_factor)
_hotswap_adapter_from_state_dict(self, state_dict, adapter_name)
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set it to None
incompatible_keys = None
else:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
# update r & scaling
self.peft_config[adapter_name].r = r
new_scaling_factor = self.peft_config[adapter_name].lora_alpha/r
_update_scaling(self, adapter_name, new_scaling_factor)
if incompatible_keys is not None:
# check only for unexpected keys

View File

@@ -18,6 +18,7 @@ import json
import os
import random
import shutil
import subprocess
import sys
import tempfile
import traceback
@@ -2014,3 +2015,40 @@ class PipelineNightlyTests(unittest.TestCase):
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
class TestLoraHotSwapping:
def test_hotswapping_peft_config_incompatible_raises(self):
# TODO
pass
def test_hotswapping_no_existing_adapter_raises(self):
# TODO
pass
def test_hotswapping_works(self):
# TODO
pass
def test_hotswapping_compiled_model_does_not_trigger_recompilation(self):
# TODO: kinda slow, should it get a slow marker?
env = {"TORCH_LOGS": "guards,recompiles"}
here = os.path.dirname(__file__)
file_name = os.path.join(here, "run_compiled_model_hotswap.py")
process = subprocess.Popen(
[sys.executable, file_name],
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
# Communicate will read the output and error streams, preventing deadlock
stdout, stderr = process.communicate()
exit_code = process.returncode
# sanity check:
assert exit_code == 0
# check that the recompilation message is not present
assert "__recompiles" not in stderr.decode()