mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-19 18:15:32 +08:00
Compare commits
1 Commits
device-map
...
apply-lora
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
afa4a23c6c |
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -634,6 +634,7 @@ class FluxTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -675,20 +676,6 @@ class FluxTransformer2DModel(
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
@@ -785,10 +772,6 @@ class FluxTransformer2DModel(
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ LIBRARIES = []
|
||||
for library in LOADABLE_CLASSES:
|
||||
LIBRARIES.append(library)
|
||||
|
||||
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"]
|
||||
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -467,7 +467,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
pipeline_is_sequentially_offloaded = any(
|
||||
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
||||
)
|
||||
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
|
||||
@@ -1186,7 +1187,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
self._maybe_raise_error_if_group_offload_active(raise_error=True)
|
||||
|
||||
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
|
||||
@@ -1310,7 +1311,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
self.remove_all_hooks()
|
||||
|
||||
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
|
||||
@@ -2199,21 +2200,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_pipeline_device_mapped(self):
|
||||
# We support passing `device_map="cuda"`, for example. This is helpful, in case
|
||||
# users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable
|
||||
# in limited VRAM environments because quantized models often initialize directly on the accelerator.
|
||||
device_map = self.hf_device_map
|
||||
is_device_type_map = False
|
||||
if isinstance(device_map, str):
|
||||
try:
|
||||
torch.device(device_map)
|
||||
is_device_type_map = True
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
|
||||
|
||||
|
||||
class StableDiffusionMixin:
|
||||
r"""
|
||||
|
||||
@@ -130,6 +130,7 @@ from .loading_utils import get_module_from_name, get_submodule_by_name, load_ima
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import (
|
||||
apply_lora_scale,
|
||||
check_peft_version,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
|
||||
@@ -16,6 +16,7 @@ PEFT utilities: Utilities related to peft library
|
||||
"""
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
@@ -275,6 +276,59 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
|
||||
module.set_scale(adapter_name, get_module_weight(weight, module_name))
|
||||
|
||||
|
||||
def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"):
|
||||
"""
|
||||
Decorator to automatically handle LoRA layer scaling/unscaling in forward methods.
|
||||
|
||||
This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward
|
||||
pass, and ensures unscaling happens after, even if an exception occurs.
|
||||
|
||||
Args:
|
||||
kwargs_name (`str`, defaults to `"joint_attention_kwargs"`):
|
||||
The name of the keyword argument that contains the LoRA scale. Common values include
|
||||
"joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc.
|
||||
"""
|
||||
|
||||
def decorator(forward_fn):
|
||||
@functools.wraps(forward_fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
from . import USE_PEFT_BACKEND
|
||||
|
||||
lora_scale = 1.0
|
||||
attention_kwargs = kwargs.get(kwargs_name)
|
||||
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
kwargs[kwargs_name] = attention_kwargs
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
if (
|
||||
not USE_PEFT_BACKEND
|
||||
and attention_kwargs is not None
|
||||
and attention_kwargs.get("scale", None) is not None
|
||||
):
|
||||
logger.warning(
|
||||
f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# Apply LoRA scaling if using PEFT backend
|
||||
if USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
try:
|
||||
# Execute the forward pass
|
||||
result = forward_fn(self, *args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
# Always unscale, even if forward pass raises an exception
|
||||
if USE_PEFT_BACKEND:
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_peft_version(min_version: str) -> None:
|
||||
r"""
|
||||
Checks if the version of PEFT is compatible.
|
||||
|
||||
Reference in New Issue
Block a user