mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-09 05:54:24 +08:00
Compare commits
3 Commits
flux2-fix
...
model-test
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bffa3a9754 | ||
|
|
1c558712e8 | ||
|
|
1f026ad14e |
@@ -32,6 +32,20 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
|
||||
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
|
||||
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
|
||||
config.addinivalue_line("markers", "training: marks tests for training functionality")
|
||||
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
|
||||
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
|
||||
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
|
||||
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
|
||||
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
|
||||
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
|
||||
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
|
||||
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
|
||||
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
|
||||
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
|
||||
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
||||
@@ -317,9 +317,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
|
||||
)
|
||||
|
||||
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
|
||||
"Model parameters don't match!"
|
||||
)
|
||||
assert all(
|
||||
torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())
|
||||
), "Model parameters don't match!"
|
||||
|
||||
# Remove a shard file
|
||||
cached_shard_file = try_to_load_from_cache(
|
||||
@@ -335,9 +335,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
|
||||
# Verify error mentions the missing shard
|
||||
error_msg = str(context.exception)
|
||||
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
|
||||
f"Expected error about missing shard, got: {error_msg}"
|
||||
)
|
||||
assert (
|
||||
cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg
|
||||
), f"Expected error about missing shard, got: {error_msg}"
|
||||
|
||||
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
|
||||
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
|
||||
@@ -354,9 +354,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
download_requests = [r.method for r in m.request_history]
|
||||
assert download_requests.count("HEAD") == 3, (
|
||||
"3 HEAD requests one for config, one for model, and one for shard index file."
|
||||
)
|
||||
assert (
|
||||
download_requests.count("HEAD") == 3
|
||||
), "3 HEAD requests one for config, one for model, and one for shard index file."
|
||||
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
|
||||
|
||||
with requests_mock.mock(real_http=True) as m:
|
||||
@@ -368,9 +368,9 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
cache_requests = [r.method for r in m.request_history]
|
||||
assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
|
||||
"We should call only `model_info` to check for commit hash and knowing if shard index is present."
|
||||
)
|
||||
assert (
|
||||
"HEAD" == cache_requests[0] and len(cache_requests) == 2
|
||||
), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
|
||||
|
||||
def test_weight_overwrite(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
|
||||
|
||||
37
tests/models/testing_utils/__init__.py
Normal file
37
tests/models/testing_utils/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from .attention import AttentionTesterMixin
|
||||
from .common import ModelTesterMixin
|
||||
from .compile import TorchCompileTesterMixin
|
||||
from .ip_adapter import IPAdapterTesterMixin
|
||||
from .lora import LoraTesterMixin
|
||||
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
|
||||
from .quantization import (
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
QuantizationTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
)
|
||||
from .single_file import SingleFileTesterMixin
|
||||
from .training import TrainingTesterMixin
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttentionTesterMixin",
|
||||
"BitsAndBytesTesterMixin",
|
||||
"CPUOffloadTesterMixin",
|
||||
"GGUFTesterMixin",
|
||||
"GroupOffloadTesterMixin",
|
||||
"IPAdapterTesterMixin",
|
||||
"LayerwiseCastingTesterMixin",
|
||||
"LoraTesterMixin",
|
||||
"MemoryTesterMixin",
|
||||
"ModelOptTesterMixin",
|
||||
"ModelTesterMixin",
|
||||
"QuantizationTesterMixin",
|
||||
"QuantoTesterMixin",
|
||||
"SingleFileTesterMixin",
|
||||
"TorchAoTesterMixin",
|
||||
"TorchCompileTesterMixin",
|
||||
"TrainingTesterMixin",
|
||||
]
|
||||
180
tests/models/testing_utils/attention.py
Normal file
180
tests/models/testing_utils/attention.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention import AttentionModuleMixin
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
|
||||
from ...testing_utils import is_attention, require_accelerator, torch_device
|
||||
|
||||
|
||||
@is_attention
|
||||
@require_accelerator
|
||||
class AttentionTesterMixin:
|
||||
"""
|
||||
Mixin class for testing attention processor and module functionality on models.
|
||||
|
||||
Tests functionality from AttentionModuleMixin including:
|
||||
- Attention processor management (set/get)
|
||||
- QKV projection fusion/unfusion
|
||||
- Attention backends (XFormers, NPU, etc.)
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
|
||||
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: attention
|
||||
Use `pytest -m "not attention"` to skip these tests
|
||||
"""
|
||||
|
||||
base_precision = 1e-3
|
||||
|
||||
def test_fuse_unfuse_qkv_projections(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if not hasattr(model, "fuse_qkv_projections"):
|
||||
pytest.skip("Model does not support QKV projection fusion.")
|
||||
|
||||
# Get output before fusion
|
||||
with torch.no_grad():
|
||||
output_before_fusion = model(**inputs_dict)
|
||||
if isinstance(output_before_fusion, dict):
|
||||
output_before_fusion = output_before_fusion.to_tuple()[0]
|
||||
|
||||
# Fuse projections
|
||||
model.fuse_qkv_projections()
|
||||
|
||||
# Verify fusion occurred by checking for fused attributes
|
||||
has_fused_projections = False
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
|
||||
has_fused_projections = True
|
||||
assert module.fused_projections, "fused_projections flag should be True"
|
||||
break
|
||||
|
||||
if has_fused_projections:
|
||||
# Get output after fusion
|
||||
with torch.no_grad():
|
||||
output_after_fusion = model(**inputs_dict)
|
||||
if isinstance(output_after_fusion, dict):
|
||||
output_after_fusion = output_after_fusion.to_tuple()[0]
|
||||
|
||||
# Verify outputs match
|
||||
assert torch.allclose(
|
||||
output_before_fusion, output_after_fusion, atol=self.base_precision
|
||||
), "Output should not change after fusing projections"
|
||||
|
||||
# Unfuse projections
|
||||
model.unfuse_qkv_projections()
|
||||
|
||||
# Verify unfusion occurred
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
|
||||
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
|
||||
assert not module.fused_projections, "fused_projections flag should be False"
|
||||
|
||||
# Get output after unfusion
|
||||
with torch.no_grad():
|
||||
output_after_unfusion = model(**inputs_dict)
|
||||
if isinstance(output_after_unfusion, dict):
|
||||
output_after_unfusion = output_after_unfusion.to_tuple()[0]
|
||||
|
||||
# Verify outputs still match
|
||||
assert torch.allclose(
|
||||
output_before_fusion, output_after_unfusion, atol=self.base_precision
|
||||
), "Output should match original after unfusing projections"
|
||||
|
||||
def test_get_set_processor(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# Check if model has attention processors
|
||||
if not hasattr(model, "attn_processors"):
|
||||
pytest.skip("Model does not have attention processors.")
|
||||
|
||||
# Test getting processors
|
||||
processors = model.attn_processors
|
||||
assert isinstance(processors, dict), "attn_processors should return a dict"
|
||||
assert len(processors) > 0, "Model should have at least one attention processor"
|
||||
|
||||
# Test that all processors can be retrieved via get_processor
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
processor = module.get_processor()
|
||||
assert processor is not None, "get_processor should return a processor"
|
||||
|
||||
# Test setting a new processor
|
||||
new_processor = AttnProcessor()
|
||||
module.set_processor(new_processor)
|
||||
retrieved_processor = module.get_processor()
|
||||
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
|
||||
|
||||
def test_attention_processor_dict(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
pytest.skip("Model does not support setting attention processors.")
|
||||
|
||||
# Get current processors
|
||||
current_processors = model.attn_processors
|
||||
|
||||
# Create a dict of new processors
|
||||
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
|
||||
|
||||
# Set processors using dict
|
||||
model.set_attn_processor(new_processors)
|
||||
|
||||
# Verify all processors were set
|
||||
updated_processors = model.attn_processors
|
||||
for key in current_processors.keys():
|
||||
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
|
||||
|
||||
def test_attention_processor_count_mismatch_raises_error(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
pytest.skip("Model does not support setting attention processors.")
|
||||
|
||||
# Get current processors
|
||||
current_processors = model.attn_processors
|
||||
|
||||
# Create a dict with wrong number of processors
|
||||
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
|
||||
|
||||
# Verify error is raised
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.set_attn_processor(wrong_processors)
|
||||
|
||||
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
|
||||
514
tests/models/testing_utils/common.py
Normal file
514
tests/models/testing_utils/common.py
Normal file
@@ -0,0 +1,514 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
||||
|
||||
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant
|
||||
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
|
||||
|
||||
def compute_module_persistent_sizes(
|
||||
model: nn.Module,
|
||||
dtype: Optional[Union[str, torch.device]] = None,
|
||||
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the size of each submodule of a given model (parameters + persistent buffers).
|
||||
"""
|
||||
if dtype is not None:
|
||||
dtype = _get_proper_dtype(dtype)
|
||||
dtype_size = dtype_byte_size(dtype)
|
||||
if special_dtypes is not None:
|
||||
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
module_sizes = defaultdict(int)
|
||||
|
||||
module_list = []
|
||||
|
||||
module_list = named_persistent_module_tensors(model, recurse=True)
|
||||
|
||||
for name, tensor in module_list:
|
||||
if special_dtypes is not None and name in special_dtypes:
|
||||
size = tensor.numel() * special_dtypes_size[name]
|
||||
elif dtype is None:
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
# According to the code in set_module_tensor_to_device, these types won't be converted
|
||||
# so use their original size here
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
else:
|
||||
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
|
||||
name_parts = name.split(".")
|
||||
for idx in range(len(name_parts) + 1):
|
||||
module_sizes[".".join(name_parts[:idx])] += size
|
||||
|
||||
return module_sizes
|
||||
|
||||
|
||||
def calculate_expected_num_shards(index_map_path):
|
||||
"""
|
||||
Calculate expected number of shards from index file.
|
||||
|
||||
Args:
|
||||
index_map_path: Path to the sharded checkpoint index file
|
||||
|
||||
Returns:
|
||||
int: Expected number of shards
|
||||
"""
|
||||
with open(index_map_path) as f:
|
||||
weight_map_dict = json.load(f)["weight_map"]
|
||||
first_key = list(weight_map_dict.keys())[0]
|
||||
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
|
||||
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
|
||||
return expected_num_shards
|
||||
|
||||
|
||||
def check_device_map_is_respected(model, device_map):
|
||||
for param_name, param in model.named_parameters():
|
||||
# Find device in device_map
|
||||
while len(param_name) > 0 and param_name not in device_map:
|
||||
param_name = ".".join(param_name.split(".")[:-1])
|
||||
if param_name not in device_map:
|
||||
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
|
||||
|
||||
param_device = device_map[param_name]
|
||||
if param_device in ["cpu", "disk"]:
|
||||
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
|
||||
else:
|
||||
assert param.device == torch.device(
|
||||
param_device
|
||||
), f"Expected device {param_device} for {param_name}, got {param.device}"
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
"""
|
||||
Base mixin class for model testing with common test methods.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states")
|
||||
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
model_class = None
|
||||
base_precision = 1e-3
|
||||
model_split_percents = [0.5, 0.7]
|
||||
|
||||
def get_init_dict(self):
|
||||
raise NotImplementedError("get_init_dict must be implemented by subclasses. ")
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
raise NotImplementedError(
|
||||
"get_dummy_inputs must be implemented by subclasses. " "It should return inputs_dict."
|
||||
)
|
||||
|
||||
def test_from_save_pretrained(self, expected_max_diff=5e-5):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
|
||||
# check if all parameters shape are the same
|
||||
for param_name in model.state_dict().keys():
|
||||
param_1 = model.state_dict()[param_name]
|
||||
param_2 = new_model.state_dict()[param_name]
|
||||
assert (
|
||||
param_1.shape == param_2.shape
|
||||
), f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(image, dict):
|
||||
image = image.to_tuple()[0]
|
||||
|
||||
new_image = new_model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
assert (
|
||||
max_diff <= expected_max_diff
|
||||
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
|
||||
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, variant="fp16")
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
|
||||
|
||||
# non-variant cannot be loaded
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
self.model_class.from_pretrained(tmpdirname)
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs())
|
||||
if isinstance(image, dict):
|
||||
image = image.to_tuple()[0]
|
||||
|
||||
new_image = new_model(**self.get_dummy_inputs())
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
assert (
|
||||
max_diff <= expected_max_diff
|
||||
), f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
|
||||
def test_from_save_pretrained_dtype(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.to(dtype)
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
if (
|
||||
hasattr(self.model_class, "_keep_in_fp32_modules")
|
||||
and self.model_class._keep_in_fp32_modules is None
|
||||
):
|
||||
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
|
||||
)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
def test_determinism(self, expected_max_diff=1e-5):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
first = model(**self.get_dummy_inputs())
|
||||
if isinstance(first, dict):
|
||||
first = first.to_tuple()[0]
|
||||
|
||||
second = model(**self.get_dummy_inputs())
|
||||
if isinstance(second, dict):
|
||||
second = second.to_tuple()[0]
|
||||
|
||||
# Remove NaN values and compute max difference
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
|
||||
# Filter out NaN values
|
||||
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
|
||||
first_filtered = first_flat[mask]
|
||||
second_filtered = second_flat[mask]
|
||||
|
||||
max_diff = torch.abs(first_filtered - second_filtered).max().item()
|
||||
assert (
|
||||
max_diff <= expected_max_diff
|
||||
), f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
|
||||
def test_output(self, expected_output_shape=None):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
assert (
|
||||
output.shape == expected_output_shape
|
||||
), f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
|
||||
|
||||
def test_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
|
||||
# Track progress in https://github.com/pytorch/pytorch/issues/77764
|
||||
device = t.device
|
||||
if device.type == "mps":
|
||||
t = t.to("cpu")
|
||||
t[t != t] = 0
|
||||
return t.to(device)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
assert torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
), (
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
)
|
||||
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
def test_model_config_to_json_string(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
|
||||
json_string = model.config.to_json_string()
|
||||
assert isinstance(json_string, str), "Config to_json_string should return a string"
|
||||
assert len(json_string) > 0, "JSON string should not be empty"
|
||||
|
||||
@require_accelerator
|
||||
@pytest.mark.skipif(torch_device not in ["cuda", "xpu"])
|
||||
def test_from_save_pretrained_float16_bfloat16(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
fp32_modules = model._keep_in_fp32_modules
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for torch_dtype in [torch.bfloat16, torch.float16]:
|
||||
model.to(torch_dtype).save_pretrained(tmp_dir)
|
||||
model_loaded = self.model_class.from_pretrained(tmp_dir, torch_dtype=torch_dtype).to(torch_device)
|
||||
|
||||
for name, param in model_loaded.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
assert param.data.dtype == torch.float32
|
||||
else:
|
||||
assert param.data.dtype == torch_dtype
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**get_dummy_inputs())
|
||||
output_loaded = model_loaded(**get_dummy_inputs())
|
||||
|
||||
assert torch.allclose(
|
||||
output, output_loaded, atol=1e-4
|
||||
), f"Loaded model output differs for {torch_dtype}"
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints(self):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
assert (
|
||||
actual_num_shards == expected_num_shards
|
||||
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match after sharded save/load"
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints_with_variant(self):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
variant = "fp16"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
|
||||
|
||||
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
assert os.path.exists(
|
||||
os.path.join(tmp_dir, index_filename)
|
||||
), f"Variant index file {index_filename} should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
assert (
|
||||
actual_num_shards == expected_num_shards
|
||||
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match after variant sharded save/load"
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints_with_parallel_loading(self):
|
||||
import time
|
||||
|
||||
from diffusers.utils import constants
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
|
||||
# Save original values to restore after test
|
||||
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
|
||||
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
assert (
|
||||
actual_num_shards == expected_num_shards
|
||||
), f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
|
||||
# Load without parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = False
|
||||
start_time = time.time()
|
||||
model_sequential = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
sequential_load_time = time.time() - start_time
|
||||
model_sequential = model_sequential.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Load with parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = True
|
||||
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
|
||||
|
||||
start_time = time.time()
|
||||
model_parallel = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
parallel_load_time = time.time() - start_time
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_parallel = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict_parallel)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], output_parallel[0], atol=1e-5
|
||||
), "Output should match with parallel loading"
|
||||
|
||||
# Verify parallel loading is faster or at least not significantly slower
|
||||
# For small test models, the difference might be negligible or even slightly slower due to overhead
|
||||
# so we just check that parallel loading completed successfully and outputs match
|
||||
assert (
|
||||
parallel_load_time < sequential_load_time
|
||||
), f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s"
|
||||
finally:
|
||||
# Restore original values
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
|
||||
if original_parallel_workers is not None:
|
||||
constants.HF_PARALLEL_WORKERS = original_parallel_workers
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_model_parallelism(self):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will be on GPU 0 and GPU 1
|
||||
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match with model parallelism"
|
||||
162
tests/models/testing_utils/compile.py
Normal file
162
tests/models/testing_utils/compile.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_torch_compile,
|
||||
require_accelerator,
|
||||
require_torch_version_greater,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_torch_compile
|
||||
@require_accelerator
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class TorchCompileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing torch.compile functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: compile
|
||||
Use `pytest -m "not compile"` to skip these tests
|
||||
"""
|
||||
|
||||
different_shapes_for_compilation = None
|
||||
|
||||
def setup_method(self):
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
if self.model_class._repeated_blocks is None:
|
||||
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
recompile_limit = 1
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
torch.no_grad(),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_with_group_offloading(self):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
group_offload_kwargs = {
|
||||
"onload_device": torch_device,
|
||||
"offload_device": "cpu",
|
||||
"offload_type": "block_level",
|
||||
"num_blocks_per_group": 1,
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
}
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
model.compile()
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_on_different_shapes(self):
|
||||
if self.different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True, dynamic=True)
|
||||
|
||||
for height, width in self.different_shapes_for_compilation:
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
inputs_dict = self.get_dummy_inputs(height=height, width=width)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_works_with_aot(self):
|
||||
from torch._inductor.package import load_package
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
|
||||
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
|
||||
assert os.path.exists(package_path), f"Package file not created at {package_path}"
|
||||
loaded_binary = load_package(package_path, run_single_threaded=True)
|
||||
|
||||
model.forward = loaded_binary
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
109
tests/models/testing_utils/hub.py
Normal file
109
tests/models/testing_utils/hub.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
|
||||
from ...others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ModelPushToHubTesterMixin:
|
||||
"""
|
||||
Mixin class for testing push_to_hub functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
"""
|
||||
|
||||
identifier = uuid.uuid4()
|
||||
repo_id = f"test-model-{identifier}"
|
||||
org_repo_id = f"valid_org/{repo_id}-org"
|
||||
|
||||
def test_push_to_hub(self):
|
||||
"""Test pushing model to hub and loading it back."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.push_to_hub(self.repo_id, token=TOKEN)
|
||||
|
||||
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.repo_id)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
|
||||
|
||||
new_model = self.model_class.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(
|
||||
p1, p2
|
||||
), "Parameters don't match after save_pretrained with push_to_hub and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
"""Test pushing model to hub in organization namespace."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.push_to_hub(self.org_repo_id, token=TOKEN)
|
||||
|
||||
new_model = self.model_class.from_pretrained(self.org_repo_id)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(p1, p2), "Parameters don't match after push_to_hub to org and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
new_model = self.model_class.from_pretrained(self.org_repo_id)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
assert torch.equal(
|
||||
p1, p2
|
||||
), "Parameters don't match after save_pretrained with push_to_hub to org and from_pretrained"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.org_repo_id, token=TOKEN)
|
||||
|
||||
def test_push_to_hub_library_name(self):
|
||||
"""Test that library_name in model card is set to 'diffusers'."""
|
||||
if not is_jinja_available():
|
||||
pytest.skip("Model card tests cannot be performed without Jinja installed.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.push_to_hub(self.repo_id, token=TOKEN)
|
||||
|
||||
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
|
||||
assert (
|
||||
model_card.library_name == "diffusers"
|
||||
), f"Expected library_name 'diffusers', got {model_card.library_name}"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
205
tests/models/testing_utils/ip_adapter.py
Normal file
205
tests/models/testing_utils/ip_adapter.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention_processor import IPAdapterAttnProcessor
|
||||
|
||||
from ...testing_utils import is_ip_adapter, torch_device
|
||||
|
||||
|
||||
def create_ip_adapter_state_dict(model):
|
||||
"""
|
||||
Create a dummy IP Adapter state dict for testing.
|
||||
|
||||
Args:
|
||||
model: The model to create IP adapter weights for
|
||||
|
||||
Returns:
|
||||
dict: IP adapter state dict with to_k_ip and to_v_ip weights
|
||||
"""
|
||||
ip_state_dict = {}
|
||||
key_id = 1
|
||||
|
||||
for name in model.attn_processors.keys():
|
||||
# Skip self-attention processors
|
||||
cross_attention_dim = getattr(model.config, "cross_attention_dim", None)
|
||||
if cross_attention_dim is None:
|
||||
continue
|
||||
|
||||
# Get hidden size based on model architecture
|
||||
hidden_size = getattr(model.config, "hidden_size", cross_attention_dim)
|
||||
|
||||
# Create IP adapter processor to get state dict structure
|
||||
sd = IPAdapterAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
|
||||
ip_state_dict.update(
|
||||
{
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
|
||||
}
|
||||
)
|
||||
key_id += 2
|
||||
|
||||
return {"ip_adapter": ip_state_dict}
|
||||
|
||||
|
||||
def check_if_ip_adapter_correctly_set(model) -> bool:
|
||||
"""
|
||||
Check if IP Adapter processors are correctly set in the model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
|
||||
Returns:
|
||||
bool: True if IP Adapter is correctly set, False otherwise
|
||||
"""
|
||||
for module in model.attn_processors.values():
|
||||
if isinstance(module, IPAdapterAttnProcessor):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@is_ip_adapter
|
||||
class IPAdapterTesterMixin:
|
||||
"""
|
||||
Mixin class for testing IP Adapter functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: ip_adapter
|
||||
Use `pytest -m "not ip_adapter"` to skip these tests
|
||||
"""
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
|
||||
|
||||
def test_load_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create dummy IP adapter state dict
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||
|
||||
# Load IP adapter
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
assert check_if_ip_adapter_correctly_set(model), "IP Adapter processors not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
# Create dummy image embeds for IP adapter
|
||||
cross_attention_dim = getattr(model.config, "cross_attention_dim", 32)
|
||||
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
|
||||
inputs_dict_with_adapter = inputs_dict.copy()
|
||||
inputs_dict_with_adapter["image_embeds"] = image_embeds
|
||||
|
||||
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(
|
||||
output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4
|
||||
), "Output should differ with IP Adapter enabled"
|
||||
|
||||
def test_ip_adapter_scale(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create and load dummy IP adapter state dict
|
||||
ip_adapter_state_dict = create_ip_adapter_state_dict(model)
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
|
||||
# Test scale = 0.0 (no effect)
|
||||
model.set_ip_adapter_scale(0.0)
|
||||
torch.manual_seed(0)
|
||||
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Test scale = 1.0 (full effect)
|
||||
model.set_ip_adapter_scale(1.0)
|
||||
torch.manual_seed(0)
|
||||
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Outputs should differ with different scales
|
||||
assert not torch.allclose(
|
||||
output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4
|
||||
), "Output should differ with different IP Adapter scales"
|
||||
|
||||
def test_unload_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Save original processors
|
||||
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||
|
||||
# Create and load IP adapter
|
||||
ip_adapter_state_dict = create_ip_adapter_state_dict(model)
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be set"
|
||||
|
||||
# Unload IP adapter
|
||||
model.unload_ip_adapter()
|
||||
assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded"
|
||||
|
||||
# Verify processors are restored
|
||||
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||
assert original_processors == current_processors, "Processors should be restored after unload"
|
||||
|
||||
def test_ip_adapter_save_load(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create and load IP adapter
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict()
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_before_save = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save the IP adapter weights
|
||||
save_path = os.path.join(tmpdir, "ip_adapter.safetensors")
|
||||
import safetensors.torch
|
||||
|
||||
safetensors.torch.save_file(ip_adapter_state_dict["ip_adapter"], save_path)
|
||||
|
||||
# Unload and reload
|
||||
model.unload_ip_adapter()
|
||||
assert not check_if_ip_adapter_correctly_set(model), "IP Adapter should be unloaded"
|
||||
|
||||
# Reload from saved file
|
||||
loaded_state_dict = {"ip_adapter": safetensors.torch.load_file(save_path)}
|
||||
model._load_ip_adapter_weights([loaded_state_dict])
|
||||
assert check_if_ip_adapter_correctly_set(model), "IP Adapter should be loaded"
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_after_load = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Outputs should match before and after save/load
|
||||
assert torch.allclose(
|
||||
output_before_save, output_after_load, atol=1e-4, rtol=1e-4
|
||||
), "Output should match before and after save/load"
|
||||
220
tests/models/testing_utils/lora.py
Normal file
220
tests/models/testing_utils/lora.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from diffusers.utils.testing_utils import check_if_dicts_are_equal
|
||||
|
||||
from ...testing_utils import is_lora, require_peft_backend, torch_device
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Check if LoRA layers are correctly set in the model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
|
||||
Returns:
|
||||
bool: True if LoRA is correctly set, False otherwise
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@is_lora
|
||||
@require_peft_backend
|
||||
class LoraTesterMixin:
|
||||
"""
|
||||
Mixin class for testing LoRA/PEFT functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: lora
|
||||
Use `pytest -m "not lora"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||
|
||||
def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(
|
||||
output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4
|
||||
), "Output should differ with LoRA enabled"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
assert os.path.isfile(
|
||||
os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
), "LoRA weights file not created"
|
||||
|
||||
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||
|
||||
for k in state_dict_loaded:
|
||||
loaded_v = state_dict_loaded[k]
|
||||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||
assert torch.allclose(loaded_v, retrieved_v), f"Mismatch in LoRA weight {k}"
|
||||
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(
|
||||
output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4
|
||||
), "Output should differ with LoRA enabled"
|
||||
assert torch.allclose(
|
||||
outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4
|
||||
), "Outputs should match before and after save/load"
|
||||
|
||||
def test_lora_wrong_adapter_name_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
wrong_name = "foo"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
|
||||
|
||||
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
|
||||
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, rank=4, lora_alpha=4, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
metadata = model.peft_config["default"].to_dict()
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
# Perturb the metadata in the state dict
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
metadata = {"format": "pt"}
|
||||
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
|
||||
443
tests/models/testing_utils/memory.py
Normal file
443
tests/models/testing_utils/memory.py
Normal file
@@ -0,0 +1,443 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import glob
|
||||
import inspect
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.utils.modeling import compute_module_sizes
|
||||
|
||||
from diffusers.utils.testing_utils import _check_safetensors_serialization
|
||||
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
backend_synchronize,
|
||||
is_cpu_offload,
|
||||
is_group_offload,
|
||||
is_memory,
|
||||
require_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .common import check_device_map_is_respected
|
||||
|
||||
|
||||
def cast_maybe_tensor_dtype(inputs_dict, from_dtype, to_dtype):
|
||||
"""Helper to cast tensor inputs from one dtype to another."""
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor) and value.dtype == from_dtype:
|
||||
inputs_dict[key] = value.to(to_dtype)
|
||||
return inputs_dict
|
||||
|
||||
|
||||
def require_offload_support(func):
|
||||
"""
|
||||
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_group_offload_support(func):
|
||||
"""
|
||||
Decorator to skip tests if model doesn't support group offloading.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@is_cpu_offload
|
||||
class CPUOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing CPU offloading functionality.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- model_split_percents: List of percentages for splitting model across devices
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cpu_offload
|
||||
Use `pytest -m "not cpu_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
model_split_percents = [0.5, 0.7]
|
||||
|
||||
@require_offload_support
|
||||
def test_cpu_offload(self):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match with CPU offloading"
|
||||
|
||||
@require_offload_support
|
||||
def test_disk_offload_without_safetensors(self):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
# Force disk offload by setting very small CPU memory
|
||||
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
||||
# This errors out because it's missing an offload folder
|
||||
with pytest.raises(ValueError):
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), "Output should match with disk offloading"
|
||||
|
||||
@require_offload_support
|
||||
def test_disk_offload_with_safetensors(self):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(
|
||||
base_output[0], new_output[0], atol=1e-5
|
||||
), "Output should match with disk offloading (safetensors)"
|
||||
|
||||
|
||||
@is_group_offload
|
||||
class GroupOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing group offloading functionality.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: group_offload
|
||||
Use `pytest -m "not group_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
@require_group_offload_support
|
||||
def test_group_offloading(self, record_stream=False):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_forward(model):
|
||||
assert all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in model.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
), "Group offloading hook should be set"
|
||||
model.eval()
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
|
||||
output_with_group_offloading1 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
|
||||
output_with_group_offloading2 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="leaf_level")
|
||||
output_with_group_offloading3 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
|
||||
)
|
||||
output_with_group_offloading4 = run_forward(model)
|
||||
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading1, atol=1e-5
|
||||
), "Output should match with block-level offloading"
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading2, atol=1e-5
|
||||
), "Output should match with non-blocking block-level offloading"
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading3, atol=1e-5
|
||||
), "Output should match with leaf-level offloading"
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading4, atol=1e-5
|
||||
), "Output should match with leaf-level offloading with stream"
|
||||
|
||||
@require_group_offload_support
|
||||
@torch.no_grad()
|
||||
def test_group_offloading_with_layerwise_casting(self, record_stream=False, offload_type="block_level"):
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
storage_dtype, compute_dtype = torch.float16, torch.float32
|
||||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
|
||||
)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
@require_group_offload_support
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def test_group_offloading_with_disk(self, offload_type="block_level", record_stream=False, atol=1e-5):
|
||||
def _has_generator_arg(model):
|
||||
sig = inspect.signature(model.forward)
|
||||
params = sig.parameters
|
||||
return "generator" in params
|
||||
|
||||
def _run_forward(model, inputs_dict):
|
||||
accepts_generator = _has_generator_arg(model)
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
torch.manual_seed(0)
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = _run_forward(model, inputs_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
num_blocks_per_group = None if offload_type == "leaf_level" else 1
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.enable_group_offload(
|
||||
torch_device,
|
||||
offload_type=offload_type,
|
||||
offload_to_disk_path=tmpdir,
|
||||
use_stream=True,
|
||||
record_stream=record_stream,
|
||||
**additional_kwargs,
|
||||
)
|
||||
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
|
||||
assert has_safetensors, "No safetensors found in the directory."
|
||||
|
||||
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
|
||||
# in nature. So, skip it.
|
||||
if offload_type != "leaf_level":
|
||||
is_correct, extra_files, missing_files = _check_safetensors_serialization(
|
||||
module=model,
|
||||
offload_to_disk_path=tmpdir,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
)
|
||||
if not is_correct:
|
||||
if extra_files:
|
||||
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
|
||||
elif missing_files:
|
||||
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||
|
||||
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||
assert torch.allclose(
|
||||
output_without_group_offloading, output_with_group_offloading, atol=atol
|
||||
), "Output should match with disk-based group offloading"
|
||||
|
||||
|
||||
class LayerwiseCastingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing layerwise dtype casting for memory optimization.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def test_layerwise_casting_memory(self):
|
||||
MB_TOLERANCE = 0.2
|
||||
LEAST_COMPUTE_CAPABILITY = 8.0
|
||||
|
||||
def reset_memory_stats():
|
||||
gc.collect()
|
||||
backend_synchronize(torch_device)
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
|
||||
def get_memory_usage(storage_dtype, compute_dtype):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device, dtype=compute_dtype)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
|
||||
reset_memory_stats()
|
||||
model(**inputs_dict)
|
||||
model_memory_footprint = model.get_memory_footprint()
|
||||
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
|
||||
|
||||
return model_memory_footprint, peak_inference_memory_allocated_mb
|
||||
|
||||
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
|
||||
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
|
||||
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
|
||||
torch.float8_e4m3fn, torch.bfloat16
|
||||
)
|
||||
|
||||
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
|
||||
assert (
|
||||
fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint
|
||||
), "Memory footprint should decrease with lower precision storage"
|
||||
|
||||
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
|
||||
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
|
||||
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
|
||||
assert (
|
||||
fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory
|
||||
), "Peak memory should be lower with bf16 compute on newer GPUs"
|
||||
|
||||
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
|
||||
# bytes. This only happens for some models, so we allow a small tolerance.
|
||||
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
|
||||
assert (
|
||||
fp8_e4m3_fp32_max_memory < fp32_max_memory
|
||||
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
|
||||
), "Peak memory should be lower or within tolerance with fp8 storage"
|
||||
|
||||
|
||||
@is_memory
|
||||
@require_accelerator
|
||||
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
|
||||
"""
|
||||
Combined mixin class for all memory optimization tests including CPU/disk offloading,
|
||||
group offloading, and layerwise dtype casting.
|
||||
|
||||
This mixin inherits from:
|
||||
- CPUOffloadTesterMixin: CPU and disk offloading tests
|
||||
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
|
||||
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: memory
|
||||
Use `pytest -m "not memory"` to skip these tests
|
||||
"""
|
||||
|
||||
pass
|
||||
833
tests/models/testing_utils/quantization.py
Normal file
833
tests/models/testing_utils/quantization.py
Normal file
@@ -0,0 +1,833 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
|
||||
from diffusers.utils.import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_gguf_available,
|
||||
is_nvidia_modelopt_available,
|
||||
is_optimum_quanto_available,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_bitsandbytes,
|
||||
is_gguf,
|
||||
is_modelopt,
|
||||
is_quanto,
|
||||
is_torchao,
|
||||
nightly,
|
||||
require_accelerate,
|
||||
require_accelerator,
|
||||
require_bitsandbytes_version_greater,
|
||||
require_gguf_version_greater_or_equal,
|
||||
require_quanto,
|
||||
require_torchao_version_greater_or_equal,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_nvidia_modelopt_available():
|
||||
import modelopt.torch.quantization as mtq
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import QLinear
|
||||
|
||||
if is_gguf_available():
|
||||
pass
|
||||
|
||||
if is_torchao_available():
|
||||
|
||||
if is_torchao_version(">=", "0.9.0"):
|
||||
pass
|
||||
|
||||
|
||||
@require_accelerator
|
||||
class QuantizationTesterMixin:
|
||||
"""
|
||||
Base mixin class providing common test implementations for quantization testing.
|
||||
|
||||
Backend-specific mixins should:
|
||||
1. Implement _create_quantized_model(config_kwargs)
|
||||
2. Implement _verify_if_layer_quantized(name, module, config_kwargs)
|
||||
3. Define their config dict (e.g., BNB_CONFIGS, QUANTO_WEIGHT_TYPES, etc.)
|
||||
4. Use @pytest.mark.parametrize to create tests that call the common test methods below
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods in test classes:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
"""
|
||||
Create a quantized model with the given config kwargs.
|
||||
|
||||
Args:
|
||||
config_kwargs: Quantization config parameters
|
||||
**extra_kwargs: Additional kwargs to pass to from_pretrained (e.g., device_map, offload_folder)
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement _create_quantized_model")
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
|
||||
|
||||
def _is_module_quantized(self, module):
|
||||
"""
|
||||
Check if a module is quantized. Returns True if quantized, False otherwise.
|
||||
Default implementation tries _verify_if_layer_quantized and catches exceptions.
|
||||
Subclasses can override for more efficient checking.
|
||||
"""
|
||||
try:
|
||||
self._verify_if_layer_quantized("", module, {})
|
||||
return True
|
||||
except (AssertionError, AttributeError):
|
||||
return False
|
||||
|
||||
def _load_unquantized_model(self):
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {})
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _test_quantization_num_parameters(self, config_kwargs):
|
||||
model = self._load_unquantized_model()
|
||||
num_params = model.num_parameters()
|
||||
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
num_params_quantized = model_quantized.num_parameters()
|
||||
|
||||
assert (
|
||||
num_params == num_params_quantized
|
||||
), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
|
||||
|
||||
def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2):
|
||||
model = self._load_unquantized_model()
|
||||
mem = model.get_memory_footprint()
|
||||
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
mem_quantized = model_quantized.get_memory_footprint()
|
||||
|
||||
ratio = mem / mem_quantized
|
||||
assert (
|
||||
ratio >= expected_memory_reduction
|
||||
), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
|
||||
|
||||
def _test_quantization_inference(self, config_kwargs):
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model_quantized(**inputs)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
|
||||
def _test_quantization_dtype_assignment(self, config_kwargs):
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
model.to(torch.float16)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
device_0 = f"{torch_device}:0"
|
||||
model.to(device=device_0, dtype=torch.float16)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
model.float()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
model.half()
|
||||
|
||||
model.to(torch_device)
|
||||
|
||||
def _test_quantization_lora_inference(self, config_kwargs):
|
||||
try:
|
||||
from peft import LoraConfig
|
||||
except ImportError:
|
||||
pytest.skip("peft is not available")
|
||||
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__})")
|
||||
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
)
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
assert output is not None, "Model output is None with LoRA"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN with LoRA"
|
||||
|
||||
def _test_quantization_serialization(self, config_kwargs):
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_pretrained(tmpdir, safe_serialization=True)
|
||||
|
||||
model_loaded = self.model_class.from_pretrained(tmpdir)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model_loaded(**inputs)
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
|
||||
|
||||
def _test_quantized_layers(self, config_kwargs):
|
||||
model_fp = self._load_unquantized_model()
|
||||
num_linear_layers = sum(1 for module in model_fp.modules() if isinstance(module, torch.nn.Linear))
|
||||
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
|
||||
num_fp32_modules = 0
|
||||
if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules:
|
||||
for name, module in model_quantized.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules):
|
||||
num_fp32_modules += 1
|
||||
|
||||
expected_quantized_layers = num_linear_layers - num_fp32_modules
|
||||
|
||||
num_quantized_layers = 0
|
||||
for name, module in model_quantized.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules:
|
||||
if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules):
|
||||
continue
|
||||
self._verify_if_layer_quantized(name, module, config_kwargs)
|
||||
num_quantized_layers += 1
|
||||
|
||||
assert (
|
||||
num_quantized_layers > 0
|
||||
), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
|
||||
assert (
|
||||
num_quantized_layers == expected_quantized_layers
|
||||
), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
|
||||
|
||||
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
|
||||
"""
|
||||
Test that modules specified in modules_to_not_convert are not quantized.
|
||||
|
||||
Args:
|
||||
config_kwargs: Base quantization config kwargs
|
||||
modules_to_not_convert: List of module names to exclude from quantization
|
||||
"""
|
||||
# Create config with modules_to_not_convert
|
||||
config_kwargs_with_exclusion = config_kwargs.copy()
|
||||
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
|
||||
|
||||
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
|
||||
|
||||
# Find a module that should NOT be quantized
|
||||
found_excluded = False
|
||||
for name, module in model_with_exclusion.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is in the exclusion list
|
||||
if any(excluded in name for excluded in modules_to_not_convert):
|
||||
found_excluded = True
|
||||
# This module should NOT be quantized
|
||||
assert not self._is_module_quantized(
|
||||
module
|
||||
), f"Module {name} should not be quantized but was found to be quantized"
|
||||
|
||||
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}"
|
||||
|
||||
# Find a module that SHOULD be quantized (not in exclusion list)
|
||||
found_quantized = False
|
||||
for name, module in model_with_exclusion.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is NOT in the exclusion list
|
||||
if not any(excluded in name for excluded in modules_to_not_convert):
|
||||
if self._is_module_quantized(module):
|
||||
found_quantized = True
|
||||
break
|
||||
|
||||
assert found_quantized, "No quantized layers found outside of excluded modules"
|
||||
|
||||
# Compare memory footprint with fully quantized model
|
||||
model_fully_quantized = self._create_quantized_model(config_kwargs)
|
||||
|
||||
mem_with_exclusion = model_with_exclusion.get_memory_footprint()
|
||||
mem_fully_quantized = model_fully_quantized.get_memory_footprint()
|
||||
|
||||
assert (
|
||||
mem_with_exclusion > mem_fully_quantized
|
||||
), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
|
||||
|
||||
def _test_quantization_device_map(self, config_kwargs):
|
||||
"""
|
||||
Test that quantized models work correctly with device_map="auto".
|
||||
|
||||
Args:
|
||||
config_kwargs: Base quantization config kwargs
|
||||
"""
|
||||
model = self._create_quantized_model(config_kwargs, device_map="auto")
|
||||
|
||||
# Verify device map is set
|
||||
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute"
|
||||
assert model.hf_device_map is not None, "hf_device_map should not be None"
|
||||
|
||||
# Verify inference works
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs)
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
|
||||
|
||||
@is_bitsandbytes
|
||||
@nightly
|
||||
@require_accelerator
|
||||
@require_bitsandbytes_version_greater("0.43.2")
|
||||
@require_accelerate
|
||||
class BitsAndBytesTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing BitsAndBytes quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional class attributes:
|
||||
- BNB_CONFIGS: Dict of config name -> BitsAndBytesConfig kwargs to test
|
||||
|
||||
Pytest mark: bitsandbytes
|
||||
Use `pytest -m "not bitsandbytes"` to skip these tests
|
||||
"""
|
||||
|
||||
# Standard BnB configs tested for all models
|
||||
# Subclasses can override to add or modify configs
|
||||
BNB_CONFIGS = {
|
||||
"4bit_nf4": {
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_compute_dtype": torch.float16,
|
||||
},
|
||||
"4bit_fp4": {
|
||||
"load_in_4bit": True,
|
||||
"bnb_4bit_quant_type": "fp4",
|
||||
"bnb_4bit_compute_dtype": torch.float16,
|
||||
},
|
||||
"8bit": {
|
||||
"load_in_8bit": True,
|
||||
},
|
||||
}
|
||||
|
||||
BNB_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
"4bit_nf4": 3.0,
|
||||
"4bit_fp4": 3.0,
|
||||
"8bit": 1.5,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = BitsAndBytesConfig(**config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params
|
||||
assert (
|
||||
module.weight.__class__ == expected_weight_class
|
||||
), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantization_num_parameters(self, config_name):
|
||||
self._test_quantization_num_parameters(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantization_memory_footprint(self, config_name):
|
||||
expected = self.BNB_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
|
||||
self._test_quantization_memory_footprint(self.BNB_CONFIGS[config_name], expected_memory_reduction=expected)
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantization_inference(self, config_name):
|
||||
self._test_quantization_inference(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
|
||||
def test_bnb_quantization_dtype_assignment(self, config_name):
|
||||
self._test_quantization_dtype_assignment(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
|
||||
def test_bnb_quantization_lora_inference(self, config_name):
|
||||
self._test_quantization_lora_inference(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["4bit_nf4"])
|
||||
def test_bnb_quantization_serialization(self, config_name):
|
||||
self._test_quantization_serialization(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantized_layers(self, config_name):
|
||||
self._test_quantized_layers(self.BNB_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(BNB_CONFIGS.keys()))
|
||||
def test_bnb_quantization_config_serialization(self, config_name):
|
||||
model = self._create_quantized_model(self.BNB_CONFIGS[config_name])
|
||||
|
||||
assert "quantization_config" in model.config, "Missing quantization_config"
|
||||
_ = model.config["quantization_config"].to_dict()
|
||||
_ = model.config["quantization_config"].to_diff_dict()
|
||||
_ = model.config["quantization_config"].to_json_string()
|
||||
|
||||
def test_bnb_original_dtype(self):
|
||||
config_name = list(self.BNB_CONFIGS.keys())[0]
|
||||
config_kwargs = self.BNB_CONFIGS[config_name]
|
||||
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
assert "_pre_quantization_dtype" in model.config, "Missing _pre_quantization_dtype"
|
||||
assert model.config["_pre_quantization_dtype"] in [
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
], f"Unexpected dtype: {model.config['_pre_quantization_dtype']}"
|
||||
|
||||
def test_bnb_keep_modules_in_fp32(self):
|
||||
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
||||
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
|
||||
|
||||
config_kwargs = self.BNB_CONFIGS["4bit_nf4"]
|
||||
|
||||
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
|
||||
self.model_class._keep_in_fp32_modules = ["proj_out"]
|
||||
|
||||
try:
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
|
||||
assert (
|
||||
module.weight.dtype == torch.float32
|
||||
), f"Module {name} should be FP32 but is {module.weight.dtype}"
|
||||
else:
|
||||
assert (
|
||||
module.weight.dtype == torch.uint8
|
||||
), f"Module {name} should be uint8 but is {module.weight.dtype}"
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
_ = model(**inputs)
|
||||
finally:
|
||||
if original_fp32_modules is not None:
|
||||
self.model_class._keep_in_fp32_modules = original_fp32_modules
|
||||
|
||||
def test_bnb_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(self.BNB_CONFIGS["4bit_nf4"], modules_to_exclude)
|
||||
|
||||
def test_bnb_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(self.BNB_CONFIGS["4bit_nf4"])
|
||||
|
||||
|
||||
@is_quanto
|
||||
@nightly
|
||||
@require_quanto
|
||||
@require_accelerate
|
||||
@require_accelerator
|
||||
class QuantoTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing Quanto quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional class attributes:
|
||||
- QUANTO_WEIGHT_TYPES: Dict of weight_type_name -> qtype
|
||||
|
||||
Pytest mark: quanto
|
||||
Use `pytest -m "not quanto"` to skip these tests
|
||||
"""
|
||||
|
||||
QUANTO_WEIGHT_TYPES = {
|
||||
"float8": {"weights_dtype": "float8"},
|
||||
"int8": {"weights_dtype": "int8"},
|
||||
"int4": {"weights_dtype": "int4"},
|
||||
"int2": {"weights_dtype": "int2"},
|
||||
}
|
||||
|
||||
QUANTO_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
"float8": 1.5,
|
||||
"int8": 1.5,
|
||||
"int4": 3.0,
|
||||
"int2": 7.0,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = QuantoConfig(**config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
assert isinstance(module, QLinear), f"Layer {name} is not QLinear, got {type(module)}"
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
|
||||
def test_quanto_quantization_num_parameters(self, weight_type_name):
|
||||
self._test_quantization_num_parameters(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
|
||||
def test_quanto_quantization_memory_footprint(self, weight_type_name):
|
||||
expected = self.QUANTO_EXPECTED_MEMORY_REDUCTIONS.get(weight_type_name, 1.2)
|
||||
self._test_quantization_memory_footprint(
|
||||
self.QUANTO_WEIGHT_TYPES[weight_type_name], expected_memory_reduction=expected
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", list(QUANTO_WEIGHT_TYPES.keys()))
|
||||
def test_quanto_quantization_inference(self, weight_type_name):
|
||||
self._test_quantization_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", ["int8"])
|
||||
def test_quanto_quantized_layers(self, weight_type_name):
|
||||
self._test_quantized_layers(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", ["int8"])
|
||||
def test_quanto_quantization_lora_inference(self, weight_type_name):
|
||||
self._test_quantization_lora_inference(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
@pytest.mark.parametrize("weight_type_name", ["int8"])
|
||||
def test_quanto_quantization_serialization(self, weight_type_name):
|
||||
self._test_quantization_serialization(self.QUANTO_WEIGHT_TYPES[weight_type_name])
|
||||
|
||||
def test_quanto_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(self.QUANTO_WEIGHT_TYPES["int8"], modules_to_exclude)
|
||||
|
||||
def test_quanto_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(self.QUANTO_WEIGHT_TYPES["int8"])
|
||||
|
||||
|
||||
@is_torchao
|
||||
@require_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing TorchAO quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional class attributes:
|
||||
- TORCHAO_QUANT_TYPES: Dict of quantization type strings to test
|
||||
|
||||
Pytest mark: torchao
|
||||
Use `pytest -m "not torchao"` to skip these tests
|
||||
"""
|
||||
|
||||
TORCHAO_QUANT_TYPES = {
|
||||
"int4wo": {"quant_type": "int4_weight_only"},
|
||||
"int8wo": {"quant_type": "int8_weight_only"},
|
||||
"int8dq": {"quant_type": "int8_dynamic_activation_int8_weight"},
|
||||
}
|
||||
|
||||
TORCHAO_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
"int4wo": 3.0,
|
||||
"int8wo": 1.5,
|
||||
"int8dq": 1.5,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = TorchAoConfig(**config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
|
||||
|
||||
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
|
||||
def test_torchao_quantization_num_parameters(self, quant_type):
|
||||
self._test_quantization_num_parameters(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
|
||||
def test_torchao_quantization_memory_footprint(self, quant_type):
|
||||
expected = self.TORCHAO_EXPECTED_MEMORY_REDUCTIONS.get(quant_type, 1.2)
|
||||
self._test_quantization_memory_footprint(
|
||||
self.TORCHAO_QUANT_TYPES[quant_type], expected_memory_reduction=expected
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("quant_type", list(TORCHAO_QUANT_TYPES.keys()))
|
||||
def test_torchao_quantization_inference(self, quant_type):
|
||||
self._test_quantization_inference(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
@pytest.mark.parametrize("quant_type", ["int8wo"])
|
||||
def test_torchao_quantized_layers(self, quant_type):
|
||||
self._test_quantized_layers(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
@pytest.mark.parametrize("quant_type", ["int8wo"])
|
||||
def test_torchao_quantization_lora_inference(self, quant_type):
|
||||
self._test_quantization_lora_inference(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
@pytest.mark.parametrize("quant_type", ["int8wo"])
|
||||
def test_torchao_quantization_serialization(self, quant_type):
|
||||
self._test_quantization_serialization(self.TORCHAO_QUANT_TYPES[quant_type])
|
||||
|
||||
def test_torchao_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
# Get a module name that exists in the model - this needs to be set by test classes
|
||||
# For now, use a generic pattern that should work with transformer models
|
||||
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(
|
||||
self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
|
||||
)
|
||||
|
||||
def test_torchao_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(self.TORCHAO_QUANT_TYPES["int8wo"])
|
||||
|
||||
|
||||
@is_gguf
|
||||
@nightly
|
||||
@require_accelerate
|
||||
@require_accelerator
|
||||
@require_gguf_version_greater_or_equal("0.10.0")
|
||||
class GGUFTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing GGUF quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- gguf_filename: URL or path to the GGUF file
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: gguf
|
||||
Use `pytest -m "not gguf"` to skip these tests
|
||||
"""
|
||||
|
||||
gguf_filename = None
|
||||
|
||||
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
|
||||
if config_kwargs is None:
|
||||
config_kwargs = {"compute_dtype": torch.bfloat16}
|
||||
|
||||
config = GGUFQuantizationConfig(**config_kwargs)
|
||||
kwargs = {
|
||||
"quantization_config": config,
|
||||
"torch_dtype": config_kwargs.get("compute_dtype", torch.bfloat16),
|
||||
}
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_single_file(self.gguf_filename, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs=None):
|
||||
from diffusers.quantizers.gguf.utils import GGUFParameter
|
||||
|
||||
assert isinstance(module.weight, GGUFParameter), f"{name} weight is not GGUFParameter"
|
||||
assert hasattr(module.weight, "quant_type"), f"{name} weight missing quant_type"
|
||||
assert module.weight.dtype == torch.uint8, f"{name} weight dtype should be uint8"
|
||||
|
||||
def test_gguf_quantization_inference(self):
|
||||
self._test_quantization_inference({"compute_dtype": torch.bfloat16})
|
||||
|
||||
def test_gguf_keep_modules_in_fp32(self):
|
||||
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
||||
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
|
||||
|
||||
_keep_in_fp32_modules = self.model_class._keep_in_fp32_modules
|
||||
self.model_class._keep_in_fp32_modules = ["proj_out"]
|
||||
|
||||
try:
|
||||
model = self._create_quantized_model()
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
|
||||
assert module.weight.dtype == torch.float32, f"Module {name} should be FP32"
|
||||
finally:
|
||||
self.model_class._keep_in_fp32_modules = _keep_in_fp32_modules
|
||||
|
||||
def test_gguf_quantization_dtype_assignment(self):
|
||||
self._test_quantization_dtype_assignment({"compute_dtype": torch.bfloat16})
|
||||
|
||||
def test_gguf_quantization_lora_inference(self):
|
||||
self._test_quantization_lora_inference({"compute_dtype": torch.bfloat16})
|
||||
|
||||
def test_gguf_dequantize_model(self):
|
||||
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
|
||||
|
||||
model = self._create_quantized_model()
|
||||
model.dequantize()
|
||||
|
||||
def _check_for_gguf_linear(model):
|
||||
has_children = list(model.children())
|
||||
if not has_children:
|
||||
return
|
||||
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear"
|
||||
assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter"
|
||||
|
||||
for name, module in model.named_children():
|
||||
_check_for_gguf_linear(module)
|
||||
|
||||
def test_gguf_quantized_layers(self):
|
||||
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
|
||||
|
||||
|
||||
@is_modelopt
|
||||
@nightly
|
||||
@require_accelerator
|
||||
@require_accelerate
|
||||
@require_modelopt_version_greater_or_equal("0.33.1")
|
||||
class ModelOptTesterMixin(QuantizationTesterMixin):
|
||||
"""
|
||||
Mixin class for testing NVIDIA ModelOpt quantization on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional class attributes:
|
||||
- MODELOPT_CONFIGS: Dict of config name -> NVIDIAModelOptConfig kwargs to test
|
||||
|
||||
Pytest mark: modelopt
|
||||
Use `pytest -m "not modelopt"` to skip these tests
|
||||
"""
|
||||
|
||||
MODELOPT_CONFIGS = {
|
||||
"fp8": {"quant_type": "FP8"},
|
||||
"int8": {"quant_type": "INT8"},
|
||||
"int4": {"quant_type": "INT4"},
|
||||
}
|
||||
|
||||
MODELOPT_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
"fp8": 1.5,
|
||||
"int8": 1.5,
|
||||
"int4": 3.0,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = NVIDIAModelOptConfig(**config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs.update(extra_kwargs)
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
assert mtq.utils.is_quantized(module), f"Layer {name} does not have weight_quantizer attribute (not quantized)"
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantization_num_parameters(self, config_name):
|
||||
self._test_quantization_num_parameters(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()))
|
||||
def test_modelopt_quantization_memory_footprint(self, config_name):
|
||||
expected = self.MODELOPT_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
|
||||
self._test_quantization_memory_footprint(
|
||||
self.MODELOPT_CONFIGS[config_name], expected_memory_reduction=expected
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("config_name", list(MODELOPT_CONFIGS.keys()))
|
||||
def test_modelopt_quantization_inference(self, config_name):
|
||||
self._test_quantization_inference(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantization_dtype_assignment(self, config_name):
|
||||
self._test_quantization_dtype_assignment(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantization_lora_inference(self, config_name):
|
||||
self._test_quantization_lora_inference(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantization_serialization(self, config_name):
|
||||
self._test_quantization_serialization(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["fp8"])
|
||||
def test_modelopt_quantized_layers(self, config_name):
|
||||
self._test_quantized_layers(self.MODELOPT_CONFIGS[config_name])
|
||||
|
||||
def test_modelopt_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(self.MODELOPT_CONFIGS["fp8"], modules_to_exclude)
|
||||
|
||||
def test_modelopt_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(self.MODELOPT_CONFIGS["fp8"])
|
||||
247
tests/models/testing_utils/single_file.py
Normal file
247
tests/models/testing_utils/single_file.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_single_file,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
|
||||
"""Download a single file checkpoint from the Hub to a temporary directory."""
|
||||
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
|
||||
return path
|
||||
|
||||
|
||||
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
|
||||
"""Download diffusers config files (excluding weights) from a repository."""
|
||||
path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
ignore_patterns=[
|
||||
"**/*.ckpt",
|
||||
"*.ckpt",
|
||||
"**/*.bin",
|
||||
"*.bin",
|
||||
"**/*.pt",
|
||||
"*.pt",
|
||||
"**/*.safetensors",
|
||||
"*.safetensors",
|
||||
],
|
||||
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
|
||||
local_dir=tmpdir,
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@is_single_file
|
||||
class SingleFileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing single file loading for models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- ckpt_path: Path or Hub path to the single file checkpoint
|
||||
- subfolder: (Optional) Subfolder within the repo
|
||||
- torch_dtype: (Optional) torch dtype to use for testing
|
||||
|
||||
Pytest mark: single_file
|
||||
Use `pytest -m "not single_file"` to skip these tests
|
||||
"""
|
||||
|
||||
pretrained_model_name_or_path = None
|
||||
ckpt_path = None
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_model_config(self):
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
|
||||
pretrained_kwargs["device"] = torch_device
|
||||
single_file_kwargs["device"] = torch_device
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading: "
|
||||
f"pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
)
|
||||
|
||||
def test_single_file_model_parameters(self):
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
|
||||
pretrained_kwargs["device"] = torch_device
|
||||
single_file_kwargs["device"] = torch_device
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict_single_file = model_single_file.state_dict()
|
||||
|
||||
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
|
||||
"Model parameters keys differ between pretrained and single file loading. "
|
||||
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
|
||||
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
|
||||
)
|
||||
|
||||
for key in state_dict.keys():
|
||||
param = state_dict[key]
|
||||
param_single_file = state_dict_single_file[key]
|
||||
|
||||
assert param.shape == param_single_file.shape, (
|
||||
f"Parameter shape mismatch for {key}: "
|
||||
f"pretrained {param.shape} vs single file {param_single_file.shape}"
|
||||
)
|
||||
|
||||
assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
|
||||
f"Parameter values differ for {key}: "
|
||||
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
|
||||
)
|
||||
|
||||
def test_single_file_loading_local_files_only(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with local_files_only=True"
|
||||
|
||||
def test_single_file_loading_with_diffusers_config(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
# Load with config parameter
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
|
||||
)
|
||||
|
||||
# Load pretrained for comparison
|
||||
pretrained_kwargs = {}
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
|
||||
# Compare configs
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert (
|
||||
model.config[param_name] == param_value
|
||||
), f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
|
||||
def test_single_file_loading_with_diffusers_config_local_files_only(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
|
||||
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, tmpdir)
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
|
||||
|
||||
def test_single_file_loading_dtype(self):
|
||||
for dtype in [torch.float32, torch.float16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
|
||||
|
||||
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
|
||||
|
||||
# Cleanup
|
||||
del model_single_file
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_checkpoint_variant_loading(self):
|
||||
if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths:
|
||||
return
|
||||
|
||||
for ckpt_path in self.alternate_ckpt_paths:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
single_file_kwargs = {}
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
|
||||
|
||||
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
224
tests/models/testing_utils/training.py
Normal file
224
tests/models/testing_utils/training.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
||||
from ...testing_utils import is_training, require_torch_accelerator_with_training, torch_all_close, torch_device
|
||||
|
||||
|
||||
@is_training
|
||||
@require_torch_accelerator_with_training
|
||||
class TrainingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing training functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Expected properties to be implemented by subclasses:
|
||||
- output_shape: Tuple defining the expected output shape
|
||||
|
||||
Pytest mark: training
|
||||
Use `pytest -m "not training"` to skip these tests
|
||||
"""
|
||||
|
||||
def test_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
def test_training_with_ema(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
ema_model = EMAModel(model.parameters())
|
||||
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
ema_model.step(model.parameters())
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
# at init model should have gradient checkpointing disabled
|
||||
model = self.model_class(**init_dict)
|
||||
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
|
||||
|
||||
# check enable works
|
||||
model.enable_gradient_checkpointing()
|
||||
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
|
||||
|
||||
# check disable works
|
||||
model.disable_gradient_checkpointing()
|
||||
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self, expected_set=None):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
if expected_set is None:
|
||||
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
model_class_copy = copy.copy(self.model_class)
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
for submodule in model.modules():
|
||||
if hasattr(submodule, "gradient_checkpointing"):
|
||||
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
|
||||
modules_with_gc_enabled[submodule.__class__.__name__] = True
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == expected_set, (
|
||||
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} "
|
||||
f"do not match expected set {expected_set}"
|
||||
)
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
|
||||
|
||||
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
if skip is None:
|
||||
skip = set()
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict_copy = copy.deepcopy(inputs_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict)
|
||||
if isinstance(out, dict):
|
||||
out = out.sample if hasattr(out, "sample") else out.to_tuple()[0]
|
||||
|
||||
# run the backwards pass on the model
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
torch.manual_seed(0)
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict_copy)
|
||||
if isinstance(out_2, dict):
|
||||
out_2 = out_2.sample if hasattr(out_2, "sample") else out_2.to_tuple()[0]
|
||||
|
||||
# run the backwards pass on the model
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
assert (
|
||||
loss - loss_2
|
||||
).abs() < loss_tolerance, f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
|
||||
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
if name in skip:
|
||||
continue
|
||||
if param.grad is None:
|
||||
continue
|
||||
|
||||
assert torch_all_close(
|
||||
param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol
|
||||
), f"Gradient mismatch for {name}"
|
||||
|
||||
def test_mixed_precision_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# Test with float16
|
||||
if torch.device(torch_device).type != "cpu":
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# Test with bfloat16
|
||||
if torch.device(torch_device).type != "cpu":
|
||||
model.zero_grad()
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
316
tests/models/transformers/test_models_transformer_flux_.py
Normal file
316
tests/models/transformers/test_models_transformer_flux_.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
IPAdapterTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
ModelTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
SingleFileTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class FluxTransformerTesterConfig:
|
||||
model_class = FluxTransformer2DModel
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
|
||||
pretrained_model_kwargs = {"subfolder": "transformer"}
|
||||
|
||||
def get_init_dict(self):
|
||||
"""Return Flux model initialization arguments."""
|
||||
return {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"pooled_projection_dim": 32,
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 1
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
|
||||
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
"""Test that deprecated 3D img_ids and txt_ids still work."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
|
||||
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
|
||||
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
|
||||
|
||||
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
|
||||
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
|
||||
|
||||
inputs_dict["txt_ids"] = text_ids_3d
|
||||
inputs_dict["img_ids"] = image_ids_3d
|
||||
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
assert output_1.shape == output_2.shape
|
||||
assert torch.allclose(output_1, output_2, atol=1e-5), (
|
||||
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
|
||||
"are not equal as them as 2d inputs"
|
||||
)
|
||||
|
||||
|
||||
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for Flux Transformer."""
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
||||
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 0
|
||||
|
||||
for name in model.attn_processors.keys():
|
||||
if name.startswith("single_transformer_blocks"):
|
||||
continue
|
||||
|
||||
joint_attention_dim = model.config["joint_attention_dim"]
|
||||
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
|
||||
sd = FluxIPAdapterAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
ip_cross_attn_state_dict.update(
|
||||
{
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
|
||||
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
|
||||
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 1
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=model.config["joint_attention_dim"],
|
||||
image_embed_dim=(
|
||||
model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
|
||||
),
|
||||
num_image_text_embeds=4,
|
||||
)
|
||||
|
||||
ip_image_projection_state_dict = {}
|
||||
sd = image_projection.state_dict()
|
||||
ip_image_projection_state_dict.update(
|
||||
{
|
||||
"proj.weight": sd["image_embeds.weight"],
|
||||
"proj.bias": sd["image_embeds.bias"],
|
||||
"norm.weight": sd["norm.weight"],
|
||||
"norm.bias": sd["norm.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
del sd
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for Flux Transformer."""
|
||||
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height=4, width=4):
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height=4, width=4):
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
|
||||
subfolder = "transformer"
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
|
||||
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
@@ -98,9 +98,9 @@ class GGUFCudaKernelsTests(unittest.TestCase):
|
||||
output_native = linear.forward_native(x)
|
||||
output_cuda = linear.forward_cuda(x)
|
||||
|
||||
assert torch.allclose(output_native, output_cuda, 1e-2), (
|
||||
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
|
||||
)
|
||||
assert torch.allclose(
|
||||
output_native, output_cuda, 1e-2
|
||||
), f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
|
||||
|
||||
|
||||
@nightly
|
||||
|
||||
@@ -241,7 +241,6 @@ def parse_flag_from_env(key, default=False):
|
||||
|
||||
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||||
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
|
||||
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
|
||||
|
||||
|
||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
@@ -282,12 +281,128 @@ def nightly(test_case):
|
||||
|
||||
def is_torch_compile(test_case):
|
||||
"""
|
||||
Decorator marking a test that runs compile tests in the diffusers CI.
|
||||
|
||||
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
|
||||
|
||||
Decorator marking a test as a torch.compile test. These tests can be filtered using:
|
||||
pytest -m "not compile" to skip
|
||||
pytest -m compile to run only these tests
|
||||
"""
|
||||
return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
|
||||
return pytest.mark.compile(test_case)
|
||||
|
||||
|
||||
def is_single_file(test_case):
|
||||
"""
|
||||
Decorator marking a test as a single file loading test. These tests can be filtered using:
|
||||
pytest -m "not single_file" to skip
|
||||
pytest -m single_file to run only these tests
|
||||
"""
|
||||
return pytest.mark.single_file(test_case)
|
||||
|
||||
|
||||
def is_lora(test_case):
|
||||
"""
|
||||
Decorator marking a test as a LoRA test. These tests can be filtered using:
|
||||
pytest -m "not lora" to skip
|
||||
pytest -m lora to run only these tests
|
||||
"""
|
||||
return pytest.mark.lora(test_case)
|
||||
|
||||
|
||||
def is_ip_adapter(test_case):
|
||||
"""
|
||||
Decorator marking a test as an IP Adapter test. These tests can be filtered using:
|
||||
pytest -m "not ip_adapter" to skip
|
||||
pytest -m ip_adapter to run only these tests
|
||||
"""
|
||||
return pytest.mark.ip_adapter(test_case)
|
||||
|
||||
|
||||
def is_training(test_case):
|
||||
"""
|
||||
Decorator marking a test as a training test. These tests can be filtered using:
|
||||
pytest -m "not training" to skip
|
||||
pytest -m training to run only these tests
|
||||
"""
|
||||
return pytest.mark.training(test_case)
|
||||
|
||||
|
||||
def is_attention(test_case):
|
||||
"""
|
||||
Decorator marking a test as an attention test. These tests can be filtered using:
|
||||
pytest -m "not attention" to skip
|
||||
pytest -m attention to run only these tests
|
||||
"""
|
||||
return pytest.mark.attention(test_case)
|
||||
|
||||
|
||||
def is_memory(test_case):
|
||||
"""
|
||||
Decorator marking a test as a memory optimization test. These tests can be filtered using:
|
||||
pytest -m "not memory" to skip
|
||||
pytest -m memory to run only these tests
|
||||
"""
|
||||
return pytest.mark.memory(test_case)
|
||||
|
||||
|
||||
def is_cpu_offload(test_case):
|
||||
"""
|
||||
Decorator marking a test as a CPU offload test. These tests can be filtered using:
|
||||
pytest -m "not cpu_offload" to skip
|
||||
pytest -m cpu_offload to run only these tests
|
||||
"""
|
||||
return pytest.mark.cpu_offload(test_case)
|
||||
|
||||
|
||||
def is_group_offload(test_case):
|
||||
"""
|
||||
Decorator marking a test as a group offload test. These tests can be filtered using:
|
||||
pytest -m "not group_offload" to skip
|
||||
pytest -m group_offload to run only these tests
|
||||
"""
|
||||
return pytest.mark.group_offload(test_case)
|
||||
|
||||
|
||||
def is_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
|
||||
pytest -m "not bitsandbytes" to skip
|
||||
pytest -m bitsandbytes to run only these tests
|
||||
"""
|
||||
return pytest.mark.bitsandbytes(test_case)
|
||||
|
||||
|
||||
def is_quanto(test_case):
|
||||
"""
|
||||
Decorator marking a test as a Quanto quantization test. These tests can be filtered using:
|
||||
pytest -m "not quanto" to skip
|
||||
pytest -m quanto to run only these tests
|
||||
"""
|
||||
return pytest.mark.quanto(test_case)
|
||||
|
||||
|
||||
def is_torchao(test_case):
|
||||
"""
|
||||
Decorator marking a test as a TorchAO quantization test. These tests can be filtered using:
|
||||
pytest -m "not torchao" to skip
|
||||
pytest -m torchao to run only these tests
|
||||
"""
|
||||
return pytest.mark.torchao(test_case)
|
||||
|
||||
|
||||
def is_gguf(test_case):
|
||||
"""
|
||||
Decorator marking a test as a GGUF quantization test. These tests can be filtered using:
|
||||
pytest -m "not gguf" to skip
|
||||
pytest -m gguf to run only these tests
|
||||
"""
|
||||
return pytest.mark.gguf(test_case)
|
||||
|
||||
|
||||
def is_modelopt(test_case):
|
||||
"""
|
||||
Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using:
|
||||
pytest -m "not modelopt" to skip
|
||||
pytest -m modelopt to run only these tests
|
||||
"""
|
||||
return pytest.mark.modelopt(test_case)
|
||||
|
||||
|
||||
def require_torch(test_case):
|
||||
|
||||
Reference in New Issue
Block a user