mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-25 09:58:17 +08:00
Compare commits
7 Commits
fix-torcha
...
cuda-devic
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e4f83d1046 | ||
|
|
7bbd96da5d | ||
|
|
62777fa819 | ||
|
|
f1fd515257 | ||
|
|
afdda57f61 | ||
|
|
5fc2bd2c8f | ||
|
|
6350a7690a |
@@ -22,7 +22,7 @@ from typing import Set
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger, is_accelerate_available, is_torchao_available
|
||||
from ..utils import get_logger, is_accelerate_available
|
||||
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
@@ -35,54 +35,6 @@ if is_accelerate_available():
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
|
||||
if not is_torchao_available():
|
||||
return False
|
||||
from torchao.utils import TorchAOBaseTensor
|
||||
|
||||
return isinstance(tensor, TorchAOBaseTensor)
|
||||
|
||||
|
||||
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
|
||||
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
|
||||
cls = type(tensor)
|
||||
names = list(getattr(cls, "tensor_data_names", []))
|
||||
for attr_name in getattr(cls, "optional_tensor_data_names", []):
|
||||
if getattr(tensor, attr_name, None) is not None:
|
||||
names.append(attr_name)
|
||||
return names
|
||||
|
||||
|
||||
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
|
||||
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
|
||||
|
||||
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
|
||||
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
|
||||
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
|
||||
that any dict keyed by `id(param)` remains valid.
|
||||
|
||||
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
|
||||
"""
|
||||
torch.utils.swap_tensors(param, source)
|
||||
|
||||
|
||||
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
|
||||
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
|
||||
|
||||
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
|
||||
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
|
||||
`cpu_param_dict`).
|
||||
"""
|
||||
for attr_name in _get_torchao_inner_tensor_names(source):
|
||||
setattr(param, attr_name, getattr(source, attr_name))
|
||||
|
||||
|
||||
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
|
||||
"""Record stream for all internal tensors of a TorchAO parameter."""
|
||||
for attr_name in _get_torchao_inner_tensor_names(param):
|
||||
getattr(param, attr_name).record_stream(stream)
|
||||
|
||||
|
||||
# fmt: off
|
||||
_GROUP_OFFLOADING = "group_offloading"
|
||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||
@@ -205,16 +157,9 @@ class ModuleGroup:
|
||||
pinned_dict = None
|
||||
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
|
||||
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if _is_torchao_tensor(tensor):
|
||||
_swap_torchao_tensor(tensor, moved)
|
||||
else:
|
||||
tensor.data = moved
|
||||
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
if _is_torchao_tensor(tensor):
|
||||
_record_stream_torchao_tensor(tensor, default_stream)
|
||||
else:
|
||||
tensor.data.record_stream(default_stream)
|
||||
tensor.data.record_stream(default_stream)
|
||||
|
||||
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
|
||||
for group_module in self.modules:
|
||||
@@ -300,35 +245,18 @@ class ModuleGroup:
|
||||
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
if _is_torchao_tensor(param):
|
||||
_restore_torchao_tensor(param, self.cpu_param_dict[param])
|
||||
else:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
if _is_torchao_tensor(param):
|
||||
_restore_torchao_tensor(param, self.cpu_param_dict[param])
|
||||
else:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for buffer in self.buffers:
|
||||
if _is_torchao_tensor(buffer):
|
||||
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
|
||||
else:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=False)
|
||||
for param in self.parameters:
|
||||
if _is_torchao_tensor(param):
|
||||
moved = param.data.to(self.offload_device, non_blocking=False)
|
||||
_swap_torchao_tensor(param, moved)
|
||||
else:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
for buffer in self.buffers:
|
||||
if _is_torchao_tensor(buffer):
|
||||
moved = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
_swap_torchao_tensor(buffer, moved)
|
||||
else:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
|
||||
@@ -87,7 +87,14 @@ class HunyuanImageRefinerRMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class HunyuanImageRefinerAttnBlock(nn.Module):
|
||||
|
||||
@@ -87,7 +87,14 @@ class HunyuanVideo15RMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class HunyuanVideo15AttnBlock(nn.Module):
|
||||
|
||||
@@ -105,7 +105,14 @@ class QwenImageRMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class QwenImageUpsample(nn.Upsample):
|
||||
|
||||
@@ -196,7 +196,14 @@ class WanRMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class WanUpsample(nn.Upsample):
|
||||
|
||||
@@ -933,6 +933,7 @@ class QwenImageTransformer2DModel(
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
|
||||
joint_attention_mask = joint_attention_mask[:, None, None, :]
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
|
||||
@@ -16,22 +16,29 @@ from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms
|
||||
import torchvision.transforms.functional
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils import (
|
||||
is_cosmos_guardrail_available,
|
||||
is_torch_xla_available,
|
||||
is_torchvision_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosPipelineOutput
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
import torchvision.transforms.functional
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
@@ -29,6 +29,7 @@ from numpy.linalg import norm
|
||||
from packaging import version
|
||||
|
||||
from .constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from .deprecation_utils import deprecate
|
||||
from .import_utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
@@ -67,9 +68,11 @@ else:
|
||||
global_rng = random.Random()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(
|
||||
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
|
||||
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
|
||||
deprecate(
|
||||
"diffusers.utils.testing_utils",
|
||||
"1.0.0",
|
||||
"diffusers.utils.testing_utils is deprecated and will be removed in a future version. "
|
||||
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. ",
|
||||
)
|
||||
_required_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
|
||||
@@ -481,6 +481,8 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
@@ -488,21 +490,31 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
|
||||
# check possibility to ignore the error/warning
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
@@ -518,20 +530,26 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# check the error and log
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
@require_torch_version_greater("2.7.1")
|
||||
|
||||
@@ -200,7 +200,6 @@ class ContextParallelTesterMixin:
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Context parallel may not support batch_size > 1")
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_inference(cp_type, batch_size=2)
|
||||
|
||||
@@ -286,6 +286,14 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM
|
||||
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for QwenImage Transformer."""
|
||||
|
||||
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
|
||||
def test_hotswapping_compiled_model_linear(self):
|
||||
super().test_hotswapping_compiled_model_linear()
|
||||
|
||||
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
|
||||
def test_hotswapping_compiled_model_both_linear_and_other(self):
|
||||
super().test_hotswapping_compiled_model_both_linear_and_other()
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -182,6 +184,25 @@ class DeprecateTester(unittest.TestCase):
|
||||
assert str(warning.warning) == "This message is better!!!"
|
||||
assert "diffusers/tests/others/test_utils.py" in warning.filename
|
||||
|
||||
def test_deprecate_testing_utils_module(self):
|
||||
import diffusers.utils.testing_utils
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
warnings.simplefilter("always")
|
||||
importlib.reload(diffusers.utils.testing_utils)
|
||||
|
||||
deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, FutureWarning)]
|
||||
assert len(deprecation_warnings) >= 1, "Expected at least one FutureWarning from diffusers.utils.testing_utils"
|
||||
|
||||
messages = [str(w.message) for w in deprecation_warnings]
|
||||
assert any("diffusers.utils.testing_utils" in msg for msg in messages), (
|
||||
f"Expected a deprecation warning mentioning 'diffusers.utils.testing_utils', got: {messages}"
|
||||
)
|
||||
assert any(
|
||||
"diffusers.utils.testing_utils is deprecated and will be removed in a future version." in msg
|
||||
for msg in messages
|
||||
), f"Expected deprecation message substring not found, got: {messages}"
|
||||
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
|
||||
class ExpectationsTester(unittest.TestCase):
|
||||
|
||||
@@ -1534,14 +1534,18 @@ class PipelineTesterMixin:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
|
||||
model_devices = [
|
||||
component.device.type for component in components.values() if getattr(component, "device", None)
|
||||
]
|
||||
self.assertTrue(all(device == "cpu" for device in model_devices))
|
||||
|
||||
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
|
||||
self.assertTrue(np.isnan(output_cpu).sum() == 0)
|
||||
|
||||
pipe.to(torch_device)
|
||||
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
|
||||
model_devices = [
|
||||
component.device.type for component in components.values() if getattr(component, "device", None)
|
||||
]
|
||||
self.assertTrue(all(device == torch_device for device in model_devices))
|
||||
|
||||
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
@@ -1552,11 +1556,11 @@ class PipelineTesterMixin:
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
|
||||
|
||||
@@ -43,7 +43,7 @@ def filter_pipelines(usage_dict, usage_cutoff=10000):
|
||||
|
||||
|
||||
def fetch_pipeline_objects():
|
||||
models = api.list_models(library="diffusers")
|
||||
models = api.list_models(filter="diffusers")
|
||||
downloads = defaultdict(int)
|
||||
|
||||
for model in models:
|
||||
|
||||
Reference in New Issue
Block a user