mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
1 Commits
torchao-lo
...
unload-sin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf6aafbedf |
@@ -177,11 +177,15 @@ def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = Non
|
||||
set_adapter_layers(text_encoder, enabled=True)
|
||||
|
||||
|
||||
def _remove_text_encoder_monkey_patch(text_encoder):
|
||||
recurse_remove_peft_layers(text_encoder)
|
||||
if getattr(text_encoder, "peft_config", None) is not None:
|
||||
del text_encoder.peft_config
|
||||
text_encoder._hf_peft_config_loaded = None
|
||||
def _remove_text_encoder_monkey_patch(text_encoder, adapter_names=None):
|
||||
if adapter_names is None:
|
||||
recurse_remove_peft_layers(text_encoder)
|
||||
if getattr(text_encoder, "peft_config", None) is not None:
|
||||
del text_encoder.peft_config
|
||||
text_encoder._hf_peft_config_loaded = None
|
||||
else:
|
||||
for adapter_name in adapter_names:
|
||||
delete_adapter_layers(text_encoder, adapter_name)
|
||||
|
||||
|
||||
def _fetch_state_dict(
|
||||
@@ -361,10 +365,12 @@ class LoraBaseMixin:
|
||||
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
||||
return _best_guess_weight_name(*args, **kwargs)
|
||||
|
||||
def unload_lora_weights(self):
|
||||
def unload_lora_weights(self, adapter_names: Optional[Union[List[str], str]] = None):
|
||||
"""
|
||||
Unloads the LoRA parameters.
|
||||
|
||||
TODO: args
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
@@ -376,13 +382,16 @@ class LoraBaseMixin:
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if adapter_names is not None:
|
||||
self._raise_error_for_missing_adapters(adapter_names=adapter_names)
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.unload_lora()
|
||||
model.unload_lora(adapter_names=adapter_names)
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
_remove_text_encoder_monkey_patch(model)
|
||||
_remove_text_encoder_monkey_patch(model, adapter_names=adapter_names) # TODO: adapter_names
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
@@ -539,14 +548,9 @@ class LoraBaseMixin:
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
||||
)
|
||||
|
||||
self._raise_error_for_missing_adapters(adapter_names=adapter_names)
|
||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||
# eg ["adapter1", "adapter2"]
|
||||
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
||||
missing_adapters = set(adapter_names) - all_adapters
|
||||
if len(missing_adapters) > 0:
|
||||
raise ValueError(
|
||||
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
||||
)
|
||||
|
||||
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||
invert_list_adapters = {
|
||||
@@ -689,6 +693,17 @@ class LoraBaseMixin:
|
||||
|
||||
return set_adapters
|
||||
|
||||
def _raise_error_for_missing_adapters(self, adapter_names):
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||
# eg ["adapter1", "adapter2"]
|
||||
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
||||
missing_adapters = set(adapter_names) - all_adapters
|
||||
if len(missing_adapters) > 0:
|
||||
raise ValueError(
|
||||
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
||||
)
|
||||
|
||||
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
||||
"""
|
||||
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
||||
|
||||
@@ -664,15 +664,28 @@ class PeftAdapterMixin:
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
def unload_lora(self):
|
||||
def unload_lora(self, adapter_names: Optional[Union[List[str], str]] = None):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `unload_lora()`.")
|
||||
|
||||
from ..utils import recurse_remove_peft_layers
|
||||
if adapter_names is None:
|
||||
from ..utils import recurse_remove_peft_layers
|
||||
|
||||
recurse_remove_peft_layers(self)
|
||||
if hasattr(self, "peft_config"):
|
||||
del self.peft_config
|
||||
recurse_remove_peft_layers(self)
|
||||
if hasattr(self, "peft_config"):
|
||||
del self.peft_config
|
||||
else:
|
||||
# We cannot completely unload a particular adapter, so, we temporally deactivate it.
|
||||
# See more details in https://github.com/huggingface/diffusers/issues/9325#issuecomment-2535510486
|
||||
if isinstance(adapter_names, str):
|
||||
adapter_names = list(adapter_names)
|
||||
present_adapters = getattr(self, "peft_config", {})
|
||||
for adapter_name in adapter_names:
|
||||
if adapter_name not in present_adapters:
|
||||
raise ValueError(
|
||||
f"{adapter_name} not found in the current list of adapters: {set(present_adapters.keys())}."
|
||||
)
|
||||
self.delete_adapters(adapter_names=adapter_names)
|
||||
|
||||
def disable_lora(self):
|
||||
"""
|
||||
|
||||
@@ -54,6 +54,8 @@ def recurse_remove_peft_layers(model):
|
||||
else:
|
||||
# This is for backwards compatibility with PEFT <= 0.6.2.
|
||||
# TODO can be removed once that PEFT version is no longer supported.
|
||||
# If we drop v0.6.0 PEFT support, we could consider a much cleaner code path
|
||||
# as noted in https://github.com/huggingface/diffusers/issues/9325#issuecomment-2535510486.
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
for name, module in model.named_children():
|
||||
|
||||
Reference in New Issue
Block a user