Compare commits

...

3 Commits

Author SHA1 Message Date
DN6
bffa3a9754 update 2025-11-14 15:48:19 +05:30
DN6
1c558712e8 Merge branch 'main' into model-test-refactor 2025-11-12 10:18:07 +05:30
DN6
1f026ad14e update 2025-11-12 10:17:54 +05:30
16 changed files with 3640 additions and 21 deletions

View File

@@ -32,6 +32,20 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_configure(config):
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
config.addinivalue_line("markers", "training: marks tests for training functionality")
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
def pytest_addoption(parser):

View File

@@ -317,9 +317,9 @@ class ModelUtilsTest(unittest.TestCase):
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
)
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
"Model parameters don't match!"
)
assert all(
torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())
), "Model parameters don't match!"
# Remove a shard file
cached_shard_file = try_to_load_from_cache(
@@ -335,9 +335,9 @@ class ModelUtilsTest(unittest.TestCase):
# Verify error mentions the missing shard
error_msg = str(context.exception)
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
f"Expected error about missing shard, got: {error_msg}"
)
assert (
cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg
), f"Expected error about missing shard, got: {error_msg}"
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
@@ -354,9 +354,9 @@ class ModelUtilsTest(unittest.TestCase):
)
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 3, (
"3 HEAD requests one for config, one for model, and one for shard index file."
)
assert (
download_requests.count("HEAD") == 3
), "3 HEAD requests one for config, one for model, and one for shard index file."
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m:
@@ -368,9 +368,9 @@ class ModelUtilsTest(unittest.TestCase):
)
cache_requests = [r.method for r in m.request_history]
assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
"We should call only `model_info` to check for commit hash and knowing if shard index is present."
)
assert (
"HEAD" == cache_requests[0] and len(cache_requests) == 2
), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:

View File

@@ -0,0 +1,37 @@
from .attention import AttentionTesterMixin
from .common import ModelTesterMixin
from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .quantization import (
BitsAndBytesTesterMixin,
GGUFTesterMixin,
ModelOptTesterMixin,
QuantizationTesterMixin,
QuantoTesterMixin,
TorchAoTesterMixin,
)
from .single_file import SingleFileTesterMixin
from .training import TrainingTesterMixin
__all__ = [
"AttentionTesterMixin",
"BitsAndBytesTesterMixin",
"CPUOffloadTesterMixin",
"GGUFTesterMixin",
"GroupOffloadTesterMixin",
"IPAdapterTesterMixin",
"LayerwiseCastingTesterMixin",
"LoraTesterMixin",
"MemoryTesterMixin",
"ModelOptTesterMixin",
"ModelTesterMixin",
"QuantizationTesterMixin",
"QuantoTesterMixin",
"SingleFileTesterMixin",
"TorchAoTesterMixin",
"TorchCompileTesterMixin",
"TrainingTesterMixin",
]

View File

@@ -0,0 +1,180 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import (
AttnProcessor,
)
from ...testing_utils import is_attention, require_accelerator, torch_device
@is_attention
@require_accelerator
class AttentionTesterMixin:
"""
Mixin class for testing attention processor and module functionality on models.
Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests
"""
base_precision = 1e-3
def test_fuse_unfuse_qkv_projections(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
if not hasattr(model, "fuse_qkv_projections"):
pytest.skip("Model does not support QKV projection fusion.")
# Get output before fusion
with torch.no_grad():
output_before_fusion = model(**inputs_dict)
if isinstance(output_before_fusion, dict):
output_before_fusion = output_before_fusion.to_tuple()[0]
# Fuse projections
model.fuse_qkv_projections()
# Verify fusion occurred by checking for fused attributes
has_fused_projections = False
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
has_fused_projections = True
assert module.fused_projections, "fused_projections flag should be True"
break
if has_fused_projections:
# Get output after fusion
with torch.no_grad():
output_after_fusion = model(**inputs_dict)
if isinstance(output_after_fusion, dict):
output_after_fusion = output_after_fusion.to_tuple()[0]
# Verify outputs match
assert torch.allclose(
output_before_fusion, output_after_fusion, atol=self.base_precision
), "Output should not change after fusing projections"
# Unfuse projections
model.unfuse_qkv_projections()
# Verify unfusion occurred
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
assert not module.fused_projections, "fused_projections flag should be False"
# Get output after unfusion
with torch.no_grad():
output_after_unfusion = model(**inputs_dict)
if isinstance(output_after_unfusion, dict):
output_after_unfusion = output_after_unfusion.to_tuple()[0]
# Verify outputs still match
assert torch.allclose(
output_before_fusion, output_after_unfusion, atol=self.base_precision
), "Output should match original after unfusing projections"
def test_get_set_processor(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
# Check if model has attention processors
if not hasattr(model, "attn_processors"):
pytest.skip("Model does not have attention processors.")
# Test getting processors
processors = model.attn_processors
assert isinstance(processors, dict), "attn_processors should return a dict"
assert len(processors) > 0, "Model should have at least one attention processor"
# Test that all processors can be retrieved via get_processor
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
processor = module.get_processor()
assert processor is not None, "get_processor should return a processor"
# Test setting a new processor
new_processor = AttnProcessor()
module.set_processor(new_processor)
retrieved_processor = module.get_processor()
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
def test_attention_processor_dict(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict of new processors
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
# Set processors using dict
model.set_attn_processor(new_processors)
# Verify all processors were set
updated_processors = model.attn_processors
for key in current_processors.keys():
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
def test_attention_processor_count_mismatch_raises_error(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict with wrong number of processors
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
# Verify error is raised
with pytest.raises(ValueError) as exc_info:
model.set_attn_processor(wrong_processors)
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"

View File

@@ -0,0 +1,514 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
from typing import Dict, List, Tuple
import pytest
import torch
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
from ...testing_utils import torch_device
def compute_module_persistent_sizes(
model: nn.Module,
dtype: Optional[Union[str, torch.device]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
):
"""
Compute the size of each submodule of a given model (parameters + persistent buffers).
"""
if dtype is not None:
dtype = _get_proper_dtype(dtype)
dtype_size = dtype_byte_size(dtype)
if special_dtypes is not None:
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
module_list = []
module_list = named_persistent_module_tensors(model, recurse=True)
for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
# According to the code in set_module_tensor_to_device, these types won't be converted
# so use their original size here
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
name_parts = name.split(".")
for idx in range(len(name_parts) + 1):
module_sizes[".".join(name_parts[:idx])] += size
return module_sizes
def calculate_expected_num_shards(index_map_path):
"""
Calculate expected number of shards from index file.
Args:
index_map_path: Path to the sharded checkpoint index file
Returns:
int: Expected number of shards
"""
with open(index_map_path) as f:
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
return expected_num_shards
def check_device_map_is_respected(model, device_map):
for param_name, param in model.named_parameters():
# Find device in device_map
while len(param_name) > 0 and param_name not in device_map:
param_name = ".".join(param_name.split(".")[:-1])
if param_name not in device_map:
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
param_device = device_map[param_name]
if param_device in ["cpu", "disk"]:
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
else:
assert param.device == torch.device(
param_device
), f"Expected device {param_device} for {param_name}, got {param.device}"
class ModelTesterMixin:
"""
Base mixin class for model testing with common test methods.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states")
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
model_class = None
base_precision = 1e-3
model_split_percents = [0.5, 0.7]
def get_init_dict(self):
raise NotImplementedError("get_init_dict must be implemented by subclasses. ")
def get_dummy_inputs(self):
raise NotImplementedError(
"get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict."
)
def test_from_save_pretrained(self, expected_max_diff=5e-5):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
# check if all parameters shape are the same
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
assert (
param_1.shape == param_2.shape
), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
with torch.no_grad():
image = model(**self.get_dummy_inputs())
if isinstance(image, dict):
image = image.to_tuple()[0]
new_image = new_model(**self.get_dummy_inputs())
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().max().item()
assert (
max_diff <= expected_max_diff
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16")
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
# non-variant cannot be loaded
with pytest.raises(OSError) as exc_info:
self.model_class.from_pretrained(tmpdirname)
# make sure that error message states what keys are missing
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
new_model.to(torch_device)
with torch.no_grad():
image = model(**self.get_dummy_inputs())
if isinstance(image, dict):
image = image.to_tuple()[0]
new_image = new_model(**self.get_dummy_inputs())
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().max().item()
assert (
max_diff <= expected_max_diff
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
def test_from_save_pretrained_dtype(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
if torch_device == "mps" and dtype == torch.bfloat16:
continue
with tempfile.TemporaryDirectory() as tmpdirname:
model.to(dtype)
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
if (
hasattr(self.model_class, "_keep_in_fp32_modules")
and self.model_class._keep_in_fp32_modules is None
):
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
new_model = self.model_class.from_pretrained(
tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
)
assert new_model.dtype == dtype
def test_determinism(self, expected_max_diff=1e-5):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with torch.no_grad():
first = model(**self.get_dummy_inputs())
if isinstance(first, dict):
first = first.to_tuple()[0]
second = model(**self.get_dummy_inputs())
if isinstance(second, dict):
second = second.to_tuple()[0]
# Remove NaN values and compute max difference
first_flat = first.flatten()
second_flat = second.flatten()
# Filter out NaN values
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
first_filtered = first_flat[mask]
second_filtered = second_flat[mask]
max_diff = torch.abs(first_filtered - second_filtered).max().item()
assert (
max_diff <= expected_max_diff
), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}"
def test_output(self, expected_output_shape=None):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
assert output is not None, "Model output is None"
assert (
output.shape == expected_output_shape
), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
def test_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
# Track progress in https://github.com/pytorch/pytorch/issues/77764
device = t.device
if device.type == "mps":
t = t.to("cpu")
t[t != t] = 0
return t.to(device)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
assert torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
), (
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
def test_model_config_to_json_string(self):
model = self.model_class(**self.get_init_dict())
json_string = model.config.to_json_string()
assert isinstance(json_string, str), "Config to_json_string should return a string"
assert len(json_string) > 0, "JSON string should not be empty"
@require_accelerator
@pytest.mark.skipif(torch_device not in ["cuda", "xpu"])
def test_from_save_pretrained_float16_bfloat16(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules
with tempfile.TemporaryDirectory() as tmp_dir:
for torch_dtype in [torch.bfloat16, torch.float16]:
model.to(torch_dtype).save_pretrained(tmp_dir)
model_loaded = self.model_class.from_pretrained(tmp_dir, torch_dtype=torch_dtype).to(torch_device)
for name, param in model_loaded.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
assert param.data.dtype == torch.float32
else:
assert param.data.dtype == torch_dtype
with torch.no_grad():
output = model(**get_dummy_inputs())
output_loaded = model_loaded(**get_dummy_inputs())
assert torch.allclose(
output, output_loaded, atol=1e-4
), f"Loaded model output differs for {torch_dtype}"
@require_accelerator
def test_sharded_checkpoints(self):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
assert (
actual_num_shards == expected_num_shards
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
new_model = self.model_class.from_pretrained(tmp_dir).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match after sharded save/load"
@require_accelerator
def test_sharded_checkpoints_with_variant(self):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
assert os.path.exists(
os.path.join(tmp_dir, index_filename)
), f"Variant index file {index_filename} should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
assert (
actual_num_shards == expected_num_shards
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match after variant sharded save/load"
@require_accelerator
def test_sharded_checkpoints_with_parallel_loading(self):
import time
from diffusers.utils import constants
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
# Save original values to restore after test
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
try:
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
assert (
actual_num_shards == expected_num_shards
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
# Load without parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = False
start_time = time.time()
model_sequential = self.model_class.from_pretrained(tmp_dir).eval()
sequential_load_time = time.time() - start_time
model_sequential = model_sequential.to(torch_device)
torch.manual_seed(0)
# Load with parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = True
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
start_time = time.time()
model_parallel = self.model_class.from_pretrained(tmp_dir).eval()
parallel_load_time = time.time() - start_time
model_parallel = model_parallel.to(torch_device)
torch.manual_seed(0)
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel)
assert torch.allclose(
base_output[0], output_parallel[0], atol=1e-5
), "Output should match with parallel loading"
# Verify parallel loading is faster or at least not significantly slower
# For small test models, the difference might be negligible or even slightly slower due to overhead
# so we just check that parallel loading completed successfully and outputs match
assert (
parallel_load_time < sequential_load_time
), f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s"
finally:
# Restore original values
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
if original_parallel_workers is not None:
constants.HF_PARALLEL_WORKERS = original_parallel_workers
@require_torch_multi_accelerator
def test_model_parallelism(self):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will be on GPU 0 and GPU 1
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match with model parallelism"

View File

@@ -0,0 +1,162 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import tempfile
import pytest
import torch
from ...testing_utils import (
backend_empty_cache,
is_torch_compile,
require_accelerator,
require_torch_version_greater,
torch_device,
)
@is_torch_compile
@require_accelerator
@require_torch_version_greater("2.7.1")
class TorchCompileTesterMixin:
"""
Mixin class for testing torch.compile functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: compile
Use `pytest -m "not compile"` to skip these tests
"""
different_shapes_for_compilation = None
def setup_method(self):
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def test_torch_compile_recompilation_and_graph_break(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True)
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
def test_torch_compile_repeated_blocks(self):
if self.model_class._repeated_blocks is None:
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),
torch.no_grad(),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
def test_compile_with_group_offloading(self):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
torch._dynamo.config.cache_size_limit = 10000
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.eval()
group_offload_kwargs = {
"onload_device": torch_device,
"offload_device": "cpu",
"offload_type": "block_level",
"num_blocks_per_group": 1,
"use_stream": True,
"non_blocking": True,
}
model.enable_group_offload(**group_offload_kwargs)
model.compile()
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
def test_compile_on_different_shapes(self):
if self.different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
torch.fx.experimental._config.use_duck_shape = False
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True, dynamic=True)
for height, width in self.different_shapes_for_compilation:
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
inputs_dict = self.get_dummy_inputs(height=height, width=width)
_ = model(**inputs_dict)
def test_compile_works_with_aot(self):
from torch._inductor.package import load_package
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
assert os.path.exists(package_path), f"Package file not created at {package_path}"
loaded_binary = load_package(package_path, run_single_threaded=True)
model.forward = loaded_binary
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)

View File

@@ -0,0 +1,109 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import uuid
import pytest
import torch
from huggingface_hub.utils import is_jinja_available
from ...others.test_utils import TOKEN, USER, is_staging_test
@is_staging_test
class ModelPushToHubTesterMixin:
"""
Mixin class for testing push_to_hub functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
"""
identifier = uuid.uuid4()
repo_id = f"test-model-{identifier}"
org_repo_id = f"valid_org/{repo_id}-org"
def test_push_to_hub(self):
"""Test pushing model to hub and loading it back."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.push_to_hub(self.repo_id, token=TOKEN)
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub and from_pretrained"
# Reset repo
delete_repo(token=TOKEN, repo_id=self.repo_id)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(
p1, p2
), "Parameters don't match after save_pretrained with push_to_hub and from_pretrained"
# Reset repo
delete_repo(self.repo_id, token=TOKEN)
def test_push_to_hub_in_organization(self):
"""Test pushing model to hub in organization namespace."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.push_to_hub(self.org_repo_id, token=TOKEN)
new_model = self.model_class.from_pretrained(self.org_repo_id)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub to org and from_pretrained"
# Reset repo
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)
new_model = self.model_class.from_pretrained(self.org_repo_id)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
assert torch.equal(
p1, p2
), "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained"
# Reset repo
delete_repo(self.org_repo_id, token=TOKEN)
def test_push_to_hub_library_name(self):
"""Test that library_name in model card is set to 'diffusers'."""
if not is_jinja_available():
pytest.skip("Model card tests cannot be performed without Jinja installed.")
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.push_to_hub(self.repo_id, token=TOKEN)
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
assert (
model_card.library_name == "diffusers"
), f"Expected library_name 'diffusers', got {model_card.library_name}"
# Reset repo
delete_repo(self.repo_id, token=TOKEN)

View File

@@ -0,0 +1,205 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import torch
from diffusers.models.attention_processor import IPAdapterAttnProcessor
from ...testing_utils import is_ip_adapter, torch_device
def create_ip_adapter_state_dict(model):
"""
Create a dummy IP Adapter state dict for testing.
Args:
model: The model to create IP adapter weights for
Returns:
dict: IP adapter state dict with to_k_ip and to_v_ip weights
"""
ip_state_dict = {}
key_id = 1
for name in model.attn_processors.keys():
# Skip self-attention processors
cross_attention_dim = getattr(model.config, "cross_attention_dim", None)
if cross_attention_dim is None:
continue
# Get hidden size based on model architecture
hidden_size = getattr(model.config, "hidden_size", cross_attention_dim)
# Create IP adapter processor to get state dict structure
sd = IPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).state_dict()
ip_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
}
)
key_id += 2
return {"ip_adapter": ip_state_dict}
def check_if_ip_adapter_correctly_set(model) -> bool:
"""
Check if IP Adapter processors are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if IP Adapter is correctly set, False otherwise
"""
for module in model.attn_processors.values():
if isinstance(module, IPAdapterAttnProcessor):
return True
return False
@is_ip_adapter
class IPAdapterTesterMixin:
"""
Mixin class for testing IP Adapter functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: ip_adapter
Use `pytest -m "not ip_adapter"` to skip these tests
"""
def create_ip_adapter_state_dict(self, model):
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
def test_load_ip_adapter(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
# Create dummy IP adapter state dict
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
# Load IP adapter
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model), "IP Adapter processors not set correctly"
torch.manual_seed(0)
# Create dummy image embeds for IP adapter
cross_attention_dim = getattr(model.config, "cross_attention_dim", 32)
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
inputs_dict_with_adapter = inputs_dict.copy()
inputs_dict_with_adapter["image_embeds"] = image_embeds
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
assert not torch.allclose(
output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4
), "Output should differ with IP Adapter enabled"
def test_ip_adapter_scale(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Create and load dummy IP adapter state dict
ip_adapter_state_dict = create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
# Test scale = 0.0 (no effect)
model.set_ip_adapter_scale(0.0)
torch.manual_seed(0)
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Test scale = 1.0 (full effect)
model.set_ip_adapter_scale(1.0)
torch.manual_seed(0)
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Outputs should differ with different scales
assert not torch.allclose(
output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4
), "Output should differ with different IP Adapter scales"
def test_unload_ip_adapter(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
# Save original processors
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
# Create and load IP adapter
ip_adapter_state_dict = create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be set"
# Unload IP adapter
model.unload_ip_adapter()
assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded"
# Verify processors are restored
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
assert original_processors == current_processors, "Processors should be restored after unload"
def test_ip_adapter_save_load(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Create and load IP adapter
ip_adapter_state_dict = self.create_ip_adapter_state_dict()
model._load_ip_adapter_weights([ip_adapter_state_dict])
torch.manual_seed(0)
output_before_save = model(**inputs_dict, return_dict=False)[0]
with tempfile.TemporaryDirectory() as tmpdir:
# Save the IP adapter weights
save_path = os.path.join(tmpdir, "ip_adapter.safetensors")
import safetensors.torch
safetensors.torch.save_file(ip_adapter_state_dict["ip_adapter"], save_path)
# Unload and reload
model.unload_ip_adapter()
assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded"
# Reload from saved file
loaded_state_dict = {"ip_adapter": safetensors.torch.load_file(save_path)}
model._load_ip_adapter_weights([loaded_state_dict])
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be loaded"
torch.manual_seed(0)
output_after_load = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Outputs should match before and after save/load
assert torch.allclose(
output_before_save, output_after_load, atol=1e-4, rtol=1e-4
), "Output should match before and after save/load"

View File

@@ -0,0 +1,220 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import pytest
import safetensors.torch
import torch
from diffusers.utils.testing_utils import check_if_dicts_are_equal
from ...testing_utils import is_lora, require_peft_backend, torch_device
def check_if_lora_correctly_set(model) -> bool:
"""
Check if LoRA layers are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if LoRA is correctly set, False otherwise
"""
from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
@is_lora
@require_peft_backend
class LoraTesterMixin:
"""
Mixin class for testing LoRA/PEFT functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: lora
Use `pytest -m "not lora"` to skip these tests
"""
def setup_method(self):
from diffusers.loaders.peft import PeftAdapterMixin
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(
output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4
), "Output should differ with LoRA enabled"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
assert os.path.isfile(
os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
), "LoRA weights file not created"
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
assert torch.allclose(loaded_v, retrieved_v), f"Mismatch in LoRA weight {k}"
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(
output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4
), "Output should differ with LoRA enabled"
assert torch.allclose(
outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4
), "Outputs should match before and after save/load"
def test_lora_wrong_adapter_name_raises_error(self):
from peft import LoraConfig
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
wrong_name = "foo"
with pytest.raises(ValueError) as exc_info:
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
def test_lora_adapter_metadata_is_loaded_correctly(self, rank=4, lora_alpha=4, use_dora=False):
from peft import LoraConfig
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
def test_lora_adapter_wrong_metadata_raises_error(self):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
# Perturb the metadata in the state dict
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
with pytest.raises(TypeError) as exc_info:
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)

View File

@@ -0,0 +1,443 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import glob
import inspect
import tempfile
from functools import wraps
import pytest
import torch
from accelerate.utils.modeling import compute_module_sizes
from diffusers.utils.testing_utils import _check_safetensors_serialization
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
backend_synchronize,
is_cpu_offload,
is_group_offload,
is_memory,
require_accelerator,
torch_device,
)
from .common import check_device_map_is_respected
def cast_maybe_tensor_dtype(inputs_dict, from_dtype, to_dtype):
"""Helper to cast tensor inputs from one dtype to another."""
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor) and value.dtype == from_dtype:
inputs_dict[key] = value.to(to_dtype)
return inputs_dict
def require_offload_support(func):
"""
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
return func(self, *args, **kwargs)
return wrapper
def require_group_offload_support(func):
"""
Decorator to skip tests if model doesn't support group offloading.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
return func(self, *args, **kwargs)
return wrapper
@is_cpu_offload
class CPUOffloadTesterMixin:
"""
Mixin class for testing CPU offloading functionality.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- model_split_percents: List of percentages for splitting model across devices
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cpu_offload
Use `pytest -m "not cpu_offload"` to skip these tests
"""
model_split_percents = [0.5, 0.7]
@require_offload_support
def test_cpu_offload(self):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match with CPU offloading"
@require_offload_support
def test_disk_offload_without_safetensors(self):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
# Force disk offload by setting very small CPU memory
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
# This errors out because it's missing an offload folder
with pytest.raises(ValueError):
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), "Output should match with disk offloading"
@require_offload_support
def test_disk_offload_with_safetensors(self):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(
base_output[0], new_output[0], atol=1e-5
), "Output should match with disk offloading (safetensors)"
@is_group_offload
class GroupOffloadTesterMixin:
"""
Mixin class for testing group offloading functionality.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: group_offload
Use `pytest -m "not group_offload"` to skip these tests
"""
@require_group_offload_support
def test_group_offloading(self, record_stream=False):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
@torch.no_grad()
def run_forward(model):
assert all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
), "Group offloading hook should be set"
model.eval()
return model(**inputs_dict)[0]
model = self.model_class(**init_dict)
model.to(torch_device)
output_without_group_offloading = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
output_with_group_offloading1 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
output_with_group_offloading2 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level")
output_with_group_offloading3 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
)
output_with_group_offloading4 = run_forward(model)
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading1, atol=1e-5
), "Output should match with block-level offloading"
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading2, atol=1e-5
), "Output should match with non-blocking block-level offloading"
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading3, atol=1e-5
), "Output should match with leaf-level offloading"
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading4, atol=1e-5
), "Output should match with leaf-level offloading with stream"
@require_group_offload_support
@torch.no_grad()
def test_group_offloading_with_layerwise_casting(self, record_stream=False, offload_type="block_level"):
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
_ = model(**inputs_dict)[0]
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
storage_dtype, compute_dtype = torch.float16, torch.float32
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
model.enable_group_offload(
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
@require_group_offload_support
@torch.no_grad()
@torch.inference_mode()
def test_group_offloading_with_disk(self, offload_type="block_level", record_stream=False, atol=1e-5):
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
return "generator" in params
def _run_forward(model, inputs_dict):
accepts_generator = _has_generator_arg(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
return model(**inputs_dict)[0]
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
model.to(torch_device)
output_without_group_offloading = _run_forward(model, inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
num_blocks_per_group = None if offload_type == "leaf_level" else 1
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
with tempfile.TemporaryDirectory() as tmpdir:
model.enable_group_offload(
torch_device,
offload_type=offload_type,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
# in nature. So, skip it.
if offload_type != "leaf_level":
is_correct, extra_files, missing_files = _check_safetensors_serialization(
module=model,
offload_to_disk_path=tmpdir,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
)
if not is_correct:
if extra_files:
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
elif missing_files:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
assert torch.allclose(
output_without_group_offloading, output_with_group_offloading, atol=atol
), "Output should match with disk-based group offloading"
class LayerwiseCastingTesterMixin:
"""
Mixin class for testing layerwise dtype casting for memory optimization.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
@torch.no_grad()
def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0
def reset_memory_stats():
gc.collect()
backend_synchronize(torch_device)
backend_empty_cache(torch_device)
backend_reset_peak_memory_stats(torch_device)
def get_memory_usage(storage_dtype, compute_dtype):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**config).eval()
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
reset_memory_stats()
model(**inputs_dict)
model_memory_footprint = model.get_memory_footprint()
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
return model_memory_footprint, peak_inference_memory_allocated_mb
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
torch.float8_e4m3fn, torch.bfloat16
)
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
assert (
fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint
), "Memory footprint should decrease with lower precision storage"
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
assert (
fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory
), "Peak memory should be lower with bf16 compute on newer GPUs"
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
# bytes. This only happens for some models, so we allow a small tolerance.
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
assert (
fp8_e4m3_fp32_max_memory < fp32_max_memory
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
), "Peak memory should be lower or within tolerance with fp8 storage"
@is_memory
@require_accelerator
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
"""
Combined mixin class for all memory optimization tests including CPU/disk offloading,
group offloading, and layerwise dtype casting.
This mixin inherits from:
- CPUOffloadTesterMixin: CPU and disk offloading tests
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: memory
Use `pytest -m "not memory"` to skip these tests
"""
pass

View File

@@ -0,0 +1,833 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import pytest
import torch
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
from diffusers.utils.import_utils import (
is_bitsandbytes_available,
is_gguf_available,
is_nvidia_modelopt_available,
is_optimum_quanto_available,
)
from ...testing_utils import (
backend_empty_cache,
is_bitsandbytes,
is_gguf,
is_modelopt,
is_quanto,
is_torchao,
nightly,
require_accelerate,
require_accelerator,
require_bitsandbytes_version_greater,
require_gguf_version_greater_or_equal,
require_quanto,
require_torchao_version_greater_or_equal,
torch_device,
)
if is_nvidia_modelopt_available():
import modelopt.torch.quantization as mtq
if is_bitsandbytes_available():
import bitsandbytes as bnb
if is_optimum_quanto_available():
from optimum.quanto import QLinear
if is_gguf_available():
pass
if is_torchao_available():
if is_torchao_version(">=", "0.9.0"):
pass
@require_accelerator
class QuantizationTesterMixin:
"""
Base mixin class providing common test implementations for quantization testing.
Backend-specific mixins should:
1. Implement _create_quantized_model(config_kwargs)
2. Implement _verify_if_layer_quantized(name, module, config_kwargs)
3. Define their config dict (e.g., BNB_CONFIGS, QUANTO_WEIGHT_TYPES, etc.)
4. Use @pytest.mark.parametrize to create tests that call the common test methods below
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
Expected methods in test classes:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
"""
Create a quantized model with the given config kwargs.
Args:
config_kwargs: Quantization config parameters
**extra_kwargs: Additional kwargs to pass to from_pretrained (e.g., device_map, offload_folder)
"""
raise NotImplementedError("Subclass must implement _create_quantized_model")
def _verify_if_layer_quantized(self, name, module, config_kwargs):
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
def _is_module_quantized(self, module):
"""
Check if a module is quantized. Returns True if quantized, False otherwise.
Default implementation tries _verify_if_layer_quantized and catches exceptions.
Subclasses can override for more efficient checking.
"""
try:
self._verify_if_layer_quantized("", module, {})
return True
except (AssertionError, AttributeError):
return False
def _load_unquantized_model(self):
kwargs = getattr(self, "pretrained_model_kwargs", {})
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _test_quantization_num_parameters(self, config_kwargs):
model = self._load_unquantized_model()
num_params = model.num_parameters()
model_quantized = self._create_quantized_model(config_kwargs)
num_params_quantized = model_quantized.num_parameters()
assert (
num_params == num_params_quantized
), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2):
model = self._load_unquantized_model()
mem = model.get_memory_footprint()
model_quantized = self._create_quantized_model(config_kwargs)
mem_quantized = model_quantized.get_memory_footprint()
ratio = mem / mem_quantized
assert (
ratio >= expected_memory_reduction
), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
def _test_quantization_inference(self, config_kwargs):
model_quantized = self._create_quantized_model(config_kwargs)
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model_quantized(**inputs)
if isinstance(output, tuple):
output = output[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
def _test_quantization_dtype_assignment(self, config_kwargs):
model = self._create_quantized_model(config_kwargs)
with pytest.raises(ValueError):
model.to(torch.float16)
with pytest.raises(ValueError):
device_0 = f"{torch_device}:0"
model.to(device=device_0, dtype=torch.float16)
with pytest.raises(ValueError):
model.float()
with pytest.raises(ValueError):
model.half()
model.to(torch_device)
def _test_quantization_lora_inference(self, config_kwargs):
try:
from peft import LoraConfig
except ImportError:
pytest.skip("peft is not available")
from diffusers.loaders.peft import PeftAdapterMixin
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__})")
model = self._create_quantized_model(config_kwargs)
lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
)
model.add_adapter(lora_config)
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model(**inputs)
if isinstance(output, tuple):
output = output[0]
assert output is not None, "Model output is None with LoRA"
assert not torch.isnan(output).any(), "Model output contains NaN with LoRA"
def _test_quantization_serialization(self, config_kwargs):
model = self._create_quantized_model(config_kwargs)
with tempfile.TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir, safe_serialization=True)
model_loaded = self.model_class.from_pretrained(tmpdir)
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model_loaded(**inputs)
if isinstance(output, tuple):
output = output[0]
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
def _test_quantized_layers(self, config_kwargs):
model_fp = self._load_unquantized_model()
num_linear_layers = sum(1 for module in model_fp.modules() if isinstance(module, torch.nn.Linear))
model_quantized = self._create_quantized_model(config_kwargs)
num_fp32_modules = 0
if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules:
for name, module in model_quantized.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules):
num_fp32_modules += 1
expected_quantized_layers = num_linear_layers - num_fp32_modules
num_quantized_layers = 0
for name, module in model_quantized.named_modules():
if isinstance(module, torch.nn.Linear):
if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules:
if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules):
continue
self._verify_if_layer_quantized(name, module, config_kwargs)
num_quantized_layers += 1
assert (
num_quantized_layers > 0
), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
assert (
num_quantized_layers == expected_quantized_layers
), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
"""
Test that modules specified in modules_to_not_convert are not quantized.
Args:
config_kwargs: Base quantization config kwargs
modules_to_not_convert: List of module names to exclude from quantization
"""
# Create config with modules_to_not_convert
config_kwargs_with_exclusion = config_kwargs.copy()
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
# Find a module that should NOT be quantized
found_excluded = False
for name, module in model_with_exclusion.named_modules():
if isinstance(module, torch.nn.Linear):
# Check if this module is in the exclusion list
if any(excluded in name for excluded in modules_to_not_convert):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(
module
), f"Module {name} should not be quantized but was found to be quantized"
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}"
# Find a module that SHOULD be quantized (not in exclusion list)
found_quantized = False
for name, module in model_with_exclusion.named_modules():
if isinstance(module, torch.nn.Linear):
# Check if this module is NOT in the exclusion list
if not any(excluded in name for excluded in modules_to_not_convert):
if self._is_module_quantized(module):
found_quantized = True
break
assert found_quantized, "No quantized layers found outside of excluded modules"
# Compare memory footprint with fully quantized model
model_fully_quantized = self._create_quantized_model(config_kwargs)
mem_with_exclusion = model_with_exclusion.get_memory_footprint()
mem_fully_quantized = model_fully_quantized.get_memory_footprint()
assert (
mem_with_exclusion > mem_fully_quantized
), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
def _test_quantization_device_map(self, config_kwargs):
"""
Test that quantized models work correctly with device_map="auto".
Args:
config_kwargs: Base quantization config kwargs
"""
model = self._create_quantized_model(config_kwargs, device_map="auto")
# Verify device map is set
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute"
assert model.hf_device_map is not None, "hf_device_map should not be None"
# Verify inference works
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model(**inputs)
if isinstance(output, tuple):
output = output[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
@is_bitsandbytes
@nightly
@require_accelerator
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
class BitsAndBytesTesterMixin(QuantizationTesterMixin):
"""
Mixin class for testing BitsAndBytes quantization on models.
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
Expected methods to be implemented by subclasses:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Optional class attributes:
- BNB_CONFIGS: Dict of config name -> BitsAndBytesConfig kwargs to test
Pytest mark: bitsandbytes
Use `pytest -m "not bitsandbytes"` to skip these tests
"""
# Standard BnB configs tested for all models
# Subclasses can override to add or modify configs
BNB_CONFIGS = {
"4bit_nf4": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.float16,
},
"4bit_fp4": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "fp4",
"bnb_4bit_compute_dtype": torch.float16,
},
"8bit": {
"load_in_8bit": True,
},
}
BNB_EXPECTED_MEMORY_REDUCTIONS = {
"4bit_nf4": 3.0,
"4bit_fp4": 3.0,
"8bit": 1.5,
}
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
config = BitsAndBytesConfig(**config_kwargs)
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
kwargs["quantization_config"] = config
kwargs.update(extra_kwargs)
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs):
expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params
assert (
module.weight.__class__ == expected_weight_class
), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
def test_bnb_quantization_num_parameters(self, config_name):
self._test_quantization_num_parameters(self.BNB_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
def test_bnb_quantization_memory_footprint(self, config_name):
expected = self.BNB_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
self._test_quantization_memory_footprint(self.BNB_CONFIGS[config_name], expected_memory_reduction=expected)
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
def test_bnb_quantization_inference(self, config_name):
self._test_quantization_inference(self.BNB_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
def test_bnb_quantization_dtype_assignment(self, config_name):
self._test_quantization_dtype_assignment(self.BNB_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
def test_bnb_quantization_lora_inference(self, config_name):
self._test_quantization_lora_inference(self.BNB_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
def test_bnb_quantization_serialization(self, config_name):
self._test_quantization_serialization(self.BNB_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
def test_bnb_quantized_layers(self, config_name):
self._test_quantized_layers(self.BNB_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
def test_bnb_quantization_config_serialization(self, config_name):
model = self._create_quantized_model(self.BNB_CONFIGS[config_name])
assert "quantization_config" in model.config, "Missing quantization_config"
_ = model.config["quantization_config"].to_dict()
_ = model.config["quantization_config"].to_diff_dict()
_ = model.config["quantization_config"].to_json_string()
def test_bnb_original_dtype(self):
config_name = list(self.BNB_CONFIGS.keys())[0]
config_kwargs = self.BNB_CONFIGS[config_name]
model = self._create_quantized_model(config_kwargs)
assert "_pre_quantization_dtype" in model.config, "Missing _pre_quantization_dtype"
assert model.config["_pre_quantization_dtype"] in [
torch.float16,
torch.float32,
torch.bfloat16,
], f"Unexpected dtype: {model.config['_pre_quantization_dtype']}"
def test_bnb_keep_modules_in_fp32(self):
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
config_kwargs = self.BNB_CONFIGS["4bit_nf4"]
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
self.model_class._keep_in_fp32_modules = ["proj_out"]
try:
model = self._create_quantized_model(config_kwargs)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
assert (
module.weight.dtype == torch.float32
), f"Module {name} should be FP32 but is {module.weight.dtype}"
else:
assert (
module.weight.dtype == torch.uint8
), f"Module {name} should be uint8 but is {module.weight.dtype}"
with torch.no_grad():
inputs = self.get_dummy_inputs()
_ = model(**inputs)
finally:
if original_fp32_modules is not None:
self.model_class._keep_in_fp32_modules = original_fp32_modules
def test_bnb_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
if modules_to_exclude is None:
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(self.BNB_CONFIGS["4bit_nf4"], modules_to_exclude)
def test_bnb_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(self.BNB_CONFIGS["4bit_nf4"])
@is_quanto
@nightly
@require_quanto
@require_accelerate
@require_accelerator
class QuantoTesterMixin(QuantizationTesterMixin):
"""
Mixin class for testing Quanto quantization on models.
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
Expected methods to be implemented by subclasses:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Optional class attributes:
- QUANTO_WEIGHT_TYPES: Dict of weight_type_name -> qtype
Pytest mark: quanto
Use `pytest -m "not quanto"` to skip these tests
"""
QUANTO_WEIGHT_TYPES = {
"float8": {"weights_dtype": "float8"},
"int8": {"weights_dtype": "int8"},
"int4": {"weights_dtype": "int4"},
"int2": {"weights_dtype": "int2"},
}
QUANTO_EXPECTED_MEMORY_REDUCTIONS = {
"float8": 1.5,
"int8": 1.5,
"int4": 3.0,
"int2": 7.0,
}
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
config = QuantoConfig(**config_kwargs)
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
kwargs["quantization_config"] = config
kwargs.update(extra_kwargs)
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs):
assert isinstance(module, QLinear), f"Layer {name} is not QLinear, got {type(module)}"
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
def test_quanto_quantization_num_parameters(self, weight_type_name):
self._test_quantization_num_parameters(self.QUANTO_WEIGHT_TYPES[weight_type_name])
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
def test_quanto_quantization_memory_footprint(self, weight_type_name):
expected = self.QUANTO_EXPECTED_MEMORY_REDUCTIONS.get(weight_type_name, 1.2)
self._test_quantization_memory_footprint(
self.QUANTO_WEIGHT_TYPES[weight_type_name], expected_memory_reduction=expected
)
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
def test_quanto_quantization_inference(self, weight_type_name):
self._test_quantization_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name])
@pytest.mark.parametrize("weight_type_name", ["int8"])
def test_quanto_quantized_layers(self, weight_type_name):
self._test_quantized_layers(self.QUANTO_WEIGHT_TYPES[weight_type_name])
@pytest.mark.parametrize("weight_type_name", ["int8"])
def test_quanto_quantization_lora_inference(self, weight_type_name):
self._test_quantization_lora_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name])
@pytest.mark.parametrize("weight_type_name", ["int8"])
def test_quanto_quantization_serialization(self, weight_type_name):
self._test_quantization_serialization(self.QUANTO_WEIGHT_TYPES[weight_type_name])
def test_quanto_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
if modules_to_exclude is None:
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(self.QUANTO_WEIGHT_TYPES["int8"], modules_to_exclude)
def test_quanto_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(self.QUANTO_WEIGHT_TYPES["int8"])
@is_torchao
@require_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoTesterMixin(QuantizationTesterMixin):
"""
Mixin class for testing TorchAO quantization on models.
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
Expected methods to be implemented by subclasses:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Optional class attributes:
- TORCHAO_QUANT_TYPES: Dict of quantization type strings to test
Pytest mark: torchao
Use `pytest -m "not torchao"` to skip these tests
"""
TORCHAO_QUANT_TYPES = {
"int4wo": {"quant_type": "int4_weight_only"},
"int8wo": {"quant_type": "int8_weight_only"},
"int8dq": {"quant_type": "int8_dynamic_activation_int8_weight"},
}
TORCHAO_EXPECTED_MEMORY_REDUCTIONS = {
"int4wo": 3.0,
"int8wo": 1.5,
"int8dq": 1.5,
}
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
config = TorchAoConfig(**config_kwargs)
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
kwargs["quantization_config"] = config
kwargs.update(extra_kwargs)
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs):
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
def test_torchao_quantization_num_parameters(self, quant_type):
self._test_quantization_num_parameters(self.TORCHAO_QUANT_TYPES[quant_type])
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
def test_torchao_quantization_memory_footprint(self, quant_type):
expected = self.TORCHAO_EXPECTED_MEMORY_REDUCTIONS.get(quant_type, 1.2)
self._test_quantization_memory_footprint(
self.TORCHAO_QUANT_TYPES[quant_type], expected_memory_reduction=expected
)
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
def test_torchao_quantization_inference(self, quant_type):
self._test_quantization_inference(self.TORCHAO_QUANT_TYPES[quant_type])
@pytest.mark.parametrize("quant_type", ["int8wo"])
def test_torchao_quantized_layers(self, quant_type):
self._test_quantized_layers(self.TORCHAO_QUANT_TYPES[quant_type])
@pytest.mark.parametrize("quant_type", ["int8wo"])
def test_torchao_quantization_lora_inference(self, quant_type):
self._test_quantization_lora_inference(self.TORCHAO_QUANT_TYPES[quant_type])
@pytest.mark.parametrize("quant_type", ["int8wo"])
def test_torchao_quantization_serialization(self, quant_type):
self._test_quantization_serialization(self.TORCHAO_QUANT_TYPES[quant_type])
def test_torchao_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""
# Get a module name that exists in the model - this needs to be set by test classes
# For now, use a generic pattern that should work with transformer models
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
if modules_to_exclude is None:
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(
self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
)
def test_torchao_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(self.TORCHAO_QUANT_TYPES["int8wo"])
@is_gguf
@nightly
@require_accelerate
@require_accelerator
@require_gguf_version_greater_or_equal("0.10.0")
class GGUFTesterMixin(QuantizationTesterMixin):
"""
Mixin class for testing GGUF quantization on models.
Expected class attributes:
- model_class: The model class to test
- gguf_filename: URL or path to the GGUF file
Expected methods to be implemented by subclasses:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: gguf
Use `pytest -m "not gguf"` to skip these tests
"""
gguf_filename = None
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
if config_kwargs is None:
config_kwargs = {"compute_dtype": torch.bfloat16}
config = GGUFQuantizationConfig(**config_kwargs)
kwargs = {
"quantization_config": config,
"torch_dtype": config_kwargs.get("compute_dtype", torch.bfloat16),
}
kwargs.update(extra_kwargs)
return self.model_class.from_single_file(self.gguf_filename, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs=None):
from diffusers.quantizers.gguf.utils import GGUFParameter
assert isinstance(module.weight, GGUFParameter), f"{name} weight is not GGUFParameter"
assert hasattr(module.weight, "quant_type"), f"{name} weight missing quant_type"
assert module.weight.dtype == torch.uint8, f"{name} weight dtype should be uint8"
def test_gguf_quantization_inference(self):
self._test_quantization_inference({"compute_dtype": torch.bfloat16})
def test_gguf_keep_modules_in_fp32(self):
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
_keep_in_fp32_modules = self.model_class._keep_in_fp32_modules
self.model_class._keep_in_fp32_modules = ["proj_out"]
try:
model = self._create_quantized_model()
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
assert module.weight.dtype == torch.float32, f"Module {name} should be FP32"
finally:
self.model_class._keep_in_fp32_modules = _keep_in_fp32_modules
def test_gguf_quantization_dtype_assignment(self):
self._test_quantization_dtype_assignment({"compute_dtype": torch.bfloat16})
def test_gguf_quantization_lora_inference(self):
self._test_quantization_lora_inference({"compute_dtype": torch.bfloat16})
def test_gguf_dequantize_model(self):
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
model = self._create_quantized_model()
model.dequantize()
def _check_for_gguf_linear(model):
has_children = list(model.children())
if not has_children:
return
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear"
assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter"
for name, module in model.named_children():
_check_for_gguf_linear(module)
def test_gguf_quantized_layers(self):
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
@is_modelopt
@nightly
@require_accelerator
@require_accelerate
@require_modelopt_version_greater_or_equal("0.33.1")
class ModelOptTesterMixin(QuantizationTesterMixin):
"""
Mixin class for testing NVIDIA ModelOpt quantization on models.
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
Expected methods to be implemented by subclasses:
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Optional class attributes:
- MODELOPT_CONFIGS: Dict of config name -> NVIDIAModelOptConfig kwargs to test
Pytest mark: modelopt
Use `pytest -m "not modelopt"` to skip these tests
"""
MODELOPT_CONFIGS = {
"fp8": {"quant_type": "FP8"},
"int8": {"quant_type": "INT8"},
"int4": {"quant_type": "INT4"},
}
MODELOPT_EXPECTED_MEMORY_REDUCTIONS = {
"fp8": 1.5,
"int8": 1.5,
"int4": 3.0,
}
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
config = NVIDIAModelOptConfig(**config_kwargs)
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
kwargs["quantization_config"] = config
kwargs.update(extra_kwargs)
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs):
assert mtq.utils.is_quantized(module), f"Layer {name} does not have weight_quantizer attribute (not quantized)"
@pytest.mark.parametrize("config_name", ["fp8"])
def test_modelopt_quantization_num_parameters(self, config_name):
self._test_quantization_num_parameters(self.MODELOPT_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()))
def test_modelopt_quantization_memory_footprint(self, config_name):
expected = self.MODELOPT_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
self._test_quantization_memory_footprint(
self.MODELOPT_CONFIGS[config_name], expected_memory_reduction=expected
)
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()))
def test_modelopt_quantization_inference(self, config_name):
self._test_quantization_inference(self.MODELOPT_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["fp8"])
def test_modelopt_quantization_dtype_assignment(self, config_name):
self._test_quantization_dtype_assignment(self.MODELOPT_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["fp8"])
def test_modelopt_quantization_lora_inference(self, config_name):
self._test_quantization_lora_inference(self.MODELOPT_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["fp8"])
def test_modelopt_quantization_serialization(self, config_name):
self._test_quantization_serialization(self.MODELOPT_CONFIGS[config_name])
@pytest.mark.parametrize("config_name", ["fp8"])
def test_modelopt_quantized_layers(self, config_name):
self._test_quantized_layers(self.MODELOPT_CONFIGS[config_name])
def test_modelopt_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
if modules_to_exclude is None:
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(self.MODELOPT_CONFIGS["fp8"], modules_to_exclude)
def test_modelopt_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(self.MODELOPT_CONFIGS["fp8"])

View File

@@ -0,0 +1,247 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from ...testing_utils import (
backend_empty_cache,
is_single_file,
nightly,
require_torch_accelerator,
torch_device,
)
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
"""Download a single file checkpoint from the Hub to a temporary directory."""
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
return path
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
"""Download diffusers config files (excluding weights) from a repository."""
path = snapshot_download(
pretrained_model_name_or_path,
ignore_patterns=[
"**/*.ckpt",
"*.ckpt",
"**/*.bin",
"*.bin",
"**/*.pt",
"*.pt",
"**/*.safetensors",
"*.safetensors",
],
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
local_dir=tmpdir,
)
return path
@nightly
@require_torch_accelerator
@is_single_file
class SingleFileTesterMixin:
"""
Mixin class for testing single file loading for models.
Expected class attributes:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- ckpt_path: Path or Hub path to the single file checkpoint
- subfolder: (Optional) Subfolder within the repo
- torch_dtype: (Optional) torch dtype to use for testing
Pytest mark: single_file
Use `pytest -m "not single_file"` to skip these tests
"""
pretrained_model_name_or_path = None
ckpt_path = None
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_model_config(self):
pretrained_kwargs = {}
single_file_kwargs = {}
pretrained_kwargs["device"] = torch_device
single_file_kwargs["device"] = torch_device
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between pretrained loading and single file loading: "
f"pretrained={model.config[param_name]}, single_file={param_value}"
)
def test_single_file_model_parameters(self):
pretrained_kwargs = {}
single_file_kwargs = {}
pretrained_kwargs["device"] = torch_device
single_file_kwargs["device"] = torch_device
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
state_dict = model.state_dict()
state_dict_single_file = model_single_file.state_dict()
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
"Model parameters keys differ between pretrained and single file loading. "
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
)
for key in state_dict.keys():
param = state_dict[key]
param_single_file = state_dict_single_file[key]
assert param.shape == param_single_file.shape, (
f"Parameter shape mismatch for {key}: "
f"pretrained {param.shape} vs single file {param_single_file.shape}"
)
assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
f"Parameter values differ for {key}: "
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
)
def test_single_file_loading_local_files_only(self):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
with tempfile.TemporaryDirectory() as tmpdir:
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
model_single_file = self.model_class.from_single_file(
local_ckpt_path, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with local_files_only=True"
def test_single_file_loading_with_diffusers_config(self):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
# Load with config parameter
model_single_file = self.model_class.from_single_file(
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
)
# Load pretrained for comparison
pretrained_kwargs = {}
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
# Compare configs
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
def test_single_file_loading_with_diffusers_config_local_files_only(self):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
with tempfile.TemporaryDirectory() as tmpdir:
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, tmpdir)
model_single_file = self.model_class.from_single_file(
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
def test_single_file_loading_dtype(self):
for dtype in [torch.float32, torch.float16]:
if torch_device == "mps" and dtype == torch.bfloat16:
continue
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
# Cleanup
del model_single_file
gc.collect()
backend_empty_cache(torch_device)
def test_checkpoint_variant_loading(self):
if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths:
return
for ckpt_path in self.alternate_ckpt_paths:
backend_empty_cache(torch_device)
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
del model
gc.collect()
backend_empty_cache(torch_device)

View File

@@ -0,0 +1,224 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import pytest
import torch
from diffusers.training_utils import EMAModel
from ...testing_utils import is_training, require_torch_accelerator_with_training, torch_all_close, torch_device
@is_training
@require_torch_accelerator_with_training
class TrainingTesterMixin:
"""
Mixin class for testing training functionality on models.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Expected properties to be implemented by subclasses:
- output_shape: Tuple defining the expected output shape
Pytest mark: training
Use `pytest -m "not training"` to skip these tests
"""
def test_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
def test_training_with_ema(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
ema_model = EMAModel(model.parameters())
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
ema_model.step(model.parameters())
def test_gradient_checkpointing(self):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
init_dict = self.get_init_dict()
# at init model should have gradient checkpointing disabled
model = self.model_class(**init_dict)
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
# check enable works
model.enable_gradient_checkpointing()
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
# check disable works
model.disable_gradient_checkpointing()
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
def test_gradient_checkpointing_is_applied(self, expected_set=None):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
if expected_set is None:
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
init_dict = self.get_init_dict()
model_class_copy = copy.copy(self.model_class)
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
modules_with_gc_enabled = {}
for submodule in model.modules():
if hasattr(submodule, "gradient_checkpointing"):
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
modules_with_gc_enabled[submodule.__class__.__name__] = True
assert set(modules_with_gc_enabled.keys()) == expected_set, (
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} "
f"do not match expected set {expected_set}"
)
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
if skip is None:
skip = set()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict_copy = copy.deepcopy(inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict)
if isinstance(out, dict):
out = out.sample if hasattr(out, "sample") else out.to_tuple()[0]
# run the backwards pass on the model
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
torch.manual_seed(0)
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict_copy)
if isinstance(out_2, dict):
out_2 = out_2.sample if hasattr(out_2, "sample") else out_2.to_tuple()[0]
# run the backwards pass on the model
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
assert (
loss - loss_2
).abs() < loss_tolerance, f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
if "post_quant_conv" in name:
continue
if name in skip:
continue
if param.grad is None:
continue
assert torch_all_close(
param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol
), f"Gradient mismatch for {name}"
def test_mixed_precision_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
# Test with float16
if torch.device(torch_device).type != "cpu":
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
# Test with bfloat16
if torch.device(torch_device).type != "cpu":
model.zero_grad()
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()

View File

@@ -0,0 +1,316 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BitsAndBytesTesterMixin,
GGUFTesterMixin,
IPAdapterTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class FluxTransformerTesterConfig:
model_class = FluxTransformer2DModel
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
pretrained_model_kwargs = {"subfolder": "transformer"}
def get_init_dict(self):
"""Return Flux model initialization arguments."""
return {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"pooled_projection_dim": 32,
"axes_dims_rope": [4, 4, 8],
}
def get_dummy_inputs(self):
batch_size = 1
height = width = 4
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
"img_ids": randn_tensor((height * width, num_image_channels)),
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
@property
def input_shape(self):
return (16, 4)
@property
def output_shape(self):
return (16, 4)
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
def test_deprecated_inputs_img_txt_ids_3d(self):
"""Test that deprecated 3D img_ids and txt_ids still work."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output_1 = model(**inputs_dict).to_tuple()[0]
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
inputs_dict["txt_ids"] = text_ids_3d
inputs_dict["img_ids"] = image_ids_3d
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
assert output_1.shape == output_2.shape
assert torch.allclose(output_1, output_2, atol=1e-5), (
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
"are not equal as them as 2d inputs"
)
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux Transformer."""
pass
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
pass
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
pass
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for Flux Transformer."""
def create_ip_adapter_state_dict(self, model):
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
ip_cross_attn_state_dict = {}
key_id = 0
for name in model.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
continue
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=(
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
),
num_image_text_embeds=4,
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux Transformer."""
pass
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux Transformer."""
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height=4, width=4):
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
"img_ids": randn_tensor((height * width, num_image_channels)),
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height=4, width=4):
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
"img_ids": randn_tensor((height * width, num_image_channels)),
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
subfolder = "transformer"
pass
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
def get_dummy_inputs(self):
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
def get_dummy_inputs(self):
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
def get_dummy_inputs(self):
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
def get_dummy_inputs(self):
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
def get_dummy_inputs(self):
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}

View File

@@ -98,9 +98,9 @@ class GGUFCudaKernelsTests(unittest.TestCase):
output_native = linear.forward_native(x)
output_cuda = linear.forward_cuda(x)
assert torch.allclose(output_native, output_cuda, 1e-2), (
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
)
assert torch.allclose(
output_native, output_cuda, 1e-2
), f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
@nightly

View File

@@ -241,7 +241,6 @@ def parse_flag_from_env(key, default=False):
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
def floats_tensor(shape, scale=1.0, rng=None, name=None):
@@ -282,12 +281,128 @@ def nightly(test_case):
def is_torch_compile(test_case):
"""
Decorator marking a test that runs compile tests in the diffusers CI.
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
Decorator marking a test as a torch.compile test. These tests can be filtered using:
pytest -m "not compile" to skip
pytest -m compile to run only these tests
"""
return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
return pytest.mark.compile(test_case)
def is_single_file(test_case):
"""
Decorator marking a test as a single file loading test. These tests can be filtered using:
pytest -m "not single_file" to skip
pytest -m single_file to run only these tests
"""
return pytest.mark.single_file(test_case)
def is_lora(test_case):
"""
Decorator marking a test as a LoRA test. These tests can be filtered using:
pytest -m "not lora" to skip
pytest -m lora to run only these tests
"""
return pytest.mark.lora(test_case)
def is_ip_adapter(test_case):
"""
Decorator marking a test as an IP Adapter test. These tests can be filtered using:
pytest -m "not ip_adapter" to skip
pytest -m ip_adapter to run only these tests
"""
return pytest.mark.ip_adapter(test_case)
def is_training(test_case):
"""
Decorator marking a test as a training test. These tests can be filtered using:
pytest -m "not training" to skip
pytest -m training to run only these tests
"""
return pytest.mark.training(test_case)
def is_attention(test_case):
"""
Decorator marking a test as an attention test. These tests can be filtered using:
pytest -m "not attention" to skip
pytest -m attention to run only these tests
"""
return pytest.mark.attention(test_case)
def is_memory(test_case):
"""
Decorator marking a test as a memory optimization test. These tests can be filtered using:
pytest -m "not memory" to skip
pytest -m memory to run only these tests
"""
return pytest.mark.memory(test_case)
def is_cpu_offload(test_case):
"""
Decorator marking a test as a CPU offload test. These tests can be filtered using:
pytest -m "not cpu_offload" to skip
pytest -m cpu_offload to run only these tests
"""
return pytest.mark.cpu_offload(test_case)
def is_group_offload(test_case):
"""
Decorator marking a test as a group offload test. These tests can be filtered using:
pytest -m "not group_offload" to skip
pytest -m group_offload to run only these tests
"""
return pytest.mark.group_offload(test_case)
def is_bitsandbytes(test_case):
"""
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
pytest -m "not bitsandbytes" to skip
pytest -m bitsandbytes to run only these tests
"""
return pytest.mark.bitsandbytes(test_case)
def is_quanto(test_case):
"""
Decorator marking a test as a Quanto quantization test. These tests can be filtered using:
pytest -m "not quanto" to skip
pytest -m quanto to run only these tests
"""
return pytest.mark.quanto(test_case)
def is_torchao(test_case):
"""
Decorator marking a test as a TorchAO quantization test. These tests can be filtered using:
pytest -m "not torchao" to skip
pytest -m torchao to run only these tests
"""
return pytest.mark.torchao(test_case)
def is_gguf(test_case):
"""
Decorator marking a test as a GGUF quantization test. These tests can be filtered using:
pytest -m "not gguf" to skip
pytest -m gguf to run only these tests
"""
return pytest.mark.gguf(test_case)
def is_modelopt(test_case):
"""
Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using:
pytest -m "not modelopt" to skip
pytest -m modelopt to run only these tests
"""
return pytest.mark.modelopt(test_case)
def require_torch(test_case):