mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-19 10:54:34 +08:00
Compare commits
15 Commits
pipeline-s
...
model-test
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c70de2bc37 | ||
|
|
e82001e40d | ||
|
|
d9b73ffd51 | ||
|
|
dcd6026d17 | ||
|
|
eae7543712 | ||
|
|
d08e0bb545 | ||
|
|
c366b5a817 | ||
|
|
0fdd9d3a60 | ||
|
|
489480b02a | ||
|
|
fe451c367b | ||
|
|
0f1a4e0c14 | ||
|
|
aa29af8f0e | ||
|
|
bffa3a9754 | ||
|
|
1c558712e8 | ||
|
|
1f026ad14e |
@@ -32,6 +32,21 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
|
||||
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
|
||||
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
|
||||
config.addinivalue_line("markers", "training: marks tests for training functionality")
|
||||
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
|
||||
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
|
||||
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
|
||||
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
|
||||
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
|
||||
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
|
||||
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
|
||||
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
|
||||
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
|
||||
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
|
||||
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
|
||||
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
||||
79
tests/models/testing_utils/__init__.py
Normal file
79
tests/models/testing_utils/__init__.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from .attention import AttentionTesterMixin
|
||||
from .cache import (
|
||||
CacheTesterMixin,
|
||||
FasterCacheConfigMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheConfigMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
PyramidAttentionBroadcastConfigMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
)
|
||||
from .common import BaseModelTesterConfig, ModelTesterMixin
|
||||
from .compile import TorchCompileTesterMixin
|
||||
from .ip_adapter import IPAdapterTesterMixin
|
||||
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
|
||||
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
|
||||
from .parallelism import ContextParallelTesterMixin
|
||||
from .quantization import (
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
BitsAndBytesConfigMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFConfigMixin,
|
||||
GGUFTesterMixin,
|
||||
ModelOptCompileTesterMixin,
|
||||
ModelOptConfigMixin,
|
||||
ModelOptTesterMixin,
|
||||
QuantizationCompileTesterMixin,
|
||||
QuantizationTesterMixin,
|
||||
QuantoCompileTesterMixin,
|
||||
QuantoConfigMixin,
|
||||
QuantoTesterMixin,
|
||||
TorchAoCompileTesterMixin,
|
||||
TorchAoConfigMixin,
|
||||
TorchAoTesterMixin,
|
||||
)
|
||||
from .single_file import SingleFileTesterMixin
|
||||
from .training import TrainingTesterMixin
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttentionTesterMixin",
|
||||
"BaseModelTesterConfig",
|
||||
"BitsAndBytesCompileTesterMixin",
|
||||
"BitsAndBytesConfigMixin",
|
||||
"BitsAndBytesTesterMixin",
|
||||
"CacheTesterMixin",
|
||||
"ContextParallelTesterMixin",
|
||||
"CPUOffloadTesterMixin",
|
||||
"FasterCacheConfigMixin",
|
||||
"FasterCacheTesterMixin",
|
||||
"FirstBlockCacheConfigMixin",
|
||||
"FirstBlockCacheTesterMixin",
|
||||
"GGUFCompileTesterMixin",
|
||||
"GGUFConfigMixin",
|
||||
"GGUFTesterMixin",
|
||||
"GroupOffloadTesterMixin",
|
||||
"IPAdapterTesterMixin",
|
||||
"LayerwiseCastingTesterMixin",
|
||||
"LoraHotSwappingForModelTesterMixin",
|
||||
"LoraTesterMixin",
|
||||
"MemoryTesterMixin",
|
||||
"ModelOptCompileTesterMixin",
|
||||
"ModelOptConfigMixin",
|
||||
"ModelOptTesterMixin",
|
||||
"ModelTesterMixin",
|
||||
"PyramidAttentionBroadcastConfigMixin",
|
||||
"PyramidAttentionBroadcastTesterMixin",
|
||||
"QuantizationCompileTesterMixin",
|
||||
"QuantizationTesterMixin",
|
||||
"QuantoCompileTesterMixin",
|
||||
"QuantoConfigMixin",
|
||||
"QuantoTesterMixin",
|
||||
"SingleFileTesterMixin",
|
||||
"TorchAoCompileTesterMixin",
|
||||
"TorchAoConfigMixin",
|
||||
"TorchAoTesterMixin",
|
||||
"TorchCompileTesterMixin",
|
||||
"TrainingTesterMixin",
|
||||
]
|
||||
185
tests/models/testing_utils/attention.py
Normal file
185
tests/models/testing_utils/attention.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention import AttentionModuleMixin
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
is_attention,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_attention
|
||||
class AttentionTesterMixin:
|
||||
"""
|
||||
Mixin class for testing attention processor and module functionality on models.
|
||||
|
||||
Tests functionality from AttentionModuleMixin including:
|
||||
- Attention processor management (set/get)
|
||||
- QKV projection fusion/unfusion
|
||||
- Attention backends (XFormers, NPU, etc.)
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
|
||||
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: attention
|
||||
Use `pytest -m "not attention"` to skip these tests
|
||||
"""
|
||||
|
||||
base_precision = 1e-3
|
||||
|
||||
def test_fuse_unfuse_qkv_projections(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if not hasattr(model, "fuse_qkv_projections"):
|
||||
pytest.skip("Model does not support QKV projection fusion.")
|
||||
|
||||
# Get output before fusion
|
||||
with torch.no_grad():
|
||||
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Fuse projections
|
||||
model.fuse_qkv_projections()
|
||||
|
||||
# Verify fusion occurred by checking for fused attributes
|
||||
has_fused_projections = False
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
|
||||
has_fused_projections = True
|
||||
assert module.fused_projections, "fused_projections flag should be True"
|
||||
break
|
||||
|
||||
if has_fused_projections:
|
||||
# Get output after fusion
|
||||
with torch.no_grad():
|
||||
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Verify outputs match
|
||||
assert_tensors_close(
|
||||
output_before_fusion,
|
||||
output_after_fusion,
|
||||
atol=self.base_precision,
|
||||
rtol=0,
|
||||
msg="Output should not change after fusing projections",
|
||||
)
|
||||
|
||||
# Unfuse projections
|
||||
model.unfuse_qkv_projections()
|
||||
|
||||
# Verify unfusion occurred
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
|
||||
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
|
||||
assert not module.fused_projections, "fused_projections flag should be False"
|
||||
|
||||
# Get output after unfusion
|
||||
with torch.no_grad():
|
||||
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Verify outputs still match
|
||||
assert_tensors_close(
|
||||
output_before_fusion,
|
||||
output_after_unfusion,
|
||||
atol=self.base_precision,
|
||||
rtol=0,
|
||||
msg="Output should match original after unfusing projections",
|
||||
)
|
||||
|
||||
def test_get_set_processor(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# Check if model has attention processors
|
||||
if not hasattr(model, "attn_processors"):
|
||||
pytest.skip("Model does not have attention processors.")
|
||||
|
||||
# Test getting processors
|
||||
processors = model.attn_processors
|
||||
assert isinstance(processors, dict), "attn_processors should return a dict"
|
||||
assert len(processors) > 0, "Model should have at least one attention processor"
|
||||
|
||||
# Test that all processors can be retrieved via get_processor
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
processor = module.get_processor()
|
||||
assert processor is not None, "get_processor should return a processor"
|
||||
|
||||
# Test setting a new processor
|
||||
new_processor = AttnProcessor()
|
||||
module.set_processor(new_processor)
|
||||
retrieved_processor = module.get_processor()
|
||||
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
|
||||
|
||||
def test_attention_processor_dict(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
pytest.skip("Model does not support setting attention processors.")
|
||||
|
||||
# Get current processors
|
||||
current_processors = model.attn_processors
|
||||
|
||||
# Create a dict of new processors
|
||||
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
|
||||
|
||||
# Set processors using dict
|
||||
model.set_attn_processor(new_processors)
|
||||
|
||||
# Verify all processors were set
|
||||
updated_processors = model.attn_processors
|
||||
for key in current_processors.keys():
|
||||
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
|
||||
|
||||
def test_attention_processor_count_mismatch_raises_error(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
pytest.skip("Model does not support setting attention processors.")
|
||||
|
||||
# Get current processors
|
||||
current_processors = model.attn_processors
|
||||
|
||||
# Create a dict with wrong number of processors
|
||||
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
|
||||
|
||||
# Verify error is raised
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.set_attn_processor(wrong_processors)
|
||||
|
||||
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
|
||||
536
tests/models/testing_utils/cache.py
Normal file
536
tests/models/testing_utils/cache.py
Normal file
@@ -0,0 +1,536 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig
|
||||
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from diffusers.models.cache_utils import CacheMixin
|
||||
|
||||
from ...testing_utils import backend_empty_cache, is_cache, torch_device
|
||||
|
||||
|
||||
def require_cache_mixin(func):
|
||||
"""Decorator to skip tests if model doesn't use CacheMixin."""
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not issubclass(self.model_class, CacheMixin):
|
||||
pytest.skip(f"{self.model_class.__name__} does not use CacheMixin.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CacheTesterMixin:
|
||||
"""
|
||||
Base mixin class providing common test implementations for cache testing.
|
||||
|
||||
Cache-specific mixins should:
|
||||
1. Inherit from their respective config mixin (e.g., PyramidAttentionBroadcastConfigMixin)
|
||||
2. Inherit from this mixin
|
||||
3. Define the cache config to use for tests
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods in test classes:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def _get_cache_config(self):
|
||||
"""
|
||||
Get the cache config for testing.
|
||||
Should be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement _get_cache_config")
|
||||
|
||||
def _get_hook_names(self):
|
||||
"""
|
||||
Get the hook names to check for this cache type.
|
||||
Should be implemented by subclasses.
|
||||
Returns a list of hook name strings.
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement _get_hook_names")
|
||||
|
||||
def _test_cache_enable_disable_state(self):
|
||||
"""Test that cache enable/disable updates the is_cache_enabled state correctly."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Initially cache should not be enabled
|
||||
assert not model.is_cache_enabled, "Cache should not be enabled initially."
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
# Enable cache
|
||||
model.enable_cache(config)
|
||||
assert model.is_cache_enabled, "Cache should be enabled after enable_cache()."
|
||||
|
||||
# Disable cache
|
||||
model.disable_cache()
|
||||
assert not model.is_cache_enabled, "Cache should not be enabled after disable_cache()."
|
||||
|
||||
def _test_cache_double_enable_raises_error(self):
|
||||
"""Test that enabling cache twice raises an error."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Trying to enable again should raise ValueError
|
||||
with pytest.raises(ValueError, match="Caching has already been enabled"):
|
||||
model.enable_cache(config)
|
||||
|
||||
# Cleanup
|
||||
model.disable_cache()
|
||||
|
||||
def _test_cache_hooks_registered(self):
|
||||
"""Test that cache hooks are properly registered and removed."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
config = self._get_cache_config()
|
||||
hook_names = self._get_hook_names()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Check that at least one hook was registered
|
||||
hook_count = 0
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
for hook_name in hook_names:
|
||||
hook = module._diffusers_hook.get_hook(hook_name)
|
||||
if hook is not None:
|
||||
hook_count += 1
|
||||
|
||||
assert hook_count > 0, f"At least one cache hook should be registered. Hook names: {hook_names}"
|
||||
|
||||
# Disable and verify hooks are removed
|
||||
model.disable_cache()
|
||||
|
||||
hook_count_after = 0
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
for hook_name in hook_names:
|
||||
hook = module._diffusers_hook.get_hook(hook_name)
|
||||
if hook is not None:
|
||||
hook_count_after += 1
|
||||
|
||||
assert hook_count_after == 0, "Cache hooks should be removed after disable_cache()."
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_inference(self):
|
||||
"""Test that model can run inference with cache enabled."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# First pass populates the cache
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create modified inputs for second pass (vary hidden_states to simulate denoising)
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if "hidden_states" in inputs_dict_step2:
|
||||
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
|
||||
|
||||
# Second pass uses cached attention with different hidden_states (produces approximated output)
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||
|
||||
# Run same inputs without cache to compare
|
||||
model.disable_cache()
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
def _test_cache_context_manager(self):
|
||||
"""Test the cache_context context manager."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Test cache_context works without error
|
||||
with model.cache_context("test_context"):
|
||||
pass
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the cache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Run forward to populate cache state
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Reset should not raise any errors
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
|
||||
@is_cache
|
||||
class PyramidAttentionBroadcastConfigMixin:
|
||||
"""
|
||||
Base mixin providing PyramidAttentionBroadcast cache config.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
"""
|
||||
|
||||
# Default PAB config - can be overridden by subclasses
|
||||
PAB_CONFIG = {
|
||||
"spatial_attention_block_skip_range": 2,
|
||||
}
|
||||
|
||||
# Store timestep for callback (must be within default range (100, 800) for skipping to trigger)
|
||||
_current_timestep = 500
|
||||
|
||||
def _get_cache_config(self):
|
||||
config_kwargs = self.PAB_CONFIG.copy()
|
||||
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep
|
||||
return PyramidAttentionBroadcastConfig(**config_kwargs)
|
||||
|
||||
def _get_hook_names(self):
|
||||
return [_PYRAMID_ATTENTION_BROADCAST_HOOK]
|
||||
|
||||
|
||||
@is_cache
|
||||
class PyramidAttentionBroadcastTesterMixin(PyramidAttentionBroadcastConfigMixin, CacheTesterMixin):
|
||||
"""
|
||||
Mixin class for testing PyramidAttentionBroadcast caching on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cache
|
||||
Use `pytest -m "not cache"` to skip these tests
|
||||
"""
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_enable_disable_state(self):
|
||||
self._test_cache_enable_disable_state()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_double_enable_raises_error(self):
|
||||
self._test_cache_double_enable_raises_error()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_hooks_registered(self):
|
||||
self._test_cache_hooks_registered()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_inference(self):
|
||||
self._test_cache_inference()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_context_manager(self):
|
||||
self._test_cache_context_manager()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_reset_stateful_cache(self):
|
||||
self._test_reset_stateful_cache()
|
||||
|
||||
|
||||
@is_cache
|
||||
class FirstBlockCacheConfigMixin:
|
||||
"""
|
||||
Base mixin providing FirstBlockCache config.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
"""
|
||||
|
||||
# Default FBC config - can be overridden by subclasses
|
||||
# Higher threshold makes FBC more aggressive about caching (skips more often)
|
||||
FBC_CONFIG = {
|
||||
"threshold": 1.0,
|
||||
}
|
||||
|
||||
def _get_cache_config(self):
|
||||
return FirstBlockCacheConfig(**self.FBC_CONFIG)
|
||||
|
||||
def _get_hook_names(self):
|
||||
return [_FBC_LEADER_BLOCK_HOOK, _FBC_BLOCK_HOOK]
|
||||
|
||||
|
||||
@is_cache
|
||||
class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
|
||||
"""
|
||||
Mixin class for testing FirstBlockCache on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cache
|
||||
Use `pytest -m "not cache"` to skip these tests
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_inference(self):
|
||||
"""Test that model can run inference with FBC cache enabled (requires cache_context)."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
# FBC requires cache_context to be set for inference
|
||||
with model.cache_context("fbc_test"):
|
||||
# First pass populates the cache
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create modified inputs for second pass (small perturbation keeps residuals similar)
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if "hidden_states" in inputs_dict_step2:
|
||||
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.01
|
||||
|
||||
# Second pass - FBC should skip remaining blocks and use cached residuals
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||
|
||||
# Run same inputs without cache to compare
|
||||
model.disable_cache()
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
# FBC requires cache_context to be set for inference
|
||||
with model.cache_context("fbc_test"):
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Reset should not raise any errors
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_enable_disable_state(self):
|
||||
self._test_cache_enable_disable_state()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_double_enable_raises_error(self):
|
||||
self._test_cache_double_enable_raises_error()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_hooks_registered(self):
|
||||
self._test_cache_hooks_registered()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_inference(self):
|
||||
self._test_cache_inference()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_context_manager(self):
|
||||
self._test_cache_context_manager()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_reset_stateful_cache(self):
|
||||
self._test_reset_stateful_cache()
|
||||
|
||||
|
||||
@is_cache
|
||||
class FasterCacheConfigMixin:
|
||||
"""
|
||||
Base mixin providing FasterCache config.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
"""
|
||||
|
||||
# Default FasterCache config - can be overridden by subclasses
|
||||
FASTER_CACHE_CONFIG = {
|
||||
"spatial_attention_block_skip_range": 2,
|
||||
"spatial_attention_timestep_skip_range": (-1, 901),
|
||||
"tensor_format": "BCHW",
|
||||
}
|
||||
|
||||
# Store timestep for callback - use a list so it can be mutated during test
|
||||
# Starts outside skip range so first pass computes; changed to inside range for subsequent passes
|
||||
_current_timestep = [1000]
|
||||
|
||||
def _get_cache_config(self):
|
||||
config_kwargs = self.FASTER_CACHE_CONFIG.copy()
|
||||
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep[0]
|
||||
return FasterCacheConfig(**config_kwargs)
|
||||
|
||||
def _get_hook_names(self):
|
||||
return [_FASTER_CACHE_DENOISER_HOOK, _FASTER_CACHE_BLOCK_HOOK]
|
||||
|
||||
|
||||
@is_cache
|
||||
class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
|
||||
"""
|
||||
Mixin class for testing FasterCache on models.
|
||||
|
||||
Note: FasterCache is designed for pipeline-level inference with proper CFG batch handling
|
||||
and timestep management. Inference tests are skipped at model level - FasterCache should
|
||||
be tested via pipeline tests (e.g., FluxPipeline, HunyuanVideoPipeline).
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cache
|
||||
Use `pytest -m "not cache"` to skip these tests
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_inference(self):
|
||||
"""Test that model can run inference with FasterCache enabled."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# First pass with timestep outside skip range - computes and populates cache
|
||||
self._current_timestep[0] = 1000
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Move timestep inside skip range so subsequent passes use cache
|
||||
self._current_timestep[0] = 500
|
||||
|
||||
# Create modified inputs for second pass
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if "hidden_states" in inputs_dict_step2:
|
||||
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
|
||||
|
||||
# Second pass uses cached attention with different hidden_states
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||
|
||||
# Run same inputs without cache to compare
|
||||
model.disable_cache()
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the FasterCache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
# First pass with timestep outside skip range
|
||||
self._current_timestep[0] = 1000
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Reset should not raise any errors
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_enable_disable_state(self):
|
||||
self._test_cache_enable_disable_state()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_double_enable_raises_error(self):
|
||||
self._test_cache_double_enable_raises_error()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_hooks_registered(self):
|
||||
self._test_cache_hooks_registered()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_inference(self):
|
||||
self._test_cache_inference()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_context_manager(self):
|
||||
self._test_cache_context_manager()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_reset_stateful_cache(self):
|
||||
self._test_reset_stateful_cache()
|
||||
649
tests/models/testing_utils/common.py
Normal file
649
tests/models/testing_utils/common.py
Normal file
@@ -0,0 +1,649 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
||||
|
||||
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
|
||||
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
|
||||
|
||||
from ...testing_utils import assert_tensors_close, torch_device
|
||||
|
||||
|
||||
def named_persistent_module_tensors(
|
||||
module: nn.Module,
|
||||
recurse: bool = False,
|
||||
):
|
||||
"""
|
||||
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module we want the tensors on.
|
||||
recurse (`bool`, *optional`, defaults to `False`):
|
||||
Whether or not to go look in every submodule or just return the direct parameters and buffers.
|
||||
"""
|
||||
yield from module.named_parameters(recurse=recurse)
|
||||
|
||||
for named_buffer in module.named_buffers(recurse=recurse):
|
||||
name, _ = named_buffer
|
||||
# Get parent by splitting on dots and traversing the model
|
||||
parent = module
|
||||
if "." in name:
|
||||
parent_name = name.rsplit(".", 1)[0]
|
||||
for part in parent_name.split("."):
|
||||
parent = getattr(parent, part)
|
||||
name = name.split(".")[-1]
|
||||
if name not in parent._non_persistent_buffers_set:
|
||||
yield named_buffer
|
||||
|
||||
|
||||
def compute_module_persistent_sizes(
|
||||
model: nn.Module,
|
||||
dtype: str | torch.device | None = None,
|
||||
special_dtypes: dict[str, str | torch.device] | None = None,
|
||||
):
|
||||
"""
|
||||
Compute the size of each submodule of a given model (parameters + persistent buffers).
|
||||
"""
|
||||
if dtype is not None:
|
||||
dtype = _get_proper_dtype(dtype)
|
||||
dtype_size = dtype_byte_size(dtype)
|
||||
if special_dtypes is not None:
|
||||
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
module_sizes = defaultdict(int)
|
||||
|
||||
module_list = []
|
||||
|
||||
module_list = named_persistent_module_tensors(model, recurse=True)
|
||||
|
||||
for name, tensor in module_list:
|
||||
if special_dtypes is not None and name in special_dtypes:
|
||||
size = tensor.numel() * special_dtypes_size[name]
|
||||
elif dtype is None:
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
# According to the code in set_module_tensor_to_device, these types won't be converted
|
||||
# so use their original size here
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
else:
|
||||
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
|
||||
name_parts = name.split(".")
|
||||
for idx in range(len(name_parts) + 1):
|
||||
module_sizes[".".join(name_parts[:idx])] += size
|
||||
|
||||
return module_sizes
|
||||
|
||||
|
||||
def calculate_expected_num_shards(index_map_path):
|
||||
"""
|
||||
Calculate expected number of shards from index file.
|
||||
|
||||
Args:
|
||||
index_map_path: Path to the sharded checkpoint index file
|
||||
|
||||
Returns:
|
||||
int: Expected number of shards
|
||||
"""
|
||||
with open(index_map_path) as f:
|
||||
weight_map_dict = json.load(f)["weight_map"]
|
||||
first_key = list(weight_map_dict.keys())[0]
|
||||
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
|
||||
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
|
||||
return expected_num_shards
|
||||
|
||||
|
||||
def check_device_map_is_respected(model, device_map):
|
||||
for param_name, param in model.named_parameters():
|
||||
# Find device in device_map
|
||||
while len(param_name) > 0 and param_name not in device_map:
|
||||
param_name = ".".join(param_name.split(".")[:-1])
|
||||
if param_name not in device_map:
|
||||
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
|
||||
|
||||
param_device = device_map[param_name]
|
||||
if param_device in ["cpu", "disk"]:
|
||||
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
|
||||
else:
|
||||
assert param.device == torch.device(param_device), (
|
||||
f"Expected device {param_device} for {param_name}, got {param.device}"
|
||||
)
|
||||
|
||||
|
||||
class BaseModelTesterConfig:
|
||||
"""
|
||||
Base class defining the configuration interface for model testing.
|
||||
|
||||
This class defines the contract that all model test classes must implement.
|
||||
It provides a consistent interface for accessing model configuration, initialization
|
||||
parameters, and test inputs across all testing mixins.
|
||||
|
||||
Required properties (must be implemented by subclasses):
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties (can be overridden, have sensible defaults):
|
||||
- pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None)
|
||||
- pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {})
|
||||
- output_shape: Expected output shape for output validation tests (default: None)
|
||||
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
|
||||
- model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7])
|
||||
|
||||
Required methods (must be implemented by subclasses):
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Example usage:
|
||||
class MyModelTestConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return MyModel
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "org/my-model"
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 3, 32, 32)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {"in_channels": 3, "out_channels": 3}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)}
|
||||
|
||||
class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin):
|
||||
pass
|
||||
"""
|
||||
|
||||
# ==================== Required Properties ====================
|
||||
|
||||
@property
|
||||
def model_class(self) -> Type[nn.Module]:
|
||||
"""The model class to test. Must be implemented by subclasses."""
|
||||
raise NotImplementedError("Subclasses must implement the `model_class` property.")
|
||||
|
||||
# ==================== Optional Properties ====================
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self) -> Optional[str]:
|
||||
"""Hub repository ID for the pretrained model (used for quantization and hub tests)."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def pretrained_model_kwargs(self) -> Dict[str, Any]:
|
||||
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def output_shape(self) -> Optional[tuple]:
|
||||
"""Expected output shape for output validation tests."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
"""Percentages for model parallelism tests."""
|
||||
return [0.5, 0.7]
|
||||
|
||||
# ==================== Required Methods ====================
|
||||
|
||||
def get_init_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns dict of arguments to initialize the model.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Initialization arguments for the model constructor.
|
||||
|
||||
Example:
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"sample_size": 32,
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement `get_init_dict()`.")
|
||||
|
||||
def get_dummy_inputs(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns dict of inputs to pass to the model forward pass.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Input tensors/values for model.forward().
|
||||
|
||||
Example:
|
||||
return {
|
||||
"sample": torch.randn(1, 3, 32, 32, device=torch_device),
|
||||
"timestep": torch.tensor([1], device=torch_device),
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.")
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
"""
|
||||
Base mixin class for model testing with common test methods.
|
||||
|
||||
This mixin expects the test class to also inherit from BaseModelTesterConfig
|
||||
(or implement its interface) which provides:
|
||||
- model_class: The model class to test
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Example:
|
||||
class MyModelTestConfig(BaseModelTesterConfig):
|
||||
model_class = MyModel
|
||||
def get_init_dict(self): ...
|
||||
def get_dummy_inputs(self): ...
|
||||
|
||||
class TestMyModel(MyModelTestConfig, ModelTesterMixin):
|
||||
pass
|
||||
"""
|
||||
|
||||
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.save_pretrained(tmp_path)
|
||||
new_model = self.model_class.from_pretrained(tmp_path)
|
||||
new_model.to(torch_device)
|
||||
|
||||
# check if all parameters shape are the same
|
||||
for param_name in model.state_dict().keys():
|
||||
param_1 = model.state_dict()[param_name]
|
||||
param_2 = new_model.state_dict()[param_name]
|
||||
assert param_1.shape == param_2.shape, (
|
||||
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.save_pretrained(tmp_path, variant="fp16")
|
||||
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
|
||||
|
||||
# non-variant cannot be loaded
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
self.model_class.from_pretrained(tmp_path)
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype(self, tmp_path, dtype):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
pytest.skip(reason=f"{dtype} is not supported on {torch_device}")
|
||||
|
||||
model.to(dtype)
|
||||
model.save_pretrained(tmp_path)
|
||||
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None:
|
||||
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
|
||||
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
def test_determinism(self, atol=1e-5, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
# Filter out NaN values before comparison
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
|
||||
first_filtered = first_flat[mask]
|
||||
second_filtered = second_flat[mask]
|
||||
|
||||
assert_tensors_close(
|
||||
first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
|
||||
)
|
||||
|
||||
def test_output(self, expected_output_shape=None):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
assert output[0].shape == expected_output_shape or self.output_shape, (
|
||||
f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
|
||||
)
|
||||
|
||||
def test_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
|
||||
# Track progress in https://github.com/pytorch/pytorch/issues/77764
|
||||
device = t.device
|
||||
if device.type == "mps":
|
||||
t = t.to("cpu")
|
||||
t[t != t] = 0
|
||||
return t.to(device)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (list, tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
assert_tensors_close(
|
||||
set_nan_tensor_to_zero(tuple_object),
|
||||
set_nan_tensor_to_zero(dict_object),
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Tuple and dict output are not equal",
|
||||
)
|
||||
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
def test_getattr_is_correct(self, caplog):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
# save some things to test
|
||||
model.dummy_attribute = 5
|
||||
model.register_to_config(test_attribute=5)
|
||||
|
||||
logger_name = "diffusers.models.modeling_utils"
|
||||
with caplog.at_level(logging.WARNING, logger=logger_name):
|
||||
caplog.clear()
|
||||
assert hasattr(model, "dummy_attribute")
|
||||
assert getattr(model, "dummy_attribute") == 5
|
||||
assert model.dummy_attribute == 5
|
||||
|
||||
# no warning should be thrown
|
||||
assert caplog.text == ""
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger=logger_name):
|
||||
caplog.clear()
|
||||
assert hasattr(model, "save_pretrained")
|
||||
fn = model.save_pretrained
|
||||
fn_1 = getattr(model, "save_pretrained")
|
||||
|
||||
assert fn == fn_1
|
||||
|
||||
# no warning should be thrown
|
||||
assert caplog.text == ""
|
||||
|
||||
# warning should be thrown for config attributes accessed directly
|
||||
with pytest.warns(FutureWarning):
|
||||
assert model.test_attribute == 5
|
||||
|
||||
with pytest.warns(FutureWarning):
|
||||
assert getattr(model, "test_attribute") == 5
|
||||
|
||||
with pytest.raises(AttributeError) as error:
|
||||
model.does_not_exist
|
||||
|
||||
assert str(error.value) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
|
||||
|
||||
@require_accelerator
|
||||
@pytest.mark.skipif(
|
||||
torch_device not in ["cuda", "xpu"],
|
||||
reason="float16 and bfloat16 can only be used with an accelerator",
|
||||
)
|
||||
def test_keep_in_fp32_modules(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
fp32_modules = model._keep_in_fp32_modules
|
||||
|
||||
if fp32_modules is None or len(fp32_modules) == 0:
|
||||
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
|
||||
|
||||
# Test with float16
|
||||
model.to(torch_device)
|
||||
model.to(torch.float16)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
assert param.dtype == torch.float32, f"Parameter {name} should be float32 but got {param.dtype}"
|
||||
else:
|
||||
assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}"
|
||||
|
||||
@require_accelerator
|
||||
@pytest.mark.skipif(
|
||||
torch_device not in ["cuda", "xpu"],
|
||||
reason="float16 and bfloat16 can only be use for inference with an accelerator",
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
fp32_modules = model._keep_in_fp32_modules
|
||||
|
||||
model.to(dtype).save_pretrained(tmp_path)
|
||||
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)
|
||||
|
||||
for name, param in model_loaded.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
assert param.data.dtype == torch.float32
|
||||
else:
|
||||
assert param.data.dtype == dtype
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}")
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints(self, tmp_path):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_path).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after sharded save/load"
|
||||
)
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints_with_variant(self, tmp_path):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
variant = "fp16"
|
||||
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant)
|
||||
|
||||
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
assert os.path.exists(os.path.join(tmp_path, index_filename)), (
|
||||
f"Variant index file {index_filename} should exist"
|
||||
)
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load"
|
||||
)
|
||||
|
||||
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path):
|
||||
from diffusers.utils import constants
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
|
||||
# Save original values to restore after test
|
||||
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
|
||||
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
|
||||
|
||||
try:
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
|
||||
# Load without parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = False
|
||||
model_sequential = self.model_class.from_pretrained(tmp_path).eval()
|
||||
model_sequential = model_sequential.to(torch_device)
|
||||
|
||||
# Load with parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = True
|
||||
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
|
||||
|
||||
torch.manual_seed(0)
|
||||
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_parallel = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, output_parallel, atol=1e-5, rtol=0, msg="Output should match with parallel loading"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Restore original values
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
|
||||
if original_parallel_workers is not None:
|
||||
constants.HF_PARALLEL_WORKERS = original_parallel_workers
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_model_parallelism(self, tmp_path):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
|
||||
model.cpu().save_pretrained(tmp_path)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will be on GPU 0 and GPU 1
|
||||
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match with model parallelism"
|
||||
)
|
||||
160
tests/models/testing_utils/compile.py
Normal file
160
tests/models/testing_utils/compile.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_torch_compile,
|
||||
require_accelerator,
|
||||
require_torch_version_greater,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_torch_compile
|
||||
@require_accelerator
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class TorchCompileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing torch.compile functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: compile
|
||||
Use `pytest -m "not compile"` to skip these tests
|
||||
"""
|
||||
|
||||
different_shapes_for_compilation = None
|
||||
|
||||
def setup_method(self):
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
torch.no_grad(),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
if self.model_class._repeated_blocks is None:
|
||||
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
recompile_limit = 1
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
torch.no_grad(),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_with_group_offloading(self):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
group_offload_kwargs = {
|
||||
"onload_device": torch_device,
|
||||
"offload_device": "cpu",
|
||||
"offload_type": "block_level",
|
||||
"num_blocks_per_group": 1,
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
}
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
model.compile()
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_on_different_shapes(self):
|
||||
if self.different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True, dynamic=True)
|
||||
|
||||
for height, width in self.different_shapes_for_compilation:
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
inputs_dict = self.get_dummy_inputs(height=height, width=width)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_works_with_aot(self, tmp_path):
|
||||
from torch._inductor.package import load_package
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
|
||||
|
||||
package_path = os.path.join(str(tmp_path), f"{self.model_class.__name__}.pt2")
|
||||
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
|
||||
assert os.path.exists(package_path), f"Package file not created at {package_path}"
|
||||
loaded_binary = load_package(package_path, run_single_threaded=True)
|
||||
|
||||
model.forward = loaded_binary
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
138
tests/models/testing_utils/ip_adapter.py
Normal file
138
tests/models/testing_utils/ip_adapter.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ...testing_utils import is_ip_adapter, torch_device
|
||||
|
||||
|
||||
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
|
||||
"""
|
||||
Check if IP Adapter processors are correctly set in the model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
|
||||
Returns:
|
||||
bool: True if IP Adapter is correctly set, False otherwise
|
||||
"""
|
||||
for module in model.attn_processors.values():
|
||||
if isinstance(module, processor_cls):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@is_ip_adapter
|
||||
class IPAdapterTesterMixin:
|
||||
"""
|
||||
Mixin class for testing IP Adapter functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: ip_adapter
|
||||
Use `pytest -m "not ip_adapter"` to skip these tests
|
||||
"""
|
||||
|
||||
ip_adapter_processor_cls = None
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
|
||||
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
raise NotImplementedError("child class must implement method to create IPAdapter model inputs")
|
||||
|
||||
def test_load_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
|
||||
"IP Adapter processors not set correctly"
|
||||
)
|
||||
|
||||
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
|
||||
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), (
|
||||
"Output should differ with IP Adapter enabled"
|
||||
)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring"
|
||||
)
|
||||
def test_ip_adapter_scale(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
|
||||
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
|
||||
|
||||
# Test scale = 0.0 (no effect)
|
||||
model.set_ip_adapter_scale(0.0)
|
||||
torch.manual_seed(0)
|
||||
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Test scale = 1.0 (full effect)
|
||||
model.set_ip_adapter_scale(1.0)
|
||||
torch.manual_seed(0)
|
||||
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Outputs should differ with different scales
|
||||
assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), (
|
||||
"Output should differ with different IP Adapter scales"
|
||||
)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring"
|
||||
)
|
||||
def test_unload_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Save original processors
|
||||
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||
|
||||
# Create and load IP adapter
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
|
||||
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set"
|
||||
|
||||
# Unload IP adapter
|
||||
model.unload_ip_adapter()
|
||||
|
||||
assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
|
||||
"IP Adapter should be unloaded"
|
||||
)
|
||||
|
||||
# Verify processors are restored
|
||||
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||
assert original_processors == current_processors, "Processors should be restored after unload"
|
||||
548
tests/models/testing_utils/lora.py
Normal file
548
tests/models/testing_utils/lora.py
Normal file
@@ -0,0 +1,548 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
from diffusers.utils.testing_utils import check_if_dicts_are_equal
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_lora,
|
||||
is_torch_compile,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_accelerator,
|
||||
require_torch_version_greater,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Check if LoRA layers are correctly set in the model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
|
||||
Returns:
|
||||
bool: True if LoRA is correctly set, False otherwise
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@is_lora
|
||||
@require_peft_backend
|
||||
class LoraTesterMixin:
|
||||
"""
|
||||
Mixin class for testing LoRA/PEFT functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: lora
|
||||
Use `pytest -m "not lora"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||
|
||||
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4), (
|
||||
"Output should differ with LoRA enabled"
|
||||
)
|
||||
|
||||
model.save_lora_adapter(tmp_path)
|
||||
assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), (
|
||||
"LoRA weights file not created"
|
||||
)
|
||||
|
||||
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||
|
||||
for k in state_dict_loaded:
|
||||
loaded_v = state_dict_loaded[k]
|
||||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||
assert_tensors_close(loaded_v, retrieved_v, atol=1e-5, rtol=0, msg=f"Mismatch in LoRA weight {k}")
|
||||
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), (
|
||||
"Output should differ with LoRA enabled"
|
||||
)
|
||||
assert_tensors_close(
|
||||
outputs_with_lora,
|
||||
outputs_with_lora_2,
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
msg="Outputs should match before and after save/load",
|
||||
)
|
||||
|
||||
def test_lora_wrong_adapter_name_raises_error(self, tmp_path):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
wrong_name = "foo"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.save_lora_adapter(tmp_path, adapter_name=wrong_name)
|
||||
|
||||
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
|
||||
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
metadata = model.peft_config["default"].to_dict()
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
model.save_lora_adapter(tmp_path)
|
||||
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
model.save_lora_adapter(tmp_path)
|
||||
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
# Perturb the metadata in the state dict
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
metadata = {"format": "pt"}
|
||||
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
|
||||
|
||||
|
||||
@is_lora
|
||||
@is_torch_compile
|
||||
@require_peft_backend
|
||||
@require_peft_version_greater("0.14.0")
|
||||
@require_torch_version_greater("2.7.1")
|
||||
@require_torch_accelerator
|
||||
class LoraHotSwappingForModelTesterMixin:
|
||||
"""
|
||||
Mixin class for testing LoRA hot swapping functionality on models.
|
||||
|
||||
Test that hotswapping does not result in recompilation on the model directly.
|
||||
We're not extensively testing the hotswapping functionality since it is implemented in PEFT
|
||||
and is extensively tested there. The goal of this test is specifically to ensure that
|
||||
hotswapping with diffusers does not require recompilation.
|
||||
|
||||
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
|
||||
for the analogous PEFT test.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic compilation tests
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest marks: lora, torch_compile
|
||||
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
|
||||
"""
|
||||
|
||||
different_shapes_for_compilation = None
|
||||
|
||||
def setup_method(self):
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||
|
||||
def teardown_method(self):
|
||||
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
|
||||
# there will be recompilation errors, as torch caches the model when run in the same process.
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def _get_lora_config(self, lora_rank, lora_alpha, target_modules):
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
return lora_config
|
||||
|
||||
def _get_linear_module_name_other_than_attn(self, model):
|
||||
linear_names = [
|
||||
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
|
||||
]
|
||||
return linear_names[0]
|
||||
|
||||
def _check_model_hotswap(self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None):
|
||||
"""
|
||||
Check that hotswapping works on a model.
|
||||
|
||||
Steps:
|
||||
- create 2 LoRA adapters and save them
|
||||
- load the first adapter
|
||||
- hotswap the second adapter
|
||||
- check that the outputs are correct
|
||||
- optionally compile the model
|
||||
- optionally check if recompilations happen on different shapes
|
||||
|
||||
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
|
||||
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
|
||||
fine.
|
||||
"""
|
||||
different_shapes = self.different_shapes_for_compilation
|
||||
# create 2 adapters with different ranks and alphas
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
alpha0, alpha1 = rank0, rank1
|
||||
max_rank = max([rank0, rank1])
|
||||
if target_modules1 is None:
|
||||
target_modules1 = target_modules0[:]
|
||||
lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0)
|
||||
lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1)
|
||||
|
||||
model.add_adapter(lora_config0, adapter_name="adapter0")
|
||||
with torch.inference_mode():
|
||||
torch.manual_seed(0)
|
||||
output0_before = model(**inputs_dict)["sample"]
|
||||
|
||||
model.add_adapter(lora_config1, adapter_name="adapter1")
|
||||
model.set_adapter("adapter1")
|
||||
with torch.inference_mode():
|
||||
torch.manual_seed(0)
|
||||
output1_before = model(**inputs_dict)["sample"]
|
||||
|
||||
# sanity checks:
|
||||
tol = 5e-3
|
||||
assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol)
|
||||
assert not (output0_before == 0).all()
|
||||
assert not (output1_before == 0).all()
|
||||
|
||||
# save the adapter checkpoints
|
||||
model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0")
|
||||
model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1")
|
||||
del model
|
||||
|
||||
# load the first adapter
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if do_compile or (rank0 != rank1):
|
||||
# no need to prepare if the model is not compiled or if the ranks are identical
|
||||
model.enable_lora_hotswap(target_rank=max_rank)
|
||||
|
||||
file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors")
|
||||
file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors")
|
||||
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
||||
|
||||
if do_compile:
|
||||
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
|
||||
|
||||
with torch.inference_mode():
|
||||
# additionally check if dynamic compilation works.
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output0_after = model(**inputs_dict)["sample"]
|
||||
assert_tensors_close(
|
||||
output0_before, output0_after, atol=tol, rtol=tol, msg="Output mismatch after loading adapter0"
|
||||
)
|
||||
|
||||
# hotswap the 2nd adapter
|
||||
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
||||
|
||||
# we need to call forward to potentially trigger recompilation
|
||||
with torch.inference_mode():
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output1_after = model(**inputs_dict)["sample"]
|
||||
assert_tensors_close(
|
||||
output1_before,
|
||||
output1_after,
|
||||
atol=tol,
|
||||
rtol=tol,
|
||||
msg="Output mismatch after hotswapping to adapter1",
|
||||
)
|
||||
|
||||
# check error when not passing valid adapter name
|
||||
name = "does-not-exist"
|
||||
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
|
||||
with pytest.raises(ValueError, match=re.escape(msg)):
|
||||
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_model(self, tmp_path, rank0, rank1):
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1):
|
||||
if "unet" not in self.model_class.__name__.lower():
|
||||
pytest.skip("Test only applies to UNet.")
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["conv", "conv1", "conv2"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1):
|
||||
if "unet" not in self.model_class.__name__.lower():
|
||||
pytest.skip("Test only applies to UNet.")
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "conv"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1):
|
||||
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
|
||||
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
|
||||
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
|
||||
# block.
|
||||
target_modules = ["to_q"]
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
target_modules.append(self._get_linear_module_name_other_than_attn(model))
|
||||
del model
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||
with pytest.raises(RuntimeError, match=msg):
|
||||
model.enable_lora_hotswap(target_rank=32)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
import logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
|
||||
# check possibility to ignore the error/warning
|
||||
import logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||
|
||||
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
|
||||
# check the error and log
|
||||
import logging
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
@require_torch_version_greater("2.7.1")
|
||||
def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1):
|
||||
different_shapes_for_compilation = self.different_shapes_for_compilation
|
||||
if different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
|
||||
# variable to represent input sizes that are the same. For more details,
|
||||
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=rank0,
|
||||
rank1=rank1,
|
||||
target_modules0=target_modules,
|
||||
)
|
||||
493
tests/models/testing_utils/memory.py
Normal file
493
tests/models/testing_utils/memory.py
Normal file
@@ -0,0 +1,493 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import glob
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.utils.modeling import compute_module_sizes
|
||||
|
||||
from diffusers.utils.testing_utils import _check_safetensors_serialization
|
||||
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
backend_synchronize,
|
||||
is_cpu_offload,
|
||||
is_group_offload,
|
||||
is_memory,
|
||||
require_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .common import check_device_map_is_respected
|
||||
|
||||
|
||||
def cast_maybe_tensor_dtype(inputs_dict, from_dtype, to_dtype):
|
||||
"""Helper to cast tensor inputs from one dtype to another."""
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor) and value.dtype == from_dtype:
|
||||
inputs_dict[key] = value.to(to_dtype)
|
||||
return inputs_dict
|
||||
|
||||
|
||||
def require_offload_support(func):
|
||||
"""
|
||||
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_group_offload_support(func):
|
||||
"""
|
||||
Decorator to skip tests if model doesn't support group offloading.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@is_cpu_offload
|
||||
class CPUOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing CPU offloading functionality.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- model_split_percents: List of percentages for splitting model across devices
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cpu_offload
|
||||
Use `pytest -m "not cpu_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
model_split_percents = [0.5, 0.7]
|
||||
|
||||
@require_offload_support
|
||||
def test_cpu_offload(self, tmp_path):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
model.cpu().save_pretrained(str(tmp_path))
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading"
|
||||
)
|
||||
|
||||
@require_offload_support
|
||||
def test_disk_offload_without_safetensors(self, tmp_path):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
# Force disk offload by setting very small CPU memory
|
||||
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
|
||||
|
||||
model.cpu().save_pretrained(str(tmp_path), safe_serialization=False)
|
||||
# This errors out because it's missing an offload folder
|
||||
with pytest.raises(ValueError):
|
||||
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
|
||||
|
||||
new_model = self.model_class.from_pretrained(
|
||||
str(tmp_path), device_map="auto", max_memory=max_memory, offload_folder=str(tmp_path)
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading"
|
||||
)
|
||||
|
||||
@require_offload_support
|
||||
def test_disk_offload_with_safetensors(self, tmp_path):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model.cpu().save_pretrained(str(tmp_path))
|
||||
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = self.model_class.from_pretrained(
|
||||
str(tmp_path), device_map="auto", offload_folder=str(tmp_path), max_memory=max_memory
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0],
|
||||
new_output[0],
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Output should match with disk offloading (safetensors)",
|
||||
)
|
||||
|
||||
|
||||
@is_group_offload
|
||||
class GroupOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing group offloading functionality.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: group_offload
|
||||
Use `pytest -m "not group_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
@require_group_offload_support
|
||||
def test_group_offloading(self, record_stream=False):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_forward(model):
|
||||
assert all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in model.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
), "Group offloading hook should be set"
|
||||
model.eval()
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
|
||||
output_with_group_offloading1 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
|
||||
output_with_group_offloading2 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="leaf_level")
|
||||
output_with_group_offloading3 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
|
||||
)
|
||||
output_with_group_offloading4 = run_forward(model)
|
||||
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading1,
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Output should match with block-level offloading",
|
||||
)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading2,
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Output should match with non-blocking block-level offloading",
|
||||
)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading3,
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Output should match with leaf-level offloading",
|
||||
)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading4,
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Output should match with leaf-level offloading with stream",
|
||||
)
|
||||
|
||||
@require_group_offload_support
|
||||
@torch.no_grad()
|
||||
def test_group_offloading_with_layerwise_casting(self, record_stream=False, offload_type="block_level"):
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
storage_dtype, compute_dtype = torch.float16, torch.float32
|
||||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
|
||||
)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
@require_group_offload_support
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def test_group_offloading_with_disk(self, tmp_path, offload_type="block_level", record_stream=False, atol=1e-5):
|
||||
def _has_generator_arg(model):
|
||||
sig = inspect.signature(model.forward)
|
||||
params = sig.parameters
|
||||
return "generator" in params
|
||||
|
||||
def _run_forward(model, inputs_dict):
|
||||
accepts_generator = _has_generator_arg(model)
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
torch.manual_seed(0)
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = _run_forward(model, inputs_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
num_blocks_per_group = None if offload_type == "leaf_level" else 1
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
|
||||
tmpdir = str(tmp_path)
|
||||
model.enable_group_offload(
|
||||
torch_device,
|
||||
offload_type=offload_type,
|
||||
offload_to_disk_path=tmpdir,
|
||||
use_stream=True,
|
||||
record_stream=record_stream,
|
||||
**additional_kwargs,
|
||||
)
|
||||
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
|
||||
assert has_safetensors, "No safetensors found in the directory."
|
||||
|
||||
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
|
||||
# in nature. So, skip it.
|
||||
if offload_type != "leaf_level":
|
||||
is_correct, extra_files, missing_files = _check_safetensors_serialization(
|
||||
module=model,
|
||||
offload_to_disk_path=tmpdir,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
)
|
||||
if not is_correct:
|
||||
if extra_files:
|
||||
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
|
||||
elif missing_files:
|
||||
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||
|
||||
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading,
|
||||
atol=atol,
|
||||
rtol=0,
|
||||
msg="Output should match with disk-based group offloading",
|
||||
)
|
||||
|
||||
|
||||
class LayerwiseCastingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing layerwise dtype casting for memory optimization.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def test_layerwise_casting_memory(self):
|
||||
MB_TOLERANCE = 0.2
|
||||
LEAST_COMPUTE_CAPABILITY = 8.0
|
||||
|
||||
def reset_memory_stats():
|
||||
gc.collect()
|
||||
backend_synchronize(torch_device)
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
|
||||
def get_memory_usage(storage_dtype, compute_dtype):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device, dtype=compute_dtype)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
|
||||
reset_memory_stats()
|
||||
model(**inputs_dict)
|
||||
model_memory_footprint = model.get_memory_footprint()
|
||||
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
|
||||
|
||||
return model_memory_footprint, peak_inference_memory_allocated_mb
|
||||
|
||||
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
|
||||
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
|
||||
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
|
||||
torch.float8_e4m3fn, torch.bfloat16
|
||||
)
|
||||
|
||||
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
|
||||
assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint, (
|
||||
"Memory footprint should decrease with lower precision storage"
|
||||
)
|
||||
|
||||
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
|
||||
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
|
||||
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
|
||||
assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory, (
|
||||
"Peak memory should be lower with bf16 compute on newer GPUs"
|
||||
)
|
||||
|
||||
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
|
||||
# bytes. This only happens for some models, so we allow a small tolerance.
|
||||
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
|
||||
assert (
|
||||
fp8_e4m3_fp32_max_memory < fp32_max_memory
|
||||
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
|
||||
), "Peak memory should be lower or within tolerance with fp8 storage"
|
||||
|
||||
def test_layerwise_casting_training(self):
|
||||
def test_fn(storage_dtype, compute_dtype):
|
||||
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
|
||||
pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
|
||||
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model = model.to(torch_device, dtype=compute_dtype)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
model.train()
|
||||
|
||||
inputs_dict = self.get_inputs_dict()
|
||||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type):
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
|
||||
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
|
||||
test_fn(torch.float16, torch.float32)
|
||||
test_fn(torch.float8_e4m3fn, torch.float32)
|
||||
test_fn(torch.float8_e5m2, torch.float32)
|
||||
test_fn(torch.float8_e4m3fn, torch.bfloat16)
|
||||
|
||||
|
||||
@is_memory
|
||||
@require_accelerator
|
||||
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
|
||||
"""
|
||||
Combined mixin class for all memory optimization tests including CPU/disk offloading,
|
||||
group offloading, and layerwise dtype casting.
|
||||
|
||||
This mixin inherits from:
|
||||
- CPUOffloadTesterMixin: CPU and disk offloading tests
|
||||
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
|
||||
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: memory
|
||||
Use `pytest -m "not memory"` to skip these tests
|
||||
"""
|
||||
|
||||
pass
|
||||
102
tests/models/testing_utils/parallelism.py
Normal file
102
tests/models/testing_utils/parallelism.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from diffusers.models._modeling_parallel import ContextParallelConfig
|
||||
|
||||
from ...testing_utils import (
|
||||
is_context_parallel,
|
||||
require_torch_multi_accelerator,
|
||||
)
|
||||
|
||||
|
||||
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
|
||||
try:
|
||||
# Setup distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="env://",
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
inputs_on_device = {}
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
inputs_on_device[key] = value.to(device)
|
||||
else:
|
||||
inputs_on_device[key] = value
|
||||
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
|
||||
if rank == 0:
|
||||
result_queue.put(("success", output.shape))
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
result_queue.put(("error", str(e)))
|
||||
finally:
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
@is_context_parallel
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelTesterMixin:
|
||||
base_precision = 1e-3
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_inference(self, cp_type):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
result_queue = ctx.Queue()
|
||||
|
||||
mp.spawn(
|
||||
_context_parallel_worker,
|
||||
args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
status, result = result_queue.get(timeout=60)
|
||||
assert status == "success", f"Context parallel inference failed: {result}"
|
||||
1291
tests/models/testing_utils/quantization.py
Normal file
1291
tests/models/testing_utils/quantization.py
Normal file
File diff suppressed because it is too large
Load Diff
244
tests/models/testing_utils/single_file.py
Normal file
244
tests/models/testing_utils/single_file.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_single_file,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
|
||||
"""Download a single file checkpoint from the Hub to a temporary directory."""
|
||||
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
|
||||
return path
|
||||
|
||||
|
||||
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
|
||||
"""Download diffusers config files (excluding weights) from a repository."""
|
||||
path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
ignore_patterns=[
|
||||
"**/*.ckpt",
|
||||
"*.ckpt",
|
||||
"**/*.bin",
|
||||
"*.bin",
|
||||
"**/*.pt",
|
||||
"*.pt",
|
||||
"**/*.safetensors",
|
||||
"*.safetensors",
|
||||
],
|
||||
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
|
||||
local_dir=tmpdir,
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@is_single_file
|
||||
class SingleFileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing single file loading for models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- ckpt_path: Path or Hub path to the single file checkpoint
|
||||
- subfolder: (Optional) Subfolder within the repo
|
||||
- torch_dtype: (Optional) torch dtype to use for testing
|
||||
|
||||
Pytest mark: single_file
|
||||
Use `pytest -m "not single_file"` to skip these tests
|
||||
"""
|
||||
|
||||
pretrained_model_name_or_path = None
|
||||
ckpt_path = None
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_model_config(self):
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
|
||||
pretrained_kwargs["device"] = torch_device
|
||||
single_file_kwargs["device"] = torch_device
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading: "
|
||||
f"pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
)
|
||||
|
||||
def test_single_file_model_parameters(self):
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
|
||||
pretrained_kwargs["device"] = torch_device
|
||||
single_file_kwargs["device"] = torch_device
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict_single_file = model_single_file.state_dict()
|
||||
|
||||
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
|
||||
"Model parameters keys differ between pretrained and single file loading. "
|
||||
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
|
||||
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
|
||||
)
|
||||
|
||||
for key in state_dict.keys():
|
||||
param = state_dict[key]
|
||||
param_single_file = state_dict_single_file[key]
|
||||
|
||||
assert param.shape == param_single_file.shape, (
|
||||
f"Parameter shape mismatch for {key}: "
|
||||
f"pretrained {param.shape} vs single file {param_single_file.shape}"
|
||||
)
|
||||
|
||||
assert_tensors_close(
|
||||
param, param_single_file, atol=1e-5, rtol=1e-5, msg=f"Parameter values differ for {key}"
|
||||
)
|
||||
|
||||
def test_single_file_loading_local_files_only(self, tmp_path):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with local_files_only=True"
|
||||
|
||||
def test_single_file_loading_with_diffusers_config(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
# Load with config parameter
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
|
||||
)
|
||||
|
||||
# Load pretrained for comparison
|
||||
pretrained_kwargs = {}
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
|
||||
# Compare configs
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
)
|
||||
|
||||
def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
|
||||
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, str(tmp_path))
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
|
||||
|
||||
def test_single_file_loading_dtype(self):
|
||||
for dtype in [torch.float32, torch.float16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
|
||||
|
||||
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
|
||||
|
||||
# Cleanup
|
||||
del model_single_file
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_checkpoint_variant_loading(self):
|
||||
if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths:
|
||||
return
|
||||
|
||||
for ckpt_path in self.alternate_ckpt_paths:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
single_file_kwargs = {}
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
|
||||
|
||||
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
207
tests/models/testing_utils/training.py
Normal file
207
tests/models/testing_utils/training.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
||||
from ...testing_utils import is_training, require_torch_accelerator_with_training, torch_all_close, torch_device
|
||||
|
||||
|
||||
@is_training
|
||||
@require_torch_accelerator_with_training
|
||||
class TrainingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing training functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Expected properties to be implemented by subclasses:
|
||||
- output_shape: Tuple defining the expected output shape
|
||||
|
||||
Pytest mark: training
|
||||
Use `pytest -m "not training"` to skip these tests
|
||||
"""
|
||||
|
||||
def test_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
def test_training_with_ema(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
ema_model = EMAModel(model.parameters())
|
||||
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
ema_model.step(model.parameters())
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
# at init model should have gradient checkpointing disabled
|
||||
model = self.model_class(**init_dict)
|
||||
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
|
||||
|
||||
# check enable works
|
||||
model.enable_gradient_checkpointing()
|
||||
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
|
||||
|
||||
# check disable works
|
||||
model.disable_gradient_checkpointing()
|
||||
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self, expected_set=None):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
if expected_set is None:
|
||||
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
model_class_copy = copy.copy(self.model_class)
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
for submodule in model.modules():
|
||||
if hasattr(submodule, "gradient_checkpointing"):
|
||||
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
|
||||
modules_with_gc_enabled[submodule.__class__.__name__] = True
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == expected_set, (
|
||||
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} do not match expected set {expected_set}"
|
||||
)
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
|
||||
|
||||
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
if skip is None:
|
||||
skip = set()
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict_copy = copy.deepcopy(inputs_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# run the backwards pass on the model
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
torch.manual_seed(0)
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict_copy, return_dict=False)[0]
|
||||
|
||||
# run the backwards pass on the model
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
assert (loss - loss_2).abs() < loss_tolerance, (
|
||||
f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
|
||||
)
|
||||
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
if name in skip:
|
||||
continue
|
||||
if param.grad is None:
|
||||
continue
|
||||
|
||||
assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol), (
|
||||
f"Gradient mismatch for {name}"
|
||||
)
|
||||
|
||||
def test_mixed_precision_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# Test with float16
|
||||
if torch.device(torch_device).type != "cpu":
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# Test with bfloat16
|
||||
if torch.device(torch_device).type != "cpu":
|
||||
model.zero_grad()
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
@@ -13,23 +13,51 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
IPAdapterTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelOptCompileTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
ModelTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
QuantoCompileTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
SingleFileTesterMixin,
|
||||
TorchAoCompileTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def create_flux_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
# TODO: This standalone function maintains backward compatibility with pipeline tests
|
||||
# (tests/pipelines/test_pipelines_common.py) and will be refactored.
|
||||
def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
"""Create a dummy IP Adapter state dict for Flux transformer testing."""
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 0
|
||||
|
||||
@@ -39,7 +67,7 @@ def create_flux_ip_adapter_state_dict(model):
|
||||
|
||||
joint_attention_dim = model.config["joint_attention_dim"]
|
||||
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
|
||||
sd = FluxIPAdapterJointAttnProcessor2_0(
|
||||
sd = FluxIPAdapterAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
ip_cross_attn_state_dict.update(
|
||||
@@ -50,11 +78,8 @@ def create_flux_ip_adapter_state_dict(model):
|
||||
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 1
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=model.config["joint_attention_dim"],
|
||||
image_embed_dim=(
|
||||
@@ -75,57 +100,37 @@ def create_flux_ip_adapter_state_dict(model):
|
||||
)
|
||||
|
||||
del sd
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
|
||||
|
||||
|
||||
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
class FluxTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return FluxTransformer2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
def pretrained_model_kwargs(self):
|
||||
return {"subfolder": "transformer"}
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
"""Return Flux model initialization arguments."""
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
@@ -137,11 +142,32 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), generator=self.generator),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator
|
||||
),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim), generator=self.generator),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels), generator=self.generator),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels), generator=self.generator),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
"""Test that deprecated 3D img_ids and txt_ids still work."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -162,63 +188,267 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
self.assertTrue(
|
||||
torch.allclose(output_1, output_2, atol=1e-5),
|
||||
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
|
||||
assert output_1.shape == output_2.shape
|
||||
assert torch.allclose(output_1, output_2, atol=1e-5), (
|
||||
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
|
||||
"are not equal as them as 2d inputs"
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
# The test exists for cases like
|
||||
# https://github.com/huggingface/diffusers/issues/11874
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_exclude_modules(self):
|
||||
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
|
||||
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Flux Transformer."""
|
||||
|
||||
lora_rank = 4
|
||||
target_module = "single_transformer_blocks.0.proj_out"
|
||||
adapter_name = "foo"
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
pass
|
||||
|
||||
state_dict = model.state_dict()
|
||||
target_mod_shape = state_dict[f"{target_module}.weight"].shape
|
||||
lora_state_dict = {
|
||||
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
|
||||
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
|
||||
|
||||
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for Flux Transformer"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for Flux Transformer."""
|
||||
|
||||
ip_adapter_processor_cls = FluxIPAdapterAttnProcessor
|
||||
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
torch.manual_seed(0)
|
||||
# Create dummy image embeds for IP adapter
|
||||
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
|
||||
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
|
||||
|
||||
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
|
||||
return create_flux_ip_adapter_state_dict(model)
|
||||
|
||||
|
||||
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for Flux Transformer."""
|
||||
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
|
||||
config = LoraConfig(
|
||||
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
|
||||
)
|
||||
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
|
||||
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
|
||||
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
|
||||
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
|
||||
|
||||
|
||||
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
|
||||
subfolder = "transformer"
|
||||
pass
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
|
||||
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompileTesterMixin):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin):
|
||||
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
|
||||
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
|
||||
"""FirstBlockCache tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin):
|
||||
"""FasterCache tests for Flux Transformer."""
|
||||
|
||||
# Flux is guidance distilled, so we can test at model level without CFG batch handling
|
||||
FASTER_CACHE_CONFIG = {
|
||||
"spatial_attention_block_skip_range": 2,
|
||||
"spatial_attention_timestep_skip_range": (-1, 901),
|
||||
"tensor_format": "BCHW",
|
||||
"is_guidance_distilled": True,
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ from diffusers.utils.import_utils import (
|
||||
is_gguf_available,
|
||||
is_kernels_available,
|
||||
is_note_seq_available,
|
||||
is_nvidia_modelopt_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_optimum_quanto_available,
|
||||
@@ -130,6 +131,59 @@ def torch_all_close(a, b, *args, **kwargs):
|
||||
return True
|
||||
|
||||
|
||||
def assert_tensors_close(
|
||||
actual: "torch.Tensor",
|
||||
expected: "torch.Tensor",
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-5,
|
||||
msg: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Assert that two tensors are close within tolerance.
|
||||
|
||||
Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected|
|
||||
Provides concise, actionable error messages without dumping full tensors.
|
||||
|
||||
Args:
|
||||
actual: The actual tensor from the computation.
|
||||
expected: The expected tensor to compare against.
|
||||
atol: Absolute tolerance.
|
||||
rtol: Relative tolerance.
|
||||
msg: Optional message prefix for the assertion error.
|
||||
|
||||
Raises:
|
||||
AssertionError: If tensors have different shapes or values exceed tolerance.
|
||||
|
||||
Example:
|
||||
>>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass")
|
||||
"""
|
||||
if not is_torch_available():
|
||||
raise ValueError("PyTorch needs to be installed to use this function.")
|
||||
|
||||
if actual.shape != expected.shape:
|
||||
raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")
|
||||
|
||||
if not torch.allclose(actual, expected, atol=atol, rtol=rtol):
|
||||
abs_diff = (actual - expected).abs()
|
||||
max_diff = abs_diff.max().item()
|
||||
|
||||
flat_idx = abs_diff.argmax().item()
|
||||
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
|
||||
|
||||
threshold = atol + rtol * expected.abs()
|
||||
mismatched = (abs_diff > threshold).sum().item()
|
||||
total = actual.numel()
|
||||
|
||||
raise AssertionError(
|
||||
f"{msg}\n"
|
||||
f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n"
|
||||
f" Max diff: {max_diff:.6e} at index {max_idx}\n"
|
||||
f" Actual: {actual.flatten()[flat_idx].item():.6e}\n"
|
||||
f" Expected: {expected.flatten()[flat_idx].item():.6e}\n"
|
||||
f" atol: {atol:.6e}, rtol: {rtol:.6e}"
|
||||
)
|
||||
|
||||
|
||||
def numpy_cosine_similarity_distance(a, b):
|
||||
similarity = np.dot(a, b) / (norm(a) * norm(b))
|
||||
distance = 1.0 - similarity.mean()
|
||||
@@ -241,7 +295,6 @@ def parse_flag_from_env(key, default=False):
|
||||
|
||||
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||||
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
|
||||
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
|
||||
|
||||
|
||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
@@ -282,12 +335,155 @@ def nightly(test_case):
|
||||
|
||||
def is_torch_compile(test_case):
|
||||
"""
|
||||
Decorator marking a test that runs compile tests in the diffusers CI.
|
||||
|
||||
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
|
||||
|
||||
Decorator marking a test as a torch.compile test. These tests can be filtered using:
|
||||
pytest -m "not compile" to skip
|
||||
pytest -m compile to run only these tests
|
||||
"""
|
||||
return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
|
||||
return pytest.mark.compile(test_case)
|
||||
|
||||
|
||||
def is_single_file(test_case):
|
||||
"""
|
||||
Decorator marking a test as a single file loading test. These tests can be filtered using:
|
||||
pytest -m "not single_file" to skip
|
||||
pytest -m single_file to run only these tests
|
||||
"""
|
||||
return pytest.mark.single_file(test_case)
|
||||
|
||||
|
||||
def is_lora(test_case):
|
||||
"""
|
||||
Decorator marking a test as a LoRA test. These tests can be filtered using:
|
||||
pytest -m "not lora" to skip
|
||||
pytest -m lora to run only these tests
|
||||
"""
|
||||
return pytest.mark.lora(test_case)
|
||||
|
||||
|
||||
def is_ip_adapter(test_case):
|
||||
"""
|
||||
Decorator marking a test as an IP Adapter test. These tests can be filtered using:
|
||||
pytest -m "not ip_adapter" to skip
|
||||
pytest -m ip_adapter to run only these tests
|
||||
"""
|
||||
return pytest.mark.ip_adapter(test_case)
|
||||
|
||||
|
||||
def is_training(test_case):
|
||||
"""
|
||||
Decorator marking a test as a training test. These tests can be filtered using:
|
||||
pytest -m "not training" to skip
|
||||
pytest -m training to run only these tests
|
||||
"""
|
||||
return pytest.mark.training(test_case)
|
||||
|
||||
|
||||
def is_attention(test_case):
|
||||
"""
|
||||
Decorator marking a test as an attention test. These tests can be filtered using:
|
||||
pytest -m "not attention" to skip
|
||||
pytest -m attention to run only these tests
|
||||
"""
|
||||
return pytest.mark.attention(test_case)
|
||||
|
||||
|
||||
def is_memory(test_case):
|
||||
"""
|
||||
Decorator marking a test as a memory optimization test. These tests can be filtered using:
|
||||
pytest -m "not memory" to skip
|
||||
pytest -m memory to run only these tests
|
||||
"""
|
||||
return pytest.mark.memory(test_case)
|
||||
|
||||
|
||||
def is_cpu_offload(test_case):
|
||||
"""
|
||||
Decorator marking a test as a CPU offload test. These tests can be filtered using:
|
||||
pytest -m "not cpu_offload" to skip
|
||||
pytest -m cpu_offload to run only these tests
|
||||
"""
|
||||
return pytest.mark.cpu_offload(test_case)
|
||||
|
||||
|
||||
def is_group_offload(test_case):
|
||||
"""
|
||||
Decorator marking a test as a group offload test. These tests can be filtered using:
|
||||
pytest -m "not group_offload" to skip
|
||||
pytest -m group_offload to run only these tests
|
||||
"""
|
||||
return pytest.mark.group_offload(test_case)
|
||||
|
||||
|
||||
def is_quantization(test_case):
|
||||
"""
|
||||
Decorator marking a test as a quantization test. These tests can be filtered using:
|
||||
pytest -m "not quantization" to skip
|
||||
pytest -m quantization to run only these tests
|
||||
"""
|
||||
return pytest.mark.quantization(test_case)
|
||||
|
||||
|
||||
def is_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
|
||||
pytest -m "not bitsandbytes" to skip
|
||||
pytest -m bitsandbytes to run only these tests
|
||||
"""
|
||||
return pytest.mark.bitsandbytes(test_case)
|
||||
|
||||
|
||||
def is_quanto(test_case):
|
||||
"""
|
||||
Decorator marking a test as a Quanto quantization test. These tests can be filtered using:
|
||||
pytest -m "not quanto" to skip
|
||||
pytest -m quanto to run only these tests
|
||||
"""
|
||||
return pytest.mark.quanto(test_case)
|
||||
|
||||
|
||||
def is_torchao(test_case):
|
||||
"""
|
||||
Decorator marking a test as a TorchAO quantization test. These tests can be filtered using:
|
||||
pytest -m "not torchao" to skip
|
||||
pytest -m torchao to run only these tests
|
||||
"""
|
||||
return pytest.mark.torchao(test_case)
|
||||
|
||||
|
||||
def is_gguf(test_case):
|
||||
"""
|
||||
Decorator marking a test as a GGUF quantization test. These tests can be filtered using:
|
||||
pytest -m "not gguf" to skip
|
||||
pytest -m gguf to run only these tests
|
||||
"""
|
||||
return pytest.mark.gguf(test_case)
|
||||
|
||||
|
||||
def is_modelopt(test_case):
|
||||
"""
|
||||
Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using:
|
||||
pytest -m "not modelopt" to skip
|
||||
pytest -m modelopt to run only these tests
|
||||
"""
|
||||
return pytest.mark.modelopt(test_case)
|
||||
|
||||
|
||||
def is_context_parallel(test_case):
|
||||
"""
|
||||
Decorator marking a test as a context parallel inference test. These tests can be filtered using:
|
||||
pytest -m "not context_parallel" to skip
|
||||
pytest -m context_parallel to run only these tests
|
||||
"""
|
||||
return pytest.mark.context_parallel(test_case)
|
||||
|
||||
|
||||
def is_cache(test_case):
|
||||
"""
|
||||
Decorator marking a test as a cache test. These tests can be filtered using:
|
||||
pytest -m "not cache" to skip
|
||||
pytest -m cache to run only these tests
|
||||
"""
|
||||
return pytest.mark.cache(test_case)
|
||||
|
||||
|
||||
def require_torch(test_case):
|
||||
@@ -650,6 +846,19 @@ def require_kernels_version_greater_or_equal(kernels_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_modelopt_version_greater_or_equal(modelopt_version):
|
||||
def decorator(test_case):
|
||||
correct_nvidia_modelopt_version = is_nvidia_modelopt_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("modelopt")).base_version
|
||||
) >= version.parse(modelopt_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_nvidia_modelopt_version,
|
||||
reason=f"Test requires modelopt with version greater than {modelopt_version}.",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
|
||||
509
utils/generate_model_tests.py
Normal file
509
utils/generate_model_tests.py
Normal file
@@ -0,0 +1,509 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Utility script to generate test suites for diffusers model classes.
|
||||
|
||||
Usage:
|
||||
python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_flux.py
|
||||
|
||||
This will analyze the model file and generate a test file with appropriate
|
||||
test classes based on the model's mixins and attributes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
MIXIN_TO_TESTER = {
|
||||
"ModelMixin": "ModelTesterMixin",
|
||||
"PeftAdapterMixin": "LoraTesterMixin",
|
||||
}
|
||||
|
||||
ATTRIBUTE_TO_TESTER = {
|
||||
"_cp_plan": "ContextParallelTesterMixin",
|
||||
"_supports_gradient_checkpointing": "TrainingTesterMixin",
|
||||
}
|
||||
|
||||
ALWAYS_INCLUDE_TESTERS = [
|
||||
"ModelTesterMixin",
|
||||
"MemoryTesterMixin",
|
||||
"TorchCompileTesterMixin",
|
||||
]
|
||||
|
||||
# Attention-related class names that indicate the model uses attention
|
||||
ATTENTION_INDICATORS = {
|
||||
"AttentionMixin",
|
||||
"AttentionModuleMixin",
|
||||
}
|
||||
|
||||
OPTIONAL_TESTERS = [
|
||||
("BitsAndBytesTesterMixin", "bnb"),
|
||||
("QuantoTesterMixin", "quanto"),
|
||||
("TorchAoTesterMixin", "torchao"),
|
||||
("GGUFTesterMixin", "gguf"),
|
||||
("ModelOptTesterMixin", "modelopt"),
|
||||
("SingleFileTesterMixin", "single_file"),
|
||||
("IPAdapterTesterMixin", "ip_adapter"),
|
||||
]
|
||||
|
||||
|
||||
class ModelAnalyzer(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.model_classes = []
|
||||
self.current_class = None
|
||||
self.imports = set()
|
||||
|
||||
def visit_Import(self, node: ast.Import):
|
||||
for alias in node.names:
|
||||
self.imports.add(alias.name.split(".")[-1])
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
self.imports.add(alias.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef):
|
||||
base_names = []
|
||||
for base in node.bases:
|
||||
if isinstance(base, ast.Name):
|
||||
base_names.append(base.id)
|
||||
elif isinstance(base, ast.Attribute):
|
||||
base_names.append(base.attr)
|
||||
|
||||
if "ModelMixin" in base_names:
|
||||
class_info = {
|
||||
"name": node.name,
|
||||
"bases": base_names,
|
||||
"attributes": {},
|
||||
"has_forward": False,
|
||||
"init_params": [],
|
||||
}
|
||||
|
||||
for item in node.body:
|
||||
if isinstance(item, ast.Assign):
|
||||
for target in item.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
attr_name = target.id
|
||||
if attr_name.startswith("_"):
|
||||
class_info["attributes"][attr_name] = self._get_value(item.value)
|
||||
|
||||
elif isinstance(item, ast.FunctionDef):
|
||||
if item.name == "forward":
|
||||
class_info["has_forward"] = True
|
||||
class_info["forward_params"] = self._extract_func_params(item)
|
||||
elif item.name == "__init__":
|
||||
class_info["init_params"] = self._extract_func_params(item)
|
||||
|
||||
self.model_classes.append(class_info)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def _extract_func_params(self, func_node: ast.FunctionDef) -> list[dict]:
|
||||
params = []
|
||||
args = func_node.args
|
||||
|
||||
num_defaults = len(args.defaults)
|
||||
num_args = len(args.args)
|
||||
first_default_idx = num_args - num_defaults
|
||||
|
||||
for i, arg in enumerate(args.args):
|
||||
if arg.arg == "self":
|
||||
continue
|
||||
|
||||
param_info = {"name": arg.arg, "type": None, "default": None}
|
||||
|
||||
if arg.annotation:
|
||||
param_info["type"] = self._get_annotation_str(arg.annotation)
|
||||
|
||||
default_idx = i - first_default_idx
|
||||
if default_idx >= 0 and default_idx < len(args.defaults):
|
||||
param_info["default"] = self._get_value(args.defaults[default_idx])
|
||||
|
||||
params.append(param_info)
|
||||
|
||||
return params
|
||||
|
||||
def _get_annotation_str(self, node) -> str:
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
elif isinstance(node, ast.Constant):
|
||||
return repr(node.value)
|
||||
elif isinstance(node, ast.Subscript):
|
||||
base = self._get_annotation_str(node.value)
|
||||
if isinstance(node.slice, ast.Tuple):
|
||||
args = ", ".join(self._get_annotation_str(el) for el in node.slice.elts)
|
||||
else:
|
||||
args = self._get_annotation_str(node.slice)
|
||||
return f"{base}[{args}]"
|
||||
elif isinstance(node, ast.Attribute):
|
||||
return f"{self._get_annotation_str(node.value)}.{node.attr}"
|
||||
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
left = self._get_annotation_str(node.left)
|
||||
right = self._get_annotation_str(node.right)
|
||||
return f"{left} | {right}"
|
||||
elif isinstance(node, ast.Tuple):
|
||||
return ", ".join(self._get_annotation_str(el) for el in node.elts)
|
||||
return "Any"
|
||||
|
||||
def _get_value(self, node):
|
||||
if isinstance(node, ast.Constant):
|
||||
return node.value
|
||||
elif isinstance(node, ast.Name):
|
||||
if node.id == "None":
|
||||
return None
|
||||
elif node.id == "True":
|
||||
return True
|
||||
elif node.id == "False":
|
||||
return False
|
||||
return node.id
|
||||
elif isinstance(node, ast.List):
|
||||
return [self._get_value(el) for el in node.elts]
|
||||
elif isinstance(node, ast.Dict):
|
||||
return {self._get_value(k): self._get_value(v) for k, v in zip(node.keys, node.values)}
|
||||
return "<complex>"
|
||||
|
||||
|
||||
def analyze_model_file(filepath: str) -> tuple[list[dict], set[str]]:
|
||||
with open(filepath) as f:
|
||||
source = f.read()
|
||||
|
||||
tree = ast.parse(source)
|
||||
analyzer = ModelAnalyzer()
|
||||
analyzer.visit(tree)
|
||||
|
||||
return analyzer.model_classes, analyzer.imports
|
||||
|
||||
|
||||
def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]:
|
||||
testers = list(ALWAYS_INCLUDE_TESTERS)
|
||||
|
||||
for base in model_info["bases"]:
|
||||
if base in MIXIN_TO_TESTER:
|
||||
tester = MIXIN_TO_TESTER[base]
|
||||
if tester not in testers:
|
||||
testers.append(tester)
|
||||
|
||||
for attr, tester in ATTRIBUTE_TO_TESTER.items():
|
||||
if attr in model_info["attributes"]:
|
||||
value = model_info["attributes"][attr]
|
||||
if value is not None and value is not False:
|
||||
if tester not in testers:
|
||||
testers.append(tester)
|
||||
|
||||
if "_cp_plan" in model_info["attributes"] and model_info["attributes"]["_cp_plan"] is not None:
|
||||
if "ContextParallelTesterMixin" not in testers:
|
||||
testers.append("ContextParallelTesterMixin")
|
||||
|
||||
# Include AttentionTesterMixin if the model imports attention-related classes
|
||||
if imports & ATTENTION_INDICATORS:
|
||||
testers.append("AttentionTesterMixin")
|
||||
|
||||
for tester, flag in OPTIONAL_TESTERS:
|
||||
if flag in include_optional:
|
||||
if tester not in testers:
|
||||
testers.append(tester)
|
||||
|
||||
return testers
|
||||
|
||||
|
||||
def generate_config_class(model_info: dict, model_name: str) -> str:
|
||||
class_name = f"{model_name}TesterConfig"
|
||||
model_class = model_info["name"]
|
||||
forward_params = model_info.get("forward_params", [])
|
||||
init_params = model_info.get("init_params", [])
|
||||
|
||||
lines = [
|
||||
f"class {class_name}:",
|
||||
f" model_class = {model_class}",
|
||||
' pretrained_model_name_or_path = ""',
|
||||
' pretrained_model_kwargs = {"subfolder": "transformer"}',
|
||||
"",
|
||||
" @property",
|
||||
" def generator(self):",
|
||||
' return torch.Generator("cpu").manual_seed(0)',
|
||||
"",
|
||||
" def get_init_dict(self) -> dict[str, int | list[int]]:",
|
||||
]
|
||||
|
||||
if init_params:
|
||||
lines.append(" # __init__ parameters:")
|
||||
for param in init_params:
|
||||
type_str = f": {param['type']}" if param["type"] else ""
|
||||
default_str = f" = {param['default']}" if param["default"] is not None else ""
|
||||
lines.append(f" # {param['name']}{type_str}{default_str}")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
" return {}",
|
||||
"",
|
||||
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||
]
|
||||
)
|
||||
|
||||
if forward_params:
|
||||
lines.append(" # forward() parameters:")
|
||||
for param in forward_params:
|
||||
type_str = f": {param['type']}" if param["type"] else ""
|
||||
default_str = f" = {param['default']}" if param["default"] is not None else ""
|
||||
lines.append(f" # {param['name']}{type_str}{default_str}")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
" # TODO: Fill in dummy inputs",
|
||||
" return {}",
|
||||
"",
|
||||
" @property",
|
||||
" def input_shape(self) -> tuple[int, ...]:",
|
||||
" return (1, 1)",
|
||||
"",
|
||||
" @property",
|
||||
" def output_shape(self) -> tuple[int, ...]:",
|
||||
" return (1, 1)",
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_test_class(model_name: str, config_class: str, tester: str) -> str:
|
||||
tester_short = tester.replace("TesterMixin", "")
|
||||
class_name = f"Test{model_name}{tester_short}"
|
||||
|
||||
lines = [f"class {class_name}({config_class}, {tester}):"]
|
||||
|
||||
if tester == "TorchCompileTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
" different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]",
|
||||
"",
|
||||
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
|
||||
" # TODO: Implement dynamic input generation",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester == "IPAdapterTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
" ip_adapter_processor_cls = None # TODO: Set processor class",
|
||||
"",
|
||||
" def modify_inputs_for_ip_adapter(self, model, inputs_dict):",
|
||||
" # TODO: Add IP adapter image embeds to inputs",
|
||||
" return inputs_dict",
|
||||
"",
|
||||
" def create_ip_adapter_state_dict(self, model):",
|
||||
" # TODO: Create IP adapter state dict",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester == "SingleFileTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
' ckpt_path = "" # TODO: Set checkpoint path',
|
||||
" alternate_keys_ckpt_paths = []",
|
||||
' pretrained_model_name_or_path = ""',
|
||||
' subfolder = "transformer"',
|
||||
]
|
||||
)
|
||||
elif tester == "GGUFTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
' gguf_filename = "" # TODO: Set GGUF filename',
|
||||
"",
|
||||
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||
" # TODO: Override with larger inputs for quantization tests",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester in ["BitsAndBytesTesterMixin", "QuantoTesterMixin", "TorchAoTesterMixin", "ModelOptTesterMixin"]:
|
||||
lines.extend(
|
||||
[
|
||||
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||
" # TODO: Override with larger inputs for quantization tests",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester == "LoraHotSwappingForModelTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
" different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]",
|
||||
"",
|
||||
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
|
||||
" # TODO: Implement dynamic input generation",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
lines.append(" pass")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str:
|
||||
model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "")
|
||||
testers = determine_testers(model_info, include_optional, imports)
|
||||
tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"})
|
||||
|
||||
lines = [
|
||||
"# coding=utf-8",
|
||||
"# Copyright 2025 HuggingFace Inc.",
|
||||
"#",
|
||||
'# Licensed under the Apache License, Version 2.0 (the "License");',
|
||||
"# you may not use this file except in compliance with the License.",
|
||||
"# You may obtain a copy of the License at",
|
||||
"#",
|
||||
"# http://www.apache.org/licenses/LICENSE-2.0",
|
||||
"#",
|
||||
"# Unless required by applicable law or agreed to in writing, software",
|
||||
'# distributed under the License is distributed on an "AS IS" BASIS,',
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
|
||||
"# See the License for the specific language governing permissions and",
|
||||
"# limitations under the License.",
|
||||
"",
|
||||
"import torch",
|
||||
"",
|
||||
f"from diffusers import {model_info['name']}",
|
||||
"from diffusers.utils.torch_utils import randn_tensor",
|
||||
"",
|
||||
"from ...testing_utils import enable_full_determinism, torch_device",
|
||||
]
|
||||
|
||||
if "LoraTesterMixin" in testers:
|
||||
lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"from ..testing_utils import (",
|
||||
*[f" {tester}," for tester in sorted(tester_imports)],
|
||||
")",
|
||||
"",
|
||||
"",
|
||||
"enable_full_determinism()",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
config_class = f"{model_name}TesterConfig"
|
||||
lines.append(generate_config_class(model_info, model_name))
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
|
||||
for tester in testers:
|
||||
lines.append(generate_test_class(model_name, config_class, tester))
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
|
||||
if "LoraTesterMixin" in testers:
|
||||
lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin"))
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines).rstrip() + "\n"
|
||||
|
||||
|
||||
def get_test_output_path(model_filepath: str) -> str:
|
||||
path = Path(model_filepath)
|
||||
model_filename = path.stem
|
||||
|
||||
if "transformers" in path.parts:
|
||||
return f"tests/models/transformers/test_models_{model_filename}.py"
|
||||
elif "unets" in path.parts:
|
||||
return f"tests/models/unets/test_models_{model_filename}.py"
|
||||
elif "autoencoders" in path.parts:
|
||||
return f"tests/models/autoencoders/test_models_{model_filename}.py"
|
||||
else:
|
||||
return f"tests/models/test_models_{model_filename}.py"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate test suite for a diffusers model class")
|
||||
parser.add_argument(
|
||||
"model_filepath",
|
||||
type=str,
|
||||
help="Path to the model file (e.g., src/diffusers/models/transformers/transformer_flux.py)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o", type=str, default=None, help="Output file path (default: auto-generated based on model path)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include",
|
||||
"-i",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
choices=["compile", "bnb", "quanto", "torchao", "gguf", "modelopt", "single_file", "ip_adapter", "all"],
|
||||
help="Optional testers to include",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class-name",
|
||||
"-c",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific model class to generate tests for (default: first model class found)",
|
||||
)
|
||||
parser.add_argument("--dry-run", action="store_true", help="Print generated code without writing to file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not Path(args.model_filepath).exists():
|
||||
print(f"Error: File not found: {args.model_filepath}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
model_classes, imports = analyze_model_file(args.model_filepath)
|
||||
|
||||
if not model_classes:
|
||||
print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if args.class_name:
|
||||
model_info = next((m for m in model_classes if m["name"] == args.class_name), None)
|
||||
if not model_info:
|
||||
available = [m["name"] for m in model_classes]
|
||||
print(f"Error: Class '{args.class_name}' not found. Available: {available}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
model_info = model_classes[0]
|
||||
if len(model_classes) > 1:
|
||||
print(f"Multiple model classes found, using: {model_info['name']}", file=sys.stderr)
|
||||
print("Use --class-name to specify a different class", file=sys.stderr)
|
||||
|
||||
include_optional = args.include
|
||||
if "all" in include_optional:
|
||||
include_optional = [flag for _, flag in OPTIONAL_TESTERS]
|
||||
|
||||
generated_code = generate_test_file(model_info, args.model_filepath, include_optional, imports)
|
||||
|
||||
if args.dry_run:
|
||||
print(generated_code)
|
||||
else:
|
||||
output_path = args.output or get_test_output_path(args.model_filepath)
|
||||
output_dir = Path(output_path).parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, "w") as f:
|
||||
f.write(generated_code)
|
||||
|
||||
print(f"Generated test file: {output_path}")
|
||||
print(f"Model class: {model_info['name']}")
|
||||
print(f"Detected attributes: {list(model_info['attributes'].keys())}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user