Compare commits

..

7 Commits

Author SHA1 Message Date
sayakpaul
e4f83d1046 fix to device and to dtype tests. 2026-03-24 17:01:12 +05:30
Dhruv Nair
7bbd96da5d [CI] Update fetching pipelines for latest HF Hub Version (#13322)
update
2026-03-24 16:42:32 +05:30
Dhruv Nair
62777fa819 Fix unguarded torchvision import in Cosmos (#13321)
update
2026-03-24 16:00:24 +05:30
Sayak Paul
f1fd515257 [tests] fix lora logging tests for models. (#13318)
* fix lora logging tests for models.

* make style
2026-03-24 15:48:03 +05:30
Cheung Ka Wai
afdda57f61 Fix the attention mask in ulysses SP for QwenImage (#13278)
* fix mask in SP

* change the modification to qwen specific

* drop xfail since qwen-image mask is fixed

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-24 02:12:50 -07:00
YangKai0616
5fc2bd2c8f Stabilize low-precision custom autoencoder RMS normalization (#13316)
* Stabilize low-precision custom autoencoder RMS normalization

* Add fp8/4

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-24 02:00:05 -07:00
Sayak Paul
6350a7690a [chore] properly deprecate src.diffusers.utils.testing_utils. (#13314)
properly deprecate src.diffusers.utils.testing_utils.
2026-03-24 10:54:35 +05:30
14 changed files with 131 additions and 114 deletions

View File

@@ -22,7 +22,7 @@ from typing import Set
import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ..utils import get_logger, is_accelerate_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
@@ -35,54 +35,6 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
if not is_torchao_available():
return False
from torchao.utils import TorchAOBaseTensor
return isinstance(tensor, TorchAOBaseTensor)
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
cls = type(tensor)
names = list(getattr(cls, "tensor_data_names", []))
for attr_name in getattr(cls, "optional_tensor_data_names", []):
if getattr(tensor, attr_name, None) is not None:
names.append(attr_name)
return names
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
that any dict keyed by `id(param)` remains valid.
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
"""
torch.utils.swap_tensors(param, source)
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
`cpu_param_dict`).
"""
for attr_name in _get_torchao_inner_tensor_names(source):
setattr(param, attr_name, getattr(source, attr_name))
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
"""Record stream for all internal tensors of a TorchAO parameter."""
for attr_name in _get_torchao_inner_tensor_names(param):
getattr(param, attr_name).record_stream(stream)
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
@@ -205,16 +157,9 @@ class ModuleGroup:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if _is_torchao_tensor(tensor):
_swap_torchao_tensor(tensor, moved)
else:
tensor.data = moved
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
else:
tensor.data.record_stream(default_stream)
tensor.data.record_stream(default_stream)
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules:
@@ -300,35 +245,18 @@ class ModuleGroup:
for group_module in self.modules:
for param in group_module.parameters():
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
if _is_torchao_tensor(buffer):
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
else:
buffer.data = self.cpu_param_dict[buffer]
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
if _is_torchao_tensor(param):
moved = param.data.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(param, moved)
else:
param.data = param.data.to(self.offload_device, non_blocking=False)
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
if _is_torchao_tensor(buffer):
moved = buffer.data.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(buffer, moved)
else:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):

View File

@@ -87,7 +87,14 @@ class HunyuanImageRefinerRMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
class HunyuanImageRefinerAttnBlock(nn.Module):

View File

@@ -87,7 +87,14 @@ class HunyuanVideo15RMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
class HunyuanVideo15AttnBlock(nn.Module):

View File

@@ -105,7 +105,14 @@ class QwenImageRMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
class QwenImageUpsample(nn.Upsample):

View File

@@ -196,7 +196,14 @@ class WanRMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
class WanUpsample(nn.Upsample):

View File

@@ -933,6 +933,7 @@ class QwenImageTransformer2DModel(
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
joint_attention_mask = joint_attention_mask[:, None, None, :]
block_attention_kwargs["attention_mask"] = joint_attention_mask
for index_block, block in enumerate(self.transformer_blocks):

View File

@@ -16,22 +16,29 @@ from typing import Callable
import numpy as np
import torch
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
from ...schedulers import UniPCMultistepScheduler
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils import (
is_cosmos_guardrail_available,
is_torch_xla_available,
is_torchvision_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_torchvision_available():
import torchvision.transforms.functional
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:

View File

@@ -29,6 +29,7 @@ from numpy.linalg import norm
from packaging import version
from .constants import DIFFUSERS_REQUEST_TIMEOUT
from .deprecation_utils import deprecate
from .import_utils import (
BACKENDS_MAPPING,
is_accelerate_available,
@@ -67,9 +68,11 @@ else:
global_rng = random.Random()
logger = get_logger(__name__)
logger.warning(
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
deprecate(
"diffusers.utils.testing_utils",
"1.0.0",
"diffusers.utils.testing_utils is deprecated and will be removed in a future version. "
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. ",
)
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version

View File

@@ -481,6 +481,8 @@ class LoraHotSwappingForModelTesterMixin:
# ensure that enable_lora_hotswap is called before loading the first adapter
import logging
from diffusers.utils import logging as diffusers_logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
@@ -488,21 +490,31 @@ class LoraHotSwappingForModelTesterMixin:
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
diffusers_logging.enable_propagation()
try:
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
finally:
diffusers_logging.disable_propagation()
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
# check possibility to ignore the error/warning
import logging
from diffusers.utils import logging as diffusers_logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
diffusers_logging.enable_propagation()
try:
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
finally:
diffusers_logging.disable_propagation()
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
@@ -518,20 +530,26 @@ class LoraHotSwappingForModelTesterMixin:
# check the error and log
import logging
from diffusers.utils import logging as diffusers_logging
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
diffusers_logging.enable_propagation()
try:
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
finally:
diffusers_logging.disable_propagation()
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")

View File

@@ -200,7 +200,6 @@ class ContextParallelTesterMixin:
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.xfail(reason="Context parallel may not support batch_size > 1")
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_batch_inputs(self, cp_type):
self.test_context_parallel_inference(cp_type, batch_size=2)

View File

@@ -286,6 +286,14 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
def test_hotswapping_compiled_model_linear(self):
super().test_hotswapping_compiled_model_linear()
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
def test_hotswapping_compiled_model_both_linear_and_other(self):
super().test_hotswapping_compiled_model_both_linear_and_other()
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]

View File

@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import unittest
import warnings
import pytest
@@ -182,6 +184,25 @@ class DeprecateTester(unittest.TestCase):
assert str(warning.warning) == "This message is better!!!"
assert "diffusers/tests/others/test_utils.py" in warning.filename
def test_deprecate_testing_utils_module(self):
import diffusers.utils.testing_utils
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
importlib.reload(diffusers.utils.testing_utils)
deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, FutureWarning)]
assert len(deprecation_warnings) >= 1, "Expected at least one FutureWarning from diffusers.utils.testing_utils"
messages = [str(w.message) for w in deprecation_warnings]
assert any("diffusers.utils.testing_utils" in msg for msg in messages), (
f"Expected a deprecation warning mentioning 'diffusers.utils.testing_utils', got: {messages}"
)
assert any(
"diffusers.utils.testing_utils is deprecated and will be removed in a future version." in msg
for msg in messages
), f"Expected deprecation message substring not found, got: {messages}"
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
class ExpectationsTester(unittest.TestCase):

View File

@@ -1534,14 +1534,18 @@ class PipelineTesterMixin:
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
model_devices = [
component.device.type for component in components.values() if getattr(component, "device", None)
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to(torch_device)
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
model_devices = [
component.device.type for component in components.values() if getattr(component, "device", None)
]
self.assertTrue(all(device == torch_device for device in model_devices))
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
@@ -1552,11 +1556,11 @@ class PipelineTesterMixin:
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):

View File

@@ -43,7 +43,7 @@ def filter_pipelines(usage_dict, usage_cutoff=10000):
def fetch_pipeline_objects():
models = api.list_models(library="diffusers")
models = api.list_models(filter="diffusers")
downloads = defaultdict(int)
for model in models: