Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
cf6aafbedf ability to delete a single adapter. 2024-12-12 17:58:58 +05:30
3 changed files with 49 additions and 19 deletions

View File

@@ -177,11 +177,15 @@ def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = Non
set_adapter_layers(text_encoder, enabled=True)
def _remove_text_encoder_monkey_patch(text_encoder):
recurse_remove_peft_layers(text_encoder)
if getattr(text_encoder, "peft_config", None) is not None:
del text_encoder.peft_config
text_encoder._hf_peft_config_loaded = None
def _remove_text_encoder_monkey_patch(text_encoder, adapter_names=None):
if adapter_names is None:
recurse_remove_peft_layers(text_encoder)
if getattr(text_encoder, "peft_config", None) is not None:
del text_encoder.peft_config
text_encoder._hf_peft_config_loaded = None
else:
for adapter_name in adapter_names:
delete_adapter_layers(text_encoder, adapter_name)
def _fetch_state_dict(
@@ -361,10 +365,12 @@ class LoraBaseMixin:
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
return _best_guess_weight_name(*args, **kwargs)
def unload_lora_weights(self):
def unload_lora_weights(self, adapter_names: Optional[Union[List[str], str]] = None):
"""
Unloads the LoRA parameters.
TODO: args
Examples:
```python
@@ -376,13 +382,16 @@ class LoraBaseMixin:
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if adapter_names is not None:
self._raise_error_for_missing_adapters(adapter_names=adapter_names)
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
if issubclass(model.__class__, ModelMixin):
model.unload_lora()
model.unload_lora(adapter_names=adapter_names)
elif issubclass(model.__class__, PreTrainedModel):
_remove_text_encoder_monkey_patch(model)
_remove_text_encoder_monkey_patch(model, adapter_names=adapter_names) # TODO: adapter_names
def fuse_lora(
self,
@@ -539,14 +548,9 @@ class LoraBaseMixin:
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
)
self._raise_error_for_missing_adapters(adapter_names=adapter_names)
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
# eg ["adapter1", "adapter2"]
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
missing_adapters = set(adapter_names) - all_adapters
if len(missing_adapters) > 0:
raise ValueError(
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
)
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
invert_list_adapters = {
@@ -689,6 +693,17 @@ class LoraBaseMixin:
return set_adapters
def _raise_error_for_missing_adapters(self, adapter_names):
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
# eg ["adapter1", "adapter2"]
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
missing_adapters = set(adapter_names) - all_adapters
if len(missing_adapters) > 0:
raise ValueError(
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
)
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
"""
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case

View File

@@ -664,15 +664,28 @@ class PeftAdapterMixin:
if isinstance(module, BaseTunerLayer):
module.unmerge()
def unload_lora(self):
def unload_lora(self, adapter_names: Optional[Union[List[str], str]] = None):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `unload_lora()`.")
from ..utils import recurse_remove_peft_layers
if adapter_names is None:
from ..utils import recurse_remove_peft_layers
recurse_remove_peft_layers(self)
if hasattr(self, "peft_config"):
del self.peft_config
recurse_remove_peft_layers(self)
if hasattr(self, "peft_config"):
del self.peft_config
else:
# We cannot completely unload a particular adapter, so, we temporally deactivate it.
# See more details in https://github.com/huggingface/diffusers/issues/9325#issuecomment-2535510486
if isinstance(adapter_names, str):
adapter_names = list(adapter_names)
present_adapters = getattr(self, "peft_config", {})
for adapter_name in adapter_names:
if adapter_name not in present_adapters:
raise ValueError(
f"{adapter_name} not found in the current list of adapters: {set(present_adapters.keys())}."
)
self.delete_adapters(adapter_names=adapter_names)
def disable_lora(self):
"""

View File

@@ -54,6 +54,8 @@ def recurse_remove_peft_layers(model):
else:
# This is for backwards compatibility with PEFT <= 0.6.2.
# TODO can be removed once that PEFT version is no longer supported.
# If we drop v0.6.0 PEFT support, we could consider a much cleaner code path
# as noted in https://github.com/huggingface/diffusers/issues/9325#issuecomment-2535510486.
from peft.tuners.lora import LoraLayer
for name, module in model.named_children():