Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
afa4a23c6c feat: implement apply_lora_scale to remove boilerplate. 2026-01-19 10:04:24 +05:30
4 changed files with 62 additions and 38 deletions

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -634,6 +634,7 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -675,20 +676,6 @@ class FluxTransformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
@@ -785,10 +772,6 @@ class FluxTransformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -111,7 +111,7 @@ LIBRARIES = []
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"]
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
logger = logging.get_logger(__name__)
@@ -467,7 +467,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
@@ -1186,7 +1187,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
@@ -1310,7 +1311,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
self.remove_all_hooks()
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
@@ -2199,21 +2200,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return True
return False
def _is_pipeline_device_mapped(self):
# We support passing `device_map="cuda"`, for example. This is helpful, in case
# users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable
# in limited VRAM environments because quantized models often initialize directly on the accelerator.
device_map = self.hf_device_map
is_device_type_map = False
if isinstance(device_map, str):
try:
torch.device(device_map)
is_device_type_map = True
except RuntimeError:
pass
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
class StableDiffusionMixin:
r"""

View File

@@ -130,6 +130,7 @@ from .loading_utils import get_module_from_name, get_submodule_by_name, load_ima
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
apply_lora_scale,
check_peft_version,
delete_adapter_layers,
get_adapter_name,

View File

@@ -16,6 +16,7 @@ PEFT utilities: Utilities related to peft library
"""
import collections
import functools
import importlib
from typing import Optional
@@ -275,6 +276,59 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.set_scale(adapter_name, get_module_weight(weight, module_name))
def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"):
"""
Decorator to automatically handle LoRA layer scaling/unscaling in forward methods.
This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward
pass, and ensures unscaling happens after, even if an exception occurs.
Args:
kwargs_name (`str`, defaults to `"joint_attention_kwargs"`):
The name of the keyword argument that contains the LoRA scale. Common values include
"joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc.
"""
def decorator(forward_fn):
@functools.wraps(forward_fn)
def wrapper(self, *args, **kwargs):
from . import USE_PEFT_BACKEND
lora_scale = 1.0
attention_kwargs = kwargs.get(kwargs_name)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
kwargs[kwargs_name] = attention_kwargs
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
if (
not USE_PEFT_BACKEND
and attention_kwargs is not None
and attention_kwargs.get("scale", None) is not None
):
logger.warning(
f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective."
)
# Apply LoRA scaling if using PEFT backend
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
try:
# Execute the forward pass
result = forward_fn(self, *args, **kwargs)
return result
finally:
# Always unscale, even if forward pass raises an exception
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
return wrapper
return decorator
def check_peft_version(min_version: str) -> None:
r"""
Checks if the version of PEFT is compatible.