mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
1 Commits
add-uv-scr
...
unload-sin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf6aafbedf |
@@ -177,11 +177,15 @@ def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = Non
|
|||||||
set_adapter_layers(text_encoder, enabled=True)
|
set_adapter_layers(text_encoder, enabled=True)
|
||||||
|
|
||||||
|
|
||||||
def _remove_text_encoder_monkey_patch(text_encoder):
|
def _remove_text_encoder_monkey_patch(text_encoder, adapter_names=None):
|
||||||
recurse_remove_peft_layers(text_encoder)
|
if adapter_names is None:
|
||||||
if getattr(text_encoder, "peft_config", None) is not None:
|
recurse_remove_peft_layers(text_encoder)
|
||||||
del text_encoder.peft_config
|
if getattr(text_encoder, "peft_config", None) is not None:
|
||||||
text_encoder._hf_peft_config_loaded = None
|
del text_encoder.peft_config
|
||||||
|
text_encoder._hf_peft_config_loaded = None
|
||||||
|
else:
|
||||||
|
for adapter_name in adapter_names:
|
||||||
|
delete_adapter_layers(text_encoder, adapter_name)
|
||||||
|
|
||||||
|
|
||||||
def _fetch_state_dict(
|
def _fetch_state_dict(
|
||||||
@@ -361,10 +365,12 @@ class LoraBaseMixin:
|
|||||||
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
||||||
return _best_guess_weight_name(*args, **kwargs)
|
return _best_guess_weight_name(*args, **kwargs)
|
||||||
|
|
||||||
def unload_lora_weights(self):
|
def unload_lora_weights(self, adapter_names: Optional[Union[List[str], str]] = None):
|
||||||
"""
|
"""
|
||||||
Unloads the LoRA parameters.
|
Unloads the LoRA parameters.
|
||||||
|
|
||||||
|
TODO: args
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -376,13 +382,16 @@ class LoraBaseMixin:
|
|||||||
if not USE_PEFT_BACKEND:
|
if not USE_PEFT_BACKEND:
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
|
|
||||||
|
if adapter_names is not None:
|
||||||
|
self._raise_error_for_missing_adapters(adapter_names=adapter_names)
|
||||||
|
|
||||||
for component in self._lora_loadable_modules:
|
for component in self._lora_loadable_modules:
|
||||||
model = getattr(self, component, None)
|
model = getattr(self, component, None)
|
||||||
if model is not None:
|
if model is not None:
|
||||||
if issubclass(model.__class__, ModelMixin):
|
if issubclass(model.__class__, ModelMixin):
|
||||||
model.unload_lora()
|
model.unload_lora(adapter_names=adapter_names)
|
||||||
elif issubclass(model.__class__, PreTrainedModel):
|
elif issubclass(model.__class__, PreTrainedModel):
|
||||||
_remove_text_encoder_monkey_patch(model)
|
_remove_text_encoder_monkey_patch(model, adapter_names=adapter_names) # TODO: adapter_names
|
||||||
|
|
||||||
def fuse_lora(
|
def fuse_lora(
|
||||||
self,
|
self,
|
||||||
@@ -539,14 +548,9 @@ class LoraBaseMixin:
|
|||||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._raise_error_for_missing_adapters(adapter_names=adapter_names)
|
||||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||||
# eg ["adapter1", "adapter2"]
|
|
||||||
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
||||||
missing_adapters = set(adapter_names) - all_adapters
|
|
||||||
if len(missing_adapters) > 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||||
invert_list_adapters = {
|
invert_list_adapters = {
|
||||||
@@ -689,6 +693,17 @@ class LoraBaseMixin:
|
|||||||
|
|
||||||
return set_adapters
|
return set_adapters
|
||||||
|
|
||||||
|
def _raise_error_for_missing_adapters(self, adapter_names):
|
||||||
|
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||||
|
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||||
|
# eg ["adapter1", "adapter2"]
|
||||||
|
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
||||||
|
missing_adapters = set(adapter_names) - all_adapters
|
||||||
|
if len(missing_adapters) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
||||||
|
)
|
||||||
|
|
||||||
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
||||||
"""
|
"""
|
||||||
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
||||||
|
|||||||
@@ -664,15 +664,28 @@ class PeftAdapterMixin:
|
|||||||
if isinstance(module, BaseTunerLayer):
|
if isinstance(module, BaseTunerLayer):
|
||||||
module.unmerge()
|
module.unmerge()
|
||||||
|
|
||||||
def unload_lora(self):
|
def unload_lora(self, adapter_names: Optional[Union[List[str], str]] = None):
|
||||||
if not USE_PEFT_BACKEND:
|
if not USE_PEFT_BACKEND:
|
||||||
raise ValueError("PEFT backend is required for `unload_lora()`.")
|
raise ValueError("PEFT backend is required for `unload_lora()`.")
|
||||||
|
|
||||||
from ..utils import recurse_remove_peft_layers
|
if adapter_names is None:
|
||||||
|
from ..utils import recurse_remove_peft_layers
|
||||||
|
|
||||||
recurse_remove_peft_layers(self)
|
recurse_remove_peft_layers(self)
|
||||||
if hasattr(self, "peft_config"):
|
if hasattr(self, "peft_config"):
|
||||||
del self.peft_config
|
del self.peft_config
|
||||||
|
else:
|
||||||
|
# We cannot completely unload a particular adapter, so, we temporally deactivate it.
|
||||||
|
# See more details in https://github.com/huggingface/diffusers/issues/9325#issuecomment-2535510486
|
||||||
|
if isinstance(adapter_names, str):
|
||||||
|
adapter_names = list(adapter_names)
|
||||||
|
present_adapters = getattr(self, "peft_config", {})
|
||||||
|
for adapter_name in adapter_names:
|
||||||
|
if adapter_name not in present_adapters:
|
||||||
|
raise ValueError(
|
||||||
|
f"{adapter_name} not found in the current list of adapters: {set(present_adapters.keys())}."
|
||||||
|
)
|
||||||
|
self.delete_adapters(adapter_names=adapter_names)
|
||||||
|
|
||||||
def disable_lora(self):
|
def disable_lora(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ def recurse_remove_peft_layers(model):
|
|||||||
else:
|
else:
|
||||||
# This is for backwards compatibility with PEFT <= 0.6.2.
|
# This is for backwards compatibility with PEFT <= 0.6.2.
|
||||||
# TODO can be removed once that PEFT version is no longer supported.
|
# TODO can be removed once that PEFT version is no longer supported.
|
||||||
|
# If we drop v0.6.0 PEFT support, we could consider a much cleaner code path
|
||||||
|
# as noted in https://github.com/huggingface/diffusers/issues/9325#issuecomment-2535510486.
|
||||||
from peft.tuners.lora import LoraLayer
|
from peft.tuners.lora import LoraLayer
|
||||||
|
|
||||||
for name, module in model.named_children():
|
for name, module in model.named_children():
|
||||||
|
|||||||
Reference in New Issue
Block a user