Compare commits

...

3 Commits

Author SHA1 Message Date
sayakpaul
8f5007333b more compile compatibility. 2024-09-09 18:33:12 +05:30
sayakpaul
a2c931aaeb merge main 2024-09-09 18:27:16 +05:30
sayakpaul
0e7204abcd rejig lora state dict to account for compiled modules. 2024-08-27 17:22:24 +05:30
3 changed files with 17 additions and 6 deletions

View File

@@ -38,6 +38,7 @@ from ..utils import (
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.torch_utils import is_compiled_module
if is_transformers_available():
@@ -371,6 +372,7 @@ class LoraBaseMixin:
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, ModelMixin):
model.unload_lora()
elif issubclass(model.__class__, PreTrainedModel):
@@ -446,6 +448,7 @@ class LoraBaseMixin:
model = getattr(self, fuse_component, None)
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
# check if diffusers model
if issubclass(model.__class__, ModelMixin):
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
@@ -506,6 +509,7 @@ class LoraBaseMixin:
model = getattr(self, fuse_component, None)
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
@@ -569,6 +573,7 @@ class LoraBaseMixin:
_component_adapter_weights.setdefault(component, [])
_component_adapter_weights[component].append(component_adapter_weights)
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, ModelMixin):
model.set_adapters(adapter_names, _component_adapter_weights[component])
elif issubclass(model.__class__, PreTrainedModel):
@@ -581,6 +586,7 @@ class LoraBaseMixin:
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, ModelMixin):
model.disable_lora()
elif issubclass(model.__class__, PreTrainedModel):
@@ -593,6 +599,7 @@ class LoraBaseMixin:
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, ModelMixin):
model.enable_lora()
elif issubclass(model.__class__, PreTrainedModel):
@@ -614,6 +621,7 @@ class LoraBaseMixin:
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, ModelMixin):
model.delete_adapters(adapter_names)
elif issubclass(model.__class__, PreTrainedModel):
@@ -645,6 +653,7 @@ class LoraBaseMixin:
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
model = model._orig_mod if is_compiled_module(model) else model
if model is not None and issubclass(model.__class__, ModelMixin):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
@@ -666,12 +675,10 @@ class LoraBaseMixin:
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if (
model is not None
and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
and hasattr(model, "peft_config")
):
set_adapters[component] = list(model.peft_config.keys())
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)) and hasattr(model, "peft_config"):
set_adapters[component] = list(model.peft_config.keys())
return set_adapters

View File

@@ -30,6 +30,7 @@ from ..utils import (
logging,
scale_lora_layers,
)
from ..utils.torch_utils import is_compiled_module
from .lora_base import LoraBaseMixin
from .lora_conversion_utils import (
_convert_kohya_flux_lora_to_diffusers,
@@ -1752,6 +1753,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
if is_compiled_module(transformer):
state_dict = {"_orig_mod." + k: v for k, v in state_dict.items()}
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
if incompatible_keys is not None:

View File

@@ -1744,6 +1744,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
for name, component in pipeline.components.items():
if name in expected_modules and name not in passed_class_obj:
# for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
component = component._orig_mod if is_compiled_module(component) else component
if (
not isinstance(component, ModelMixin)
or type(component) in component_types[name]