mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
3 Commits
ruff-updat
...
lora-compi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f5007333b | ||
|
|
a2c931aaeb | ||
|
|
0e7204abcd |
@@ -38,6 +38,7 @@ from ..utils import (
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -371,6 +372,7 @@ class LoraBaseMixin:
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.unload_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
@@ -446,6 +448,7 @@ class LoraBaseMixin:
|
||||
|
||||
model = getattr(self, fuse_component, None)
|
||||
if model is not None:
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
# check if diffusers model
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
@@ -506,6 +509,7 @@ class LoraBaseMixin:
|
||||
|
||||
model = getattr(self, fuse_component, None)
|
||||
if model is not None:
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
@@ -569,6 +573,7 @@ class LoraBaseMixin:
|
||||
_component_adapter_weights.setdefault(component, [])
|
||||
_component_adapter_weights[component].append(component_adapter_weights)
|
||||
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.set_adapters(adapter_names, _component_adapter_weights[component])
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
@@ -581,6 +586,7 @@ class LoraBaseMixin:
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.disable_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
@@ -593,6 +599,7 @@ class LoraBaseMixin:
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.enable_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
@@ -614,6 +621,7 @@ class LoraBaseMixin:
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.delete_adapters(adapter_names)
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
@@ -645,6 +653,7 @@ class LoraBaseMixin:
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
if model is not None and issubclass(model.__class__, ModelMixin):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
@@ -666,12 +675,10 @@ class LoraBaseMixin:
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if (
|
||||
model is not None
|
||||
and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
|
||||
and hasattr(model, "peft_config")
|
||||
):
|
||||
set_adapters[component] = list(model.peft_config.keys())
|
||||
if model is not None:
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)) and hasattr(model, "peft_config"):
|
||||
set_adapters[component] = list(model.peft_config.keys())
|
||||
|
||||
return set_adapters
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ from ..utils import (
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
)
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
from .lora_base import LoraBaseMixin
|
||||
from .lora_conversion_utils import (
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
@@ -1752,6 +1753,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
if is_compiled_module(transformer):
|
||||
state_dict = {"_orig_mod." + k: v for k, v in state_dict.items()}
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -1744,6 +1744,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
for name, component in pipeline.components.items():
|
||||
if name in expected_modules and name not in passed_class_obj:
|
||||
# for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
|
||||
component = component._orig_mod if is_compiled_module(component) else component
|
||||
if (
|
||||
not isinstance(component, ModelMixin)
|
||||
or type(component) in component_types[name]
|
||||
|
||||
Reference in New Issue
Block a user