mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-20 19:34:48 +08:00
Compare commits
15 Commits
pipeline-s
...
model-test
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c70de2bc37 | ||
|
|
e82001e40d | ||
|
|
d9b73ffd51 | ||
|
|
dcd6026d17 | ||
|
|
eae7543712 | ||
|
|
d08e0bb545 | ||
|
|
c366b5a817 | ||
|
|
0fdd9d3a60 | ||
|
|
489480b02a | ||
|
|
fe451c367b | ||
|
|
0f1a4e0c14 | ||
|
|
aa29af8f0e | ||
|
|
bffa3a9754 | ||
|
|
1c558712e8 | ||
|
|
1f026ad14e |
@@ -32,6 +32,21 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
|
|||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
|
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
|
||||||
|
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
|
||||||
|
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
|
||||||
|
config.addinivalue_line("markers", "training: marks tests for training functionality")
|
||||||
|
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
|
||||||
|
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
|
||||||
|
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
|
||||||
|
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
|
||||||
|
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
|
||||||
|
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
|
||||||
|
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
|
||||||
|
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
|
||||||
|
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
|
||||||
|
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
|
||||||
|
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
|
||||||
|
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
|
|||||||
79
tests/models/testing_utils/__init__.py
Normal file
79
tests/models/testing_utils/__init__.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from .attention import AttentionTesterMixin
|
||||||
|
from .cache import (
|
||||||
|
CacheTesterMixin,
|
||||||
|
FasterCacheConfigMixin,
|
||||||
|
FasterCacheTesterMixin,
|
||||||
|
FirstBlockCacheConfigMixin,
|
||||||
|
FirstBlockCacheTesterMixin,
|
||||||
|
PyramidAttentionBroadcastConfigMixin,
|
||||||
|
PyramidAttentionBroadcastTesterMixin,
|
||||||
|
)
|
||||||
|
from .common import BaseModelTesterConfig, ModelTesterMixin
|
||||||
|
from .compile import TorchCompileTesterMixin
|
||||||
|
from .ip_adapter import IPAdapterTesterMixin
|
||||||
|
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
|
||||||
|
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
|
||||||
|
from .parallelism import ContextParallelTesterMixin
|
||||||
|
from .quantization import (
|
||||||
|
BitsAndBytesCompileTesterMixin,
|
||||||
|
BitsAndBytesConfigMixin,
|
||||||
|
BitsAndBytesTesterMixin,
|
||||||
|
GGUFCompileTesterMixin,
|
||||||
|
GGUFConfigMixin,
|
||||||
|
GGUFTesterMixin,
|
||||||
|
ModelOptCompileTesterMixin,
|
||||||
|
ModelOptConfigMixin,
|
||||||
|
ModelOptTesterMixin,
|
||||||
|
QuantizationCompileTesterMixin,
|
||||||
|
QuantizationTesterMixin,
|
||||||
|
QuantoCompileTesterMixin,
|
||||||
|
QuantoConfigMixin,
|
||||||
|
QuantoTesterMixin,
|
||||||
|
TorchAoCompileTesterMixin,
|
||||||
|
TorchAoConfigMixin,
|
||||||
|
TorchAoTesterMixin,
|
||||||
|
)
|
||||||
|
from .single_file import SingleFileTesterMixin
|
||||||
|
from .training import TrainingTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AttentionTesterMixin",
|
||||||
|
"BaseModelTesterConfig",
|
||||||
|
"BitsAndBytesCompileTesterMixin",
|
||||||
|
"BitsAndBytesConfigMixin",
|
||||||
|
"BitsAndBytesTesterMixin",
|
||||||
|
"CacheTesterMixin",
|
||||||
|
"ContextParallelTesterMixin",
|
||||||
|
"CPUOffloadTesterMixin",
|
||||||
|
"FasterCacheConfigMixin",
|
||||||
|
"FasterCacheTesterMixin",
|
||||||
|
"FirstBlockCacheConfigMixin",
|
||||||
|
"FirstBlockCacheTesterMixin",
|
||||||
|
"GGUFCompileTesterMixin",
|
||||||
|
"GGUFConfigMixin",
|
||||||
|
"GGUFTesterMixin",
|
||||||
|
"GroupOffloadTesterMixin",
|
||||||
|
"IPAdapterTesterMixin",
|
||||||
|
"LayerwiseCastingTesterMixin",
|
||||||
|
"LoraHotSwappingForModelTesterMixin",
|
||||||
|
"LoraTesterMixin",
|
||||||
|
"MemoryTesterMixin",
|
||||||
|
"ModelOptCompileTesterMixin",
|
||||||
|
"ModelOptConfigMixin",
|
||||||
|
"ModelOptTesterMixin",
|
||||||
|
"ModelTesterMixin",
|
||||||
|
"PyramidAttentionBroadcastConfigMixin",
|
||||||
|
"PyramidAttentionBroadcastTesterMixin",
|
||||||
|
"QuantizationCompileTesterMixin",
|
||||||
|
"QuantizationTesterMixin",
|
||||||
|
"QuantoCompileTesterMixin",
|
||||||
|
"QuantoConfigMixin",
|
||||||
|
"QuantoTesterMixin",
|
||||||
|
"SingleFileTesterMixin",
|
||||||
|
"TorchAoCompileTesterMixin",
|
||||||
|
"TorchAoConfigMixin",
|
||||||
|
"TorchAoTesterMixin",
|
||||||
|
"TorchCompileTesterMixin",
|
||||||
|
"TrainingTesterMixin",
|
||||||
|
]
|
||||||
185
tests/models/testing_utils/attention.py
Normal file
185
tests/models/testing_utils/attention.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers.models.attention import AttentionModuleMixin
|
||||||
|
from diffusers.models.attention_processor import (
|
||||||
|
AttnProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...testing_utils import (
|
||||||
|
assert_tensors_close,
|
||||||
|
is_attention,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@is_attention
|
||||||
|
class AttentionTesterMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for testing attention processor and module functionality on models.
|
||||||
|
|
||||||
|
Tests functionality from AttentionModuleMixin including:
|
||||||
|
- Attention processor management (set/get)
|
||||||
|
- QKV projection fusion/unfusion
|
||||||
|
- Attention backends (XFormers, NPU, etc.)
|
||||||
|
|
||||||
|
Expected class attributes to be set by subclasses:
|
||||||
|
- model_class: The model class to test
|
||||||
|
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
|
||||||
|
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Pytest mark: attention
|
||||||
|
Use `pytest -m "not attention"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_precision = 1e-3
|
||||||
|
|
||||||
|
def test_fuse_unfuse_qkv_projections(self):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if not hasattr(model, "fuse_qkv_projections"):
|
||||||
|
pytest.skip("Model does not support QKV projection fusion.")
|
||||||
|
|
||||||
|
# Get output before fusion
|
||||||
|
with torch.no_grad():
|
||||||
|
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Fuse projections
|
||||||
|
model.fuse_qkv_projections()
|
||||||
|
|
||||||
|
# Verify fusion occurred by checking for fused attributes
|
||||||
|
has_fused_projections = False
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, AttentionModuleMixin):
|
||||||
|
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
|
||||||
|
has_fused_projections = True
|
||||||
|
assert module.fused_projections, "fused_projections flag should be True"
|
||||||
|
break
|
||||||
|
|
||||||
|
if has_fused_projections:
|
||||||
|
# Get output after fusion
|
||||||
|
with torch.no_grad():
|
||||||
|
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Verify outputs match
|
||||||
|
assert_tensors_close(
|
||||||
|
output_before_fusion,
|
||||||
|
output_after_fusion,
|
||||||
|
atol=self.base_precision,
|
||||||
|
rtol=0,
|
||||||
|
msg="Output should not change after fusing projections",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unfuse projections
|
||||||
|
model.unfuse_qkv_projections()
|
||||||
|
|
||||||
|
# Verify unfusion occurred
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, AttentionModuleMixin):
|
||||||
|
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
|
||||||
|
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
|
||||||
|
assert not module.fused_projections, "fused_projections flag should be False"
|
||||||
|
|
||||||
|
# Get output after unfusion
|
||||||
|
with torch.no_grad():
|
||||||
|
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Verify outputs still match
|
||||||
|
assert_tensors_close(
|
||||||
|
output_before_fusion,
|
||||||
|
output_after_unfusion,
|
||||||
|
atol=self.base_precision,
|
||||||
|
rtol=0,
|
||||||
|
msg="Output should match original after unfusing projections",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_set_processor(self):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
# Check if model has attention processors
|
||||||
|
if not hasattr(model, "attn_processors"):
|
||||||
|
pytest.skip("Model does not have attention processors.")
|
||||||
|
|
||||||
|
# Test getting processors
|
||||||
|
processors = model.attn_processors
|
||||||
|
assert isinstance(processors, dict), "attn_processors should return a dict"
|
||||||
|
assert len(processors) > 0, "Model should have at least one attention processor"
|
||||||
|
|
||||||
|
# Test that all processors can be retrieved via get_processor
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, AttentionModuleMixin):
|
||||||
|
processor = module.get_processor()
|
||||||
|
assert processor is not None, "get_processor should return a processor"
|
||||||
|
|
||||||
|
# Test setting a new processor
|
||||||
|
new_processor = AttnProcessor()
|
||||||
|
module.set_processor(new_processor)
|
||||||
|
retrieved_processor = module.get_processor()
|
||||||
|
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
|
||||||
|
|
||||||
|
def test_attention_processor_dict(self):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
if not hasattr(model, "set_attn_processor"):
|
||||||
|
pytest.skip("Model does not support setting attention processors.")
|
||||||
|
|
||||||
|
# Get current processors
|
||||||
|
current_processors = model.attn_processors
|
||||||
|
|
||||||
|
# Create a dict of new processors
|
||||||
|
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
|
||||||
|
|
||||||
|
# Set processors using dict
|
||||||
|
model.set_attn_processor(new_processors)
|
||||||
|
|
||||||
|
# Verify all processors were set
|
||||||
|
updated_processors = model.attn_processors
|
||||||
|
for key in current_processors.keys():
|
||||||
|
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
|
||||||
|
|
||||||
|
def test_attention_processor_count_mismatch_raises_error(self):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
if not hasattr(model, "set_attn_processor"):
|
||||||
|
pytest.skip("Model does not support setting attention processors.")
|
||||||
|
|
||||||
|
# Get current processors
|
||||||
|
current_processors = model.attn_processors
|
||||||
|
|
||||||
|
# Create a dict with wrong number of processors
|
||||||
|
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
|
||||||
|
|
||||||
|
# Verify error is raised
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
model.set_attn_processor(wrong_processors)
|
||||||
|
|
||||||
|
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
|
||||||
536
tests/models/testing_utils/cache.py
Normal file
536
tests/models/testing_utils/cache.py
Normal file
@@ -0,0 +1,536 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig
|
||||||
|
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||||
|
from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||||
|
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||||
|
from diffusers.models.cache_utils import CacheMixin
|
||||||
|
|
||||||
|
from ...testing_utils import backend_empty_cache, is_cache, torch_device
|
||||||
|
|
||||||
|
|
||||||
|
def require_cache_mixin(func):
|
||||||
|
"""Decorator to skip tests if model doesn't use CacheMixin."""
|
||||||
|
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
if not issubclass(self.model_class, CacheMixin):
|
||||||
|
pytest.skip(f"{self.model_class.__name__} does not use CacheMixin.")
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class CacheTesterMixin:
|
||||||
|
"""
|
||||||
|
Base mixin class providing common test implementations for cache testing.
|
||||||
|
|
||||||
|
Cache-specific mixins should:
|
||||||
|
1. Inherit from their respective config mixin (e.g., PyramidAttentionBroadcastConfigMixin)
|
||||||
|
2. Inherit from this mixin
|
||||||
|
3. Define the cache config to use for tests
|
||||||
|
|
||||||
|
Expected class attributes:
|
||||||
|
- model_class: The model class to test (must use CacheMixin)
|
||||||
|
|
||||||
|
Expected methods in test classes:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
def _get_cache_config(self):
|
||||||
|
"""
|
||||||
|
Get the cache config for testing.
|
||||||
|
Should be implemented by subclasses.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclass must implement _get_cache_config")
|
||||||
|
|
||||||
|
def _get_hook_names(self):
|
||||||
|
"""
|
||||||
|
Get the hook names to check for this cache type.
|
||||||
|
Should be implemented by subclasses.
|
||||||
|
Returns a list of hook name strings.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclass must implement _get_hook_names")
|
||||||
|
|
||||||
|
def _test_cache_enable_disable_state(self):
|
||||||
|
"""Test that cache enable/disable updates the is_cache_enabled state correctly."""
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
# Initially cache should not be enabled
|
||||||
|
assert not model.is_cache_enabled, "Cache should not be enabled initially."
|
||||||
|
|
||||||
|
config = self._get_cache_config()
|
||||||
|
|
||||||
|
# Enable cache
|
||||||
|
model.enable_cache(config)
|
||||||
|
assert model.is_cache_enabled, "Cache should be enabled after enable_cache()."
|
||||||
|
|
||||||
|
# Disable cache
|
||||||
|
model.disable_cache()
|
||||||
|
assert not model.is_cache_enabled, "Cache should not be enabled after disable_cache()."
|
||||||
|
|
||||||
|
def _test_cache_double_enable_raises_error(self):
|
||||||
|
"""Test that enabling cache twice raises an error."""
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
config = self._get_cache_config()
|
||||||
|
|
||||||
|
model.enable_cache(config)
|
||||||
|
|
||||||
|
# Trying to enable again should raise ValueError
|
||||||
|
with pytest.raises(ValueError, match="Caching has already been enabled"):
|
||||||
|
model.enable_cache(config)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
model.disable_cache()
|
||||||
|
|
||||||
|
def _test_cache_hooks_registered(self):
|
||||||
|
"""Test that cache hooks are properly registered and removed."""
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
config = self._get_cache_config()
|
||||||
|
hook_names = self._get_hook_names()
|
||||||
|
|
||||||
|
model.enable_cache(config)
|
||||||
|
|
||||||
|
# Check that at least one hook was registered
|
||||||
|
hook_count = 0
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, "_diffusers_hook"):
|
||||||
|
for hook_name in hook_names:
|
||||||
|
hook = module._diffusers_hook.get_hook(hook_name)
|
||||||
|
if hook is not None:
|
||||||
|
hook_count += 1
|
||||||
|
|
||||||
|
assert hook_count > 0, f"At least one cache hook should be registered. Hook names: {hook_names}"
|
||||||
|
|
||||||
|
# Disable and verify hooks are removed
|
||||||
|
model.disable_cache()
|
||||||
|
|
||||||
|
hook_count_after = 0
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, "_diffusers_hook"):
|
||||||
|
for hook_name in hook_names:
|
||||||
|
hook = module._diffusers_hook.get_hook(hook_name)
|
||||||
|
if hook is not None:
|
||||||
|
hook_count_after += 1
|
||||||
|
|
||||||
|
assert hook_count_after == 0, "Cache hooks should be removed after disable_cache()."
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _test_cache_inference(self):
|
||||||
|
"""Test that model can run inference with cache enabled."""
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
config = self._get_cache_config()
|
||||||
|
|
||||||
|
model.enable_cache(config)
|
||||||
|
|
||||||
|
# First pass populates the cache
|
||||||
|
_ = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Create modified inputs for second pass (vary hidden_states to simulate denoising)
|
||||||
|
inputs_dict_step2 = inputs_dict.copy()
|
||||||
|
if "hidden_states" in inputs_dict_step2:
|
||||||
|
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
|
||||||
|
|
||||||
|
# Second pass uses cached attention with different hidden_states (produces approximated output)
|
||||||
|
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||||
|
|
||||||
|
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||||
|
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||||
|
|
||||||
|
# Run same inputs without cache to compare
|
||||||
|
model.disable_cache()
|
||||||
|
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Cached output should be different from non-cached output (due to approximation)
|
||||||
|
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||||
|
"Cached output should be different from non-cached output due to cache approximation."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_cache_context_manager(self):
|
||||||
|
"""Test the cache_context context manager."""
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
config = self._get_cache_config()
|
||||||
|
|
||||||
|
model.enable_cache(config)
|
||||||
|
|
||||||
|
# Test cache_context works without error
|
||||||
|
with model.cache_context("test_context"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
model.disable_cache()
|
||||||
|
|
||||||
|
def _test_reset_stateful_cache(self):
|
||||||
|
"""Test that _reset_stateful_cache resets the cache state."""
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
config = self._get_cache_config()
|
||||||
|
|
||||||
|
model.enable_cache(config)
|
||||||
|
|
||||||
|
# Run forward to populate cache state
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Reset should not raise any errors
|
||||||
|
model._reset_stateful_cache()
|
||||||
|
|
||||||
|
model.disable_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@is_cache
|
||||||
|
class PyramidAttentionBroadcastConfigMixin:
|
||||||
|
"""
|
||||||
|
Base mixin providing PyramidAttentionBroadcast cache config.
|
||||||
|
|
||||||
|
Expected class attributes:
|
||||||
|
- model_class: The model class to test (must use CacheMixin)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default PAB config - can be overridden by subclasses
|
||||||
|
PAB_CONFIG = {
|
||||||
|
"spatial_attention_block_skip_range": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store timestep for callback (must be within default range (100, 800) for skipping to trigger)
|
||||||
|
_current_timestep = 500
|
||||||
|
|
||||||
|
def _get_cache_config(self):
|
||||||
|
config_kwargs = self.PAB_CONFIG.copy()
|
||||||
|
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep
|
||||||
|
return PyramidAttentionBroadcastConfig(**config_kwargs)
|
||||||
|
|
||||||
|
def _get_hook_names(self):
|
||||||
|
return [_PYRAMID_ATTENTION_BROADCAST_HOOK]
|
||||||
|
|
||||||
|
|
||||||
|
@is_cache
|
||||||
|
class PyramidAttentionBroadcastTesterMixin(PyramidAttentionBroadcastConfigMixin, CacheTesterMixin):
|
||||||
|
"""
|
||||||
|
Mixin class for testing PyramidAttentionBroadcast caching on models.
|
||||||
|
|
||||||
|
Expected class attributes:
|
||||||
|
- model_class: The model class to test (must use CacheMixin)
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Pytest mark: cache
|
||||||
|
Use `pytest -m "not cache"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_pab_cache_enable_disable_state(self):
|
||||||
|
self._test_cache_enable_disable_state()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_pab_cache_double_enable_raises_error(self):
|
||||||
|
self._test_cache_double_enable_raises_error()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_pab_cache_hooks_registered(self):
|
||||||
|
self._test_cache_hooks_registered()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_pab_cache_inference(self):
|
||||||
|
self._test_cache_inference()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_pab_cache_context_manager(self):
|
||||||
|
self._test_cache_context_manager()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_pab_reset_stateful_cache(self):
|
||||||
|
self._test_reset_stateful_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@is_cache
|
||||||
|
class FirstBlockCacheConfigMixin:
|
||||||
|
"""
|
||||||
|
Base mixin providing FirstBlockCache config.
|
||||||
|
|
||||||
|
Expected class attributes:
|
||||||
|
- model_class: The model class to test (must use CacheMixin)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default FBC config - can be overridden by subclasses
|
||||||
|
# Higher threshold makes FBC more aggressive about caching (skips more often)
|
||||||
|
FBC_CONFIG = {
|
||||||
|
"threshold": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_cache_config(self):
|
||||||
|
return FirstBlockCacheConfig(**self.FBC_CONFIG)
|
||||||
|
|
||||||
|
def _get_hook_names(self):
|
||||||
|
return [_FBC_LEADER_BLOCK_HOOK, _FBC_BLOCK_HOOK]
|
||||||
|
|
||||||
|
|
||||||
|
@is_cache
|
||||||
|
class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
|
||||||
|
"""
|
||||||
|
Mixin class for testing FirstBlockCache on models.
|
||||||
|
|
||||||
|
Expected class attributes:
|
||||||
|
- model_class: The model class to test (must use CacheMixin)
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Pytest mark: cache
|
||||||
|
Use `pytest -m "not cache"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _test_cache_inference(self):
|
||||||
|
"""Test that model can run inference with FBC cache enabled (requires cache_context)."""
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
config = self._get_cache_config()
|
||||||
|
model.enable_cache(config)
|
||||||
|
|
||||||
|
# FBC requires cache_context to be set for inference
|
||||||
|
with model.cache_context("fbc_test"):
|
||||||
|
# First pass populates the cache
|
||||||
|
_ = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Create modified inputs for second pass (small perturbation keeps residuals similar)
|
||||||
|
inputs_dict_step2 = inputs_dict.copy()
|
||||||
|
if "hidden_states" in inputs_dict_step2:
|
||||||
|
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.01
|
||||||
|
|
||||||
|
# Second pass - FBC should skip remaining blocks and use cached residuals
|
||||||
|
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||||
|
|
||||||
|
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||||
|
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||||
|
|
||||||
|
# Run same inputs without cache to compare
|
||||||
|
model.disable_cache()
|
||||||
|
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Cached output should be different from non-cached output (due to approximation)
|
||||||
|
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||||
|
"Cached output should be different from non-cached output due to cache approximation."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_reset_stateful_cache(self):
|
||||||
|
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
config = self._get_cache_config()
|
||||||
|
model.enable_cache(config)
|
||||||
|
|
||||||
|
# FBC requires cache_context to be set for inference
|
||||||
|
with model.cache_context("fbc_test"):
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Reset should not raise any errors
|
||||||
|
model._reset_stateful_cache()
|
||||||
|
|
||||||
|
model.disable_cache()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_fbc_cache_enable_disable_state(self):
|
||||||
|
self._test_cache_enable_disable_state()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_fbc_cache_double_enable_raises_error(self):
|
||||||
|
self._test_cache_double_enable_raises_error()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_fbc_cache_hooks_registered(self):
|
||||||
|
self._test_cache_hooks_registered()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_fbc_cache_inference(self):
|
||||||
|
self._test_cache_inference()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_fbc_cache_context_manager(self):
|
||||||
|
self._test_cache_context_manager()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_fbc_reset_stateful_cache(self):
|
||||||
|
self._test_reset_stateful_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@is_cache
|
||||||
|
class FasterCacheConfigMixin:
|
||||||
|
"""
|
||||||
|
Base mixin providing FasterCache config.
|
||||||
|
|
||||||
|
Expected class attributes:
|
||||||
|
- model_class: The model class to test (must use CacheMixin)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default FasterCache config - can be overridden by subclasses
|
||||||
|
FASTER_CACHE_CONFIG = {
|
||||||
|
"spatial_attention_block_skip_range": 2,
|
||||||
|
"spatial_attention_timestep_skip_range": (-1, 901),
|
||||||
|
"tensor_format": "BCHW",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store timestep for callback - use a list so it can be mutated during test
|
||||||
|
# Starts outside skip range so first pass computes; changed to inside range for subsequent passes
|
||||||
|
_current_timestep = [1000]
|
||||||
|
|
||||||
|
def _get_cache_config(self):
|
||||||
|
config_kwargs = self.FASTER_CACHE_CONFIG.copy()
|
||||||
|
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep[0]
|
||||||
|
return FasterCacheConfig(**config_kwargs)
|
||||||
|
|
||||||
|
def _get_hook_names(self):
|
||||||
|
return [_FASTER_CACHE_DENOISER_HOOK, _FASTER_CACHE_BLOCK_HOOK]
|
||||||
|
|
||||||
|
|
||||||
|
@is_cache
|
||||||
|
class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
|
||||||
|
"""
|
||||||
|
Mixin class for testing FasterCache on models.
|
||||||
|
|
||||||
|
Note: FasterCache is designed for pipeline-level inference with proper CFG batch handling
|
||||||
|
and timestep management. Inference tests are skipped at model level - FasterCache should
|
||||||
|
be tested via pipeline tests (e.g., FluxPipeline, HunyuanVideoPipeline).
|
||||||
|
|
||||||
|
Expected class attributes:
|
||||||
|
- model_class: The model class to test (must use CacheMixin)
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Pytest mark: cache
|
||||||
|
Use `pytest -m "not cache"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _test_cache_inference(self):
|
||||||
|
"""Test that model can run inference with FasterCache enabled."""
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
config = self._get_cache_config()
|
||||||
|
|
||||||
|
model.enable_cache(config)
|
||||||
|
|
||||||
|
# First pass with timestep outside skip range - computes and populates cache
|
||||||
|
self._current_timestep[0] = 1000
|
||||||
|
_ = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Move timestep inside skip range so subsequent passes use cache
|
||||||
|
self._current_timestep[0] = 500
|
||||||
|
|
||||||
|
# Create modified inputs for second pass
|
||||||
|
inputs_dict_step2 = inputs_dict.copy()
|
||||||
|
if "hidden_states" in inputs_dict_step2:
|
||||||
|
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
|
||||||
|
|
||||||
|
# Second pass uses cached attention with different hidden_states
|
||||||
|
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||||
|
|
||||||
|
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||||
|
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||||
|
|
||||||
|
# Run same inputs without cache to compare
|
||||||
|
model.disable_cache()
|
||||||
|
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Cached output should be different from non-cached output (due to approximation)
|
||||||
|
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||||
|
"Cached output should be different from non-cached output due to cache approximation."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_reset_stateful_cache(self):
|
||||||
|
"""Test that _reset_stateful_cache resets the FasterCache state."""
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
config = self._get_cache_config()
|
||||||
|
model.enable_cache(config)
|
||||||
|
|
||||||
|
# First pass with timestep outside skip range
|
||||||
|
self._current_timestep[0] = 1000
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Reset should not raise any errors
|
||||||
|
model._reset_stateful_cache()
|
||||||
|
|
||||||
|
model.disable_cache()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_faster_cache_enable_disable_state(self):
|
||||||
|
self._test_cache_enable_disable_state()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_faster_cache_double_enable_raises_error(self):
|
||||||
|
self._test_cache_double_enable_raises_error()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_faster_cache_hooks_registered(self):
|
||||||
|
self._test_cache_hooks_registered()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_faster_cache_inference(self):
|
||||||
|
self._test_cache_inference()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_faster_cache_context_manager(self):
|
||||||
|
self._test_cache_context_manager()
|
||||||
|
|
||||||
|
@require_cache_mixin
|
||||||
|
def test_faster_cache_reset_stateful_cache(self):
|
||||||
|
self._test_reset_stateful_cache()
|
||||||
649
tests/models/testing_utils/common.py
Normal file
649
tests/models/testing_utils/common.py
Normal file
@@ -0,0 +1,649 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Dict, Optional, Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
||||||
|
|
||||||
|
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
|
||||||
|
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
|
||||||
|
|
||||||
|
from ...testing_utils import assert_tensors_close, torch_device
|
||||||
|
|
||||||
|
|
||||||
|
def named_persistent_module_tensors(
|
||||||
|
module: nn.Module,
|
||||||
|
recurse: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (`torch.nn.Module`):
|
||||||
|
The module we want the tensors on.
|
||||||
|
recurse (`bool`, *optional`, defaults to `False`):
|
||||||
|
Whether or not to go look in every submodule or just return the direct parameters and buffers.
|
||||||
|
"""
|
||||||
|
yield from module.named_parameters(recurse=recurse)
|
||||||
|
|
||||||
|
for named_buffer in module.named_buffers(recurse=recurse):
|
||||||
|
name, _ = named_buffer
|
||||||
|
# Get parent by splitting on dots and traversing the model
|
||||||
|
parent = module
|
||||||
|
if "." in name:
|
||||||
|
parent_name = name.rsplit(".", 1)[0]
|
||||||
|
for part in parent_name.split("."):
|
||||||
|
parent = getattr(parent, part)
|
||||||
|
name = name.split(".")[-1]
|
||||||
|
if name not in parent._non_persistent_buffers_set:
|
||||||
|
yield named_buffer
|
||||||
|
|
||||||
|
|
||||||
|
def compute_module_persistent_sizes(
|
||||||
|
model: nn.Module,
|
||||||
|
dtype: str | torch.device | None = None,
|
||||||
|
special_dtypes: dict[str, str | torch.device] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute the size of each submodule of a given model (parameters + persistent buffers).
|
||||||
|
"""
|
||||||
|
if dtype is not None:
|
||||||
|
dtype = _get_proper_dtype(dtype)
|
||||||
|
dtype_size = dtype_byte_size(dtype)
|
||||||
|
if special_dtypes is not None:
|
||||||
|
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
|
||||||
|
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
|
||||||
|
module_sizes = defaultdict(int)
|
||||||
|
|
||||||
|
module_list = []
|
||||||
|
|
||||||
|
module_list = named_persistent_module_tensors(model, recurse=True)
|
||||||
|
|
||||||
|
for name, tensor in module_list:
|
||||||
|
if special_dtypes is not None and name in special_dtypes:
|
||||||
|
size = tensor.numel() * special_dtypes_size[name]
|
||||||
|
elif dtype is None:
|
||||||
|
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||||
|
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||||
|
# According to the code in set_module_tensor_to_device, these types won't be converted
|
||||||
|
# so use their original size here
|
||||||
|
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||||
|
else:
|
||||||
|
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
|
||||||
|
name_parts = name.split(".")
|
||||||
|
for idx in range(len(name_parts) + 1):
|
||||||
|
module_sizes[".".join(name_parts[:idx])] += size
|
||||||
|
|
||||||
|
return module_sizes
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_expected_num_shards(index_map_path):
|
||||||
|
"""
|
||||||
|
Calculate expected number of shards from index file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_map_path: Path to the sharded checkpoint index file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Expected number of shards
|
||||||
|
"""
|
||||||
|
with open(index_map_path) as f:
|
||||||
|
weight_map_dict = json.load(f)["weight_map"]
|
||||||
|
first_key = list(weight_map_dict.keys())[0]
|
||||||
|
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
|
||||||
|
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
|
||||||
|
return expected_num_shards
|
||||||
|
|
||||||
|
|
||||||
|
def check_device_map_is_respected(model, device_map):
|
||||||
|
for param_name, param in model.named_parameters():
|
||||||
|
# Find device in device_map
|
||||||
|
while len(param_name) > 0 and param_name not in device_map:
|
||||||
|
param_name = ".".join(param_name.split(".")[:-1])
|
||||||
|
if param_name not in device_map:
|
||||||
|
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
|
||||||
|
|
||||||
|
param_device = device_map[param_name]
|
||||||
|
if param_device in ["cpu", "disk"]:
|
||||||
|
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
|
||||||
|
else:
|
||||||
|
assert param.device == torch.device(param_device), (
|
||||||
|
f"Expected device {param_device} for {param_name}, got {param.device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelTesterConfig:
|
||||||
|
"""
|
||||||
|
Base class defining the configuration interface for model testing.
|
||||||
|
|
||||||
|
This class defines the contract that all model test classes must implement.
|
||||||
|
It provides a consistent interface for accessing model configuration, initialization
|
||||||
|
parameters, and test inputs across all testing mixins.
|
||||||
|
|
||||||
|
Required properties (must be implemented by subclasses):
|
||||||
|
- model_class: The model class to test
|
||||||
|
|
||||||
|
Optional properties (can be overridden, have sensible defaults):
|
||||||
|
- pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None)
|
||||||
|
- pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {})
|
||||||
|
- output_shape: Expected output shape for output validation tests (default: None)
|
||||||
|
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
|
||||||
|
- model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7])
|
||||||
|
|
||||||
|
Required methods (must be implemented by subclasses):
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
class MyModelTestConfig(BaseModelTesterConfig):
|
||||||
|
@property
|
||||||
|
def model_class(self):
|
||||||
|
return MyModel
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pretrained_model_name_or_path(self):
|
||||||
|
return "org/my-model"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
return (1, 3, 32, 32)
|
||||||
|
|
||||||
|
def get_init_dict(self):
|
||||||
|
return {"in_channels": 3, "out_channels": 3}
|
||||||
|
|
||||||
|
def get_dummy_inputs(self):
|
||||||
|
return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)}
|
||||||
|
|
||||||
|
class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin):
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ==================== Required Properties ====================
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_class(self) -> Type[nn.Module]:
|
||||||
|
"""The model class to test. Must be implemented by subclasses."""
|
||||||
|
raise NotImplementedError("Subclasses must implement the `model_class` property.")
|
||||||
|
|
||||||
|
# ==================== Optional Properties ====================
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pretrained_model_name_or_path(self) -> Optional[str]:
|
||||||
|
"""Hub repository ID for the pretrained model (used for quantization and hub tests)."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pretrained_model_kwargs(self) -> Dict[str, Any]:
|
||||||
|
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self) -> Optional[tuple]:
|
||||||
|
"""Expected output shape for output validation tests."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_split_percents(self) -> list:
|
||||||
|
"""Percentages for model parallelism tests."""
|
||||||
|
return [0.5, 0.7]
|
||||||
|
|
||||||
|
# ==================== Required Methods ====================
|
||||||
|
|
||||||
|
def get_init_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Returns dict of arguments to initialize the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Initialization arguments for the model constructor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
return {
|
||||||
|
"in_channels": 3,
|
||||||
|
"out_channels": 3,
|
||||||
|
"sample_size": 32,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement `get_init_dict()`.")
|
||||||
|
|
||||||
|
def get_dummy_inputs(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Returns dict of inputs to pass to the model forward pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Input tensors/values for model.forward().
|
||||||
|
|
||||||
|
Example:
|
||||||
|
return {
|
||||||
|
"sample": torch.randn(1, 3, 32, 32, device=torch_device),
|
||||||
|
"timestep": torch.tensor([1], device=torch_device),
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTesterMixin:
|
||||||
|
"""
|
||||||
|
Base mixin class for model testing with common test methods.
|
||||||
|
|
||||||
|
This mixin expects the test class to also inherit from BaseModelTesterConfig
|
||||||
|
(or implement its interface) which provides:
|
||||||
|
- model_class: The model class to test
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class MyModelTestConfig(BaseModelTesterConfig):
|
||||||
|
model_class = MyModel
|
||||||
|
def get_init_dict(self): ...
|
||||||
|
def get_dummy_inputs(self): ...
|
||||||
|
|
||||||
|
class TestMyModel(MyModelTestConfig, ModelTesterMixin):
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = self.model_class(**self.get_init_dict())
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model.save_pretrained(tmp_path)
|
||||||
|
new_model = self.model_class.from_pretrained(tmp_path)
|
||||||
|
new_model.to(torch_device)
|
||||||
|
|
||||||
|
# check if all parameters shape are the same
|
||||||
|
for param_name in model.state_dict().keys():
|
||||||
|
param_1 = model.state_dict()[param_name]
|
||||||
|
param_2 = new_model.state_dict()[param_name]
|
||||||
|
assert param_1.shape == param_2.shape, (
|
||||||
|
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||||
|
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||||
|
|
||||||
|
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||||
|
|
||||||
|
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
|
||||||
|
model = self.model_class(**self.get_init_dict())
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model.save_pretrained(tmp_path, variant="fp16")
|
||||||
|
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
|
||||||
|
|
||||||
|
# non-variant cannot be loaded
|
||||||
|
with pytest.raises(OSError) as exc_info:
|
||||||
|
self.model_class.from_pretrained(tmp_path)
|
||||||
|
|
||||||
|
# make sure that error message states what keys are missing
|
||||||
|
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
||||||
|
|
||||||
|
new_model.to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||||
|
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||||
|
|
||||||
|
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
||||||
|
def test_from_save_pretrained_dtype(self, tmp_path, dtype):
|
||||||
|
model = self.model_class(**self.get_init_dict())
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||||
|
pytest.skip(reason=f"{dtype} is not supported on {torch_device}")
|
||||||
|
|
||||||
|
model.to(dtype)
|
||||||
|
model.save_pretrained(tmp_path)
|
||||||
|
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||||
|
assert new_model.dtype == dtype
|
||||||
|
if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None:
|
||||||
|
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
|
||||||
|
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype)
|
||||||
|
assert new_model.dtype == dtype
|
||||||
|
|
||||||
|
def test_determinism(self, atol=1e-5, rtol=0):
|
||||||
|
model = self.model_class(**self.get_init_dict())
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||||
|
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||||
|
|
||||||
|
# Filter out NaN values before comparison
|
||||||
|
first_flat = first.flatten()
|
||||||
|
second_flat = second.flatten()
|
||||||
|
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
|
||||||
|
first_filtered = first_flat[mask]
|
||||||
|
second_filtered = second_flat[mask]
|
||||||
|
|
||||||
|
assert_tensors_close(
|
||||||
|
first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_output(self, expected_output_shape=None):
|
||||||
|
model = self.model_class(**self.get_init_dict())
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
assert output is not None, "Model output is None"
|
||||||
|
assert output[0].shape == expected_output_shape or self.output_shape, (
|
||||||
|
f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_outputs_equivalence(self):
|
||||||
|
def set_nan_tensor_to_zero(t):
|
||||||
|
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
|
||||||
|
# Track progress in https://github.com/pytorch/pytorch/issues/77764
|
||||||
|
device = t.device
|
||||||
|
if device.type == "mps":
|
||||||
|
t = t.to("cpu")
|
||||||
|
t[t != t] = 0
|
||||||
|
return t.to(device)
|
||||||
|
|
||||||
|
def recursive_check(tuple_object, dict_object):
|
||||||
|
if isinstance(tuple_object, (list, tuple)):
|
||||||
|
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||||
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||||
|
elif isinstance(tuple_object, dict):
|
||||||
|
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||||
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||||
|
elif tuple_object is None:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
assert_tensors_close(
|
||||||
|
set_nan_tensor_to_zero(tuple_object),
|
||||||
|
set_nan_tensor_to_zero(dict_object),
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0,
|
||||||
|
msg="Tuple and dict output are not equal",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = self.model_class(**self.get_init_dict())
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs_dict = model(**self.get_dummy_inputs())
|
||||||
|
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||||
|
|
||||||
|
recursive_check(outputs_tuple, outputs_dict)
|
||||||
|
|
||||||
|
def test_getattr_is_correct(self, caplog):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
|
||||||
|
# save some things to test
|
||||||
|
model.dummy_attribute = 5
|
||||||
|
model.register_to_config(test_attribute=5)
|
||||||
|
|
||||||
|
logger_name = "diffusers.models.modeling_utils"
|
||||||
|
with caplog.at_level(logging.WARNING, logger=logger_name):
|
||||||
|
caplog.clear()
|
||||||
|
assert hasattr(model, "dummy_attribute")
|
||||||
|
assert getattr(model, "dummy_attribute") == 5
|
||||||
|
assert model.dummy_attribute == 5
|
||||||
|
|
||||||
|
# no warning should be thrown
|
||||||
|
assert caplog.text == ""
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING, logger=logger_name):
|
||||||
|
caplog.clear()
|
||||||
|
assert hasattr(model, "save_pretrained")
|
||||||
|
fn = model.save_pretrained
|
||||||
|
fn_1 = getattr(model, "save_pretrained")
|
||||||
|
|
||||||
|
assert fn == fn_1
|
||||||
|
|
||||||
|
# no warning should be thrown
|
||||||
|
assert caplog.text == ""
|
||||||
|
|
||||||
|
# warning should be thrown for config attributes accessed directly
|
||||||
|
with pytest.warns(FutureWarning):
|
||||||
|
assert model.test_attribute == 5
|
||||||
|
|
||||||
|
with pytest.warns(FutureWarning):
|
||||||
|
assert getattr(model, "test_attribute") == 5
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError) as error:
|
||||||
|
model.does_not_exist
|
||||||
|
|
||||||
|
assert str(error.value) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
|
||||||
|
|
||||||
|
@require_accelerator
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
torch_device not in ["cuda", "xpu"],
|
||||||
|
reason="float16 and bfloat16 can only be used with an accelerator",
|
||||||
|
)
|
||||||
|
def test_keep_in_fp32_modules(self):
|
||||||
|
model = self.model_class(**self.get_init_dict())
|
||||||
|
fp32_modules = model._keep_in_fp32_modules
|
||||||
|
|
||||||
|
if fp32_modules is None or len(fp32_modules) == 0:
|
||||||
|
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
|
||||||
|
|
||||||
|
# Test with float16
|
||||||
|
model.to(torch_device)
|
||||||
|
model.to(torch.float16)
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||||
|
assert param.dtype == torch.float32, f"Parameter {name} should be float32 but got {param.dtype}"
|
||||||
|
else:
|
||||||
|
assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}"
|
||||||
|
|
||||||
|
@require_accelerator
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
torch_device not in ["cuda", "xpu"],
|
||||||
|
reason="float16 and bfloat16 can only be use for inference with an accelerator",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||||
|
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||||
|
model = self.model_class(**self.get_init_dict())
|
||||||
|
model.to(torch_device)
|
||||||
|
fp32_modules = model._keep_in_fp32_modules
|
||||||
|
|
||||||
|
model.to(dtype).save_pretrained(tmp_path)
|
||||||
|
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)
|
||||||
|
|
||||||
|
for name, param in model_loaded.named_parameters():
|
||||||
|
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||||
|
assert param.data.dtype == torch.float32
|
||||||
|
else:
|
||||||
|
assert param.data.dtype == dtype
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||||
|
output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||||
|
|
||||||
|
assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}")
|
||||||
|
|
||||||
|
@require_accelerator
|
||||||
|
def test_sharded_checkpoints(self, tmp_path):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
config = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**config).eval()
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
model_size = compute_module_persistent_sizes(model)[""]
|
||||||
|
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||||
|
|
||||||
|
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
|
||||||
|
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||||
|
|
||||||
|
# Check if the right number of shards exists
|
||||||
|
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
|
||||||
|
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||||
|
assert actual_num_shards == expected_num_shards, (
|
||||||
|
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||||
|
)
|
||||||
|
|
||||||
|
new_model = self.model_class.from_pretrained(tmp_path).eval()
|
||||||
|
new_model = new_model.to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
inputs_dict_new = self.get_dummy_inputs()
|
||||||
|
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||||
|
|
||||||
|
assert_tensors_close(
|
||||||
|
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after sharded save/load"
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_accelerator
|
||||||
|
def test_sharded_checkpoints_with_variant(self, tmp_path):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
config = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**config).eval()
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
model_size = compute_module_persistent_sizes(model)[""]
|
||||||
|
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||||
|
variant = "fp16"
|
||||||
|
|
||||||
|
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant)
|
||||||
|
|
||||||
|
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||||
|
assert os.path.exists(os.path.join(tmp_path, index_filename)), (
|
||||||
|
f"Variant index file {index_filename} should exist"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the right number of shards exists
|
||||||
|
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename))
|
||||||
|
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||||
|
assert actual_num_shards == expected_num_shards, (
|
||||||
|
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||||
|
)
|
||||||
|
|
||||||
|
new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval()
|
||||||
|
new_model = new_model.to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
inputs_dict_new = self.get_dummy_inputs()
|
||||||
|
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||||
|
|
||||||
|
assert_tensors_close(
|
||||||
|
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path):
|
||||||
|
from diffusers.utils import constants
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
config = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**config).eval()
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
model_size = compute_module_persistent_sizes(model)[""]
|
||||||
|
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||||
|
|
||||||
|
# Save original values to restore after test
|
||||||
|
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
|
||||||
|
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
|
||||||
|
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||||
|
|
||||||
|
# Check if the right number of shards exists
|
||||||
|
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
|
||||||
|
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||||
|
assert actual_num_shards == expected_num_shards, (
|
||||||
|
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load without parallel loading
|
||||||
|
constants.HF_ENABLE_PARALLEL_LOADING = False
|
||||||
|
model_sequential = self.model_class.from_pretrained(tmp_path).eval()
|
||||||
|
model_sequential = model_sequential.to(torch_device)
|
||||||
|
|
||||||
|
# Load with parallel loading
|
||||||
|
constants.HF_ENABLE_PARALLEL_LOADING = True
|
||||||
|
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
|
||||||
|
model_parallel = model_parallel.to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
inputs_dict_parallel = self.get_dummy_inputs()
|
||||||
|
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
|
||||||
|
|
||||||
|
assert_tensors_close(
|
||||||
|
base_output, output_parallel, atol=1e-5, rtol=0, msg="Output should match with parallel loading"
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original values
|
||||||
|
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
|
||||||
|
if original_parallel_workers is not None:
|
||||||
|
constants.HF_PARALLEL_WORKERS = original_parallel_workers
|
||||||
|
|
||||||
|
@require_torch_multi_accelerator
|
||||||
|
def test_model_parallelism(self, tmp_path):
|
||||||
|
if self.model_class._no_split_modules is None:
|
||||||
|
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||||
|
|
||||||
|
config = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**config).eval()
|
||||||
|
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
model_size = compute_module_sizes(model)[""]
|
||||||
|
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||||
|
|
||||||
|
model.cpu().save_pretrained(tmp_path)
|
||||||
|
|
||||||
|
for max_size in max_gpu_sizes:
|
||||||
|
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
|
||||||
|
new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory)
|
||||||
|
# Making sure part of the model will be on GPU 0 and GPU 1
|
||||||
|
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
|
||||||
|
|
||||||
|
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
new_output = new_model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
assert_tensors_close(
|
||||||
|
base_output, new_output, atol=1e-5, rtol=0, msg="Output should match with model parallelism"
|
||||||
|
)
|
||||||
160
tests/models/testing_utils/compile.py
Normal file
160
tests/models/testing_utils/compile.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...testing_utils import (
|
||||||
|
backend_empty_cache,
|
||||||
|
is_torch_compile,
|
||||||
|
require_accelerator,
|
||||||
|
require_torch_version_greater,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@is_torch_compile
|
||||||
|
@require_accelerator
|
||||||
|
@require_torch_version_greater("2.7.1")
|
||||||
|
class TorchCompileTesterMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for testing torch.compile functionality on models.
|
||||||
|
|
||||||
|
Expected class attributes to be set by subclasses:
|
||||||
|
- model_class: The model class to test
|
||||||
|
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Pytest mark: compile
|
||||||
|
Use `pytest -m "not compile"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
different_shapes_for_compilation = None
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
torch.compiler.reset()
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
torch.compiler.reset()
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
def test_torch_compile_recompilation_and_graph_break(self):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
model = torch.compile(model, fullgraph=True)
|
||||||
|
|
||||||
|
with (
|
||||||
|
torch._inductor.utils.fresh_inductor_cache(),
|
||||||
|
torch._dynamo.config.patch(error_on_recompile=True),
|
||||||
|
torch.no_grad(),
|
||||||
|
):
|
||||||
|
_ = model(**inputs_dict)
|
||||||
|
_ = model(**inputs_dict)
|
||||||
|
|
||||||
|
def test_torch_compile_repeated_blocks(self):
|
||||||
|
if self.model_class._repeated_blocks is None:
|
||||||
|
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
|
||||||
|
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
model.compile_repeated_blocks(fullgraph=True)
|
||||||
|
|
||||||
|
recompile_limit = 1
|
||||||
|
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||||
|
recompile_limit = 2
|
||||||
|
|
||||||
|
with (
|
||||||
|
torch._inductor.utils.fresh_inductor_cache(),
|
||||||
|
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||||
|
torch.no_grad(),
|
||||||
|
):
|
||||||
|
_ = model(**inputs_dict)
|
||||||
|
_ = model(**inputs_dict)
|
||||||
|
|
||||||
|
def test_compile_with_group_offloading(self):
|
||||||
|
if not self.model_class._supports_group_offloading:
|
||||||
|
pytest.skip("Model does not support group offloading.")
|
||||||
|
|
||||||
|
torch._dynamo.config.cache_size_limit = 10000
|
||||||
|
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
group_offload_kwargs = {
|
||||||
|
"onload_device": torch_device,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"offload_type": "block_level",
|
||||||
|
"num_blocks_per_group": 1,
|
||||||
|
"use_stream": True,
|
||||||
|
"non_blocking": True,
|
||||||
|
}
|
||||||
|
model.enable_group_offload(**group_offload_kwargs)
|
||||||
|
model.compile()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(**inputs_dict)
|
||||||
|
_ = model(**inputs_dict)
|
||||||
|
|
||||||
|
def test_compile_on_different_shapes(self):
|
||||||
|
if self.different_shapes_for_compilation is None:
|
||||||
|
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||||
|
torch.fx.experimental._config.use_duck_shape = False
|
||||||
|
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
model = torch.compile(model, fullgraph=True, dynamic=True)
|
||||||
|
|
||||||
|
for height, width in self.different_shapes_for_compilation:
|
||||||
|
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||||
|
inputs_dict = self.get_dummy_inputs(height=height, width=width)
|
||||||
|
_ = model(**inputs_dict)
|
||||||
|
|
||||||
|
def test_compile_works_with_aot(self, tmp_path):
|
||||||
|
from torch._inductor.package import load_package
|
||||||
|
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
|
||||||
|
|
||||||
|
package_path = os.path.join(str(tmp_path), f"{self.model_class.__name__}.pt2")
|
||||||
|
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
|
||||||
|
assert os.path.exists(package_path), f"Package file not created at {package_path}"
|
||||||
|
loaded_binary = load_package(package_path, run_single_threaded=True)
|
||||||
|
|
||||||
|
model.forward = loaded_binary
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(**inputs_dict)
|
||||||
|
_ = model(**inputs_dict)
|
||||||
138
tests/models/testing_utils/ip_adapter.py
Normal file
138
tests/models/testing_utils/ip_adapter.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...testing_utils import is_ip_adapter, torch_device
|
||||||
|
|
||||||
|
|
||||||
|
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
|
||||||
|
"""
|
||||||
|
Check if IP Adapter processors are correctly set in the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if IP Adapter is correctly set, False otherwise
|
||||||
|
"""
|
||||||
|
for module in model.attn_processors.values():
|
||||||
|
if isinstance(module, processor_cls):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@is_ip_adapter
|
||||||
|
class IPAdapterTesterMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for testing IP Adapter functionality on models.
|
||||||
|
|
||||||
|
Expected class attributes to be set by subclasses:
|
||||||
|
- model_class: The model class to test
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Pytest mark: ip_adapter
|
||||||
|
Use `pytest -m "not ip_adapter"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
ip_adapter_processor_cls = None
|
||||||
|
|
||||||
|
def create_ip_adapter_state_dict(self, model):
|
||||||
|
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
|
||||||
|
|
||||||
|
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||||
|
raise NotImplementedError("child class must implement method to create IPAdapter model inputs")
|
||||||
|
|
||||||
|
def test_load_ip_adapter(self):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||||
|
|
||||||
|
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||||
|
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
|
||||||
|
"IP Adapter processors not set correctly"
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
|
||||||
|
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||||
|
|
||||||
|
assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), (
|
||||||
|
"Output should differ with IP Adapter enabled"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring"
|
||||||
|
)
|
||||||
|
def test_ip_adapter_scale(self):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||||
|
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||||
|
|
||||||
|
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
|
||||||
|
|
||||||
|
# Test scale = 0.0 (no effect)
|
||||||
|
model.set_ip_adapter_scale(0.0)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Test scale = 1.0 (full effect)
|
||||||
|
model.set_ip_adapter_scale(1.0)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||||
|
|
||||||
|
# Outputs should differ with different scales
|
||||||
|
assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), (
|
||||||
|
"Output should differ with different IP Adapter scales"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring"
|
||||||
|
)
|
||||||
|
def test_unload_ip_adapter(self):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
# Save original processors
|
||||||
|
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||||
|
|
||||||
|
# Create and load IP adapter
|
||||||
|
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||||
|
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||||
|
|
||||||
|
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set"
|
||||||
|
|
||||||
|
# Unload IP adapter
|
||||||
|
model.unload_ip_adapter()
|
||||||
|
|
||||||
|
assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
|
||||||
|
"IP Adapter should be unloaded"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify processors are restored
|
||||||
|
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||||
|
assert original_processors == current_processors, "Processors should be restored after unload"
|
||||||
548
tests/models/testing_utils/lora.py
Normal file
548
tests/models/testing_utils/lora.py
Normal file
@@ -0,0 +1,548 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import safetensors.torch
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from diffusers.utils.import_utils import is_peft_available
|
||||||
|
from diffusers.utils.testing_utils import check_if_dicts_are_equal
|
||||||
|
|
||||||
|
from ...testing_utils import (
|
||||||
|
assert_tensors_close,
|
||||||
|
backend_empty_cache,
|
||||||
|
is_lora,
|
||||||
|
is_torch_compile,
|
||||||
|
require_peft_backend,
|
||||||
|
require_peft_version_greater,
|
||||||
|
require_torch_accelerator,
|
||||||
|
require_torch_version_greater,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_peft_available():
|
||||||
|
from diffusers.loaders.peft import PeftAdapterMixin
|
||||||
|
|
||||||
|
|
||||||
|
def check_if_lora_correctly_set(model) -> bool:
|
||||||
|
"""
|
||||||
|
Check if LoRA layers are correctly set in the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if LoRA is correctly set, False otherwise
|
||||||
|
"""
|
||||||
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||||
|
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, BaseTunerLayer):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@is_lora
|
||||||
|
@require_peft_backend
|
||||||
|
class LoraTesterMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for testing LoRA/PEFT functionality on models.
|
||||||
|
|
||||||
|
Expected class attributes to be set by subclasses:
|
||||||
|
- model_class: The model class to test
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Pytest mark: lora
|
||||||
|
Use `pytest -m "not lora"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||||
|
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||||
|
|
||||||
|
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
|
||||||
|
from peft import LoraConfig
|
||||||
|
from peft.utils import get_peft_model_state_dict
|
||||||
|
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
denoiser_lora_config = LoraConfig(
|
||||||
|
r=rank,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||||
|
init_lora_weights=False,
|
||||||
|
use_dora=use_dora,
|
||||||
|
)
|
||||||
|
model.add_adapter(denoiser_lora_config)
|
||||||
|
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4), (
|
||||||
|
"Output should differ with LoRA enabled"
|
||||||
|
)
|
||||||
|
|
||||||
|
model.save_lora_adapter(tmp_path)
|
||||||
|
assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), (
|
||||||
|
"LoRA weights file not created"
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors"))
|
||||||
|
|
||||||
|
model.unload_lora()
|
||||||
|
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||||
|
|
||||||
|
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||||
|
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||||
|
|
||||||
|
for k in state_dict_loaded:
|
||||||
|
loaded_v = state_dict_loaded[k]
|
||||||
|
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||||
|
assert_tensors_close(loaded_v, retrieved_v, atol=1e-5, rtol=0, msg=f"Mismatch in LoRA weight {k}")
|
||||||
|
|
||||||
|
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), (
|
||||||
|
"Output should differ with LoRA enabled"
|
||||||
|
)
|
||||||
|
assert_tensors_close(
|
||||||
|
outputs_with_lora,
|
||||||
|
outputs_with_lora_2,
|
||||||
|
atol=1e-4,
|
||||||
|
rtol=1e-4,
|
||||||
|
msg="Outputs should match before and after save/load",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lora_wrong_adapter_name_raises_error(self, tmp_path):
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
denoiser_lora_config = LoraConfig(
|
||||||
|
r=4,
|
||||||
|
lora_alpha=4,
|
||||||
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||||
|
init_lora_weights=False,
|
||||||
|
use_dora=False,
|
||||||
|
)
|
||||||
|
model.add_adapter(denoiser_lora_config)
|
||||||
|
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||||
|
|
||||||
|
wrong_name = "foo"
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
model.save_lora_adapter(tmp_path, adapter_name=wrong_name)
|
||||||
|
|
||||||
|
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
denoiser_lora_config = LoraConfig(
|
||||||
|
r=rank,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||||
|
init_lora_weights=False,
|
||||||
|
use_dora=use_dora,
|
||||||
|
)
|
||||||
|
model.add_adapter(denoiser_lora_config)
|
||||||
|
metadata = model.peft_config["default"].to_dict()
|
||||||
|
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||||
|
|
||||||
|
model.save_lora_adapter(tmp_path)
|
||||||
|
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
|
||||||
|
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||||
|
|
||||||
|
model.unload_lora()
|
||||||
|
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||||
|
|
||||||
|
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||||
|
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||||
|
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||||
|
|
||||||
|
def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path):
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||||
|
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
denoiser_lora_config = LoraConfig(
|
||||||
|
r=4,
|
||||||
|
lora_alpha=4,
|
||||||
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||||
|
init_lora_weights=False,
|
||||||
|
use_dora=False,
|
||||||
|
)
|
||||||
|
model.add_adapter(denoiser_lora_config)
|
||||||
|
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||||
|
|
||||||
|
model.save_lora_adapter(tmp_path)
|
||||||
|
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
|
||||||
|
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||||
|
|
||||||
|
# Perturb the metadata in the state dict
|
||||||
|
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||||
|
metadata = {"format": "pt"}
|
||||||
|
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||||
|
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||||
|
for key, value in lora_adapter_metadata.items():
|
||||||
|
if isinstance(value, set):
|
||||||
|
lora_adapter_metadata[key] = list(value)
|
||||||
|
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||||
|
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||||
|
|
||||||
|
model.unload_lora()
|
||||||
|
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||||
|
|
||||||
|
with pytest.raises(TypeError) as exc_info:
|
||||||
|
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||||
|
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
@is_lora
|
||||||
|
@is_torch_compile
|
||||||
|
@require_peft_backend
|
||||||
|
@require_peft_version_greater("0.14.0")
|
||||||
|
@require_torch_version_greater("2.7.1")
|
||||||
|
@require_torch_accelerator
|
||||||
|
class LoraHotSwappingForModelTesterMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for testing LoRA hot swapping functionality on models.
|
||||||
|
|
||||||
|
Test that hotswapping does not result in recompilation on the model directly.
|
||||||
|
We're not extensively testing the hotswapping functionality since it is implemented in PEFT
|
||||||
|
and is extensively tested there. The goal of this test is specifically to ensure that
|
||||||
|
hotswapping with diffusers does not require recompilation.
|
||||||
|
|
||||||
|
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
|
||||||
|
for the analogous PEFT test.
|
||||||
|
|
||||||
|
Expected class attributes to be set by subclasses:
|
||||||
|
- model_class: The model class to test
|
||||||
|
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic compilation tests
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Pytest marks: lora, torch_compile
|
||||||
|
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
different_shapes_for_compilation = None
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||||
|
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
|
||||||
|
# there will be recompilation errors, as torch caches the model when run in the same process.
|
||||||
|
torch.compiler.reset()
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
def _get_lora_config(self, lora_rank, lora_alpha, target_modules):
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=lora_rank,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
target_modules=target_modules,
|
||||||
|
init_lora_weights=False,
|
||||||
|
use_dora=False,
|
||||||
|
)
|
||||||
|
return lora_config
|
||||||
|
|
||||||
|
def _get_linear_module_name_other_than_attn(self, model):
|
||||||
|
linear_names = [
|
||||||
|
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
|
||||||
|
]
|
||||||
|
return linear_names[0]
|
||||||
|
|
||||||
|
def _check_model_hotswap(self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None):
|
||||||
|
"""
|
||||||
|
Check that hotswapping works on a model.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
- create 2 LoRA adapters and save them
|
||||||
|
- load the first adapter
|
||||||
|
- hotswap the second adapter
|
||||||
|
- check that the outputs are correct
|
||||||
|
- optionally compile the model
|
||||||
|
- optionally check if recompilations happen on different shapes
|
||||||
|
|
||||||
|
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
|
||||||
|
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
|
||||||
|
fine.
|
||||||
|
"""
|
||||||
|
different_shapes = self.different_shapes_for_compilation
|
||||||
|
# create 2 adapters with different ranks and alphas
|
||||||
|
torch.manual_seed(0)
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
alpha0, alpha1 = rank0, rank1
|
||||||
|
max_rank = max([rank0, rank1])
|
||||||
|
if target_modules1 is None:
|
||||||
|
target_modules1 = target_modules0[:]
|
||||||
|
lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0)
|
||||||
|
lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1)
|
||||||
|
|
||||||
|
model.add_adapter(lora_config0, adapter_name="adapter0")
|
||||||
|
with torch.inference_mode():
|
||||||
|
torch.manual_seed(0)
|
||||||
|
output0_before = model(**inputs_dict)["sample"]
|
||||||
|
|
||||||
|
model.add_adapter(lora_config1, adapter_name="adapter1")
|
||||||
|
model.set_adapter("adapter1")
|
||||||
|
with torch.inference_mode():
|
||||||
|
torch.manual_seed(0)
|
||||||
|
output1_before = model(**inputs_dict)["sample"]
|
||||||
|
|
||||||
|
# sanity checks:
|
||||||
|
tol = 5e-3
|
||||||
|
assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol)
|
||||||
|
assert not (output0_before == 0).all()
|
||||||
|
assert not (output1_before == 0).all()
|
||||||
|
|
||||||
|
# save the adapter checkpoints
|
||||||
|
model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0")
|
||||||
|
model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1")
|
||||||
|
del model
|
||||||
|
|
||||||
|
# load the first adapter
|
||||||
|
torch.manual_seed(0)
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
|
||||||
|
if do_compile or (rank0 != rank1):
|
||||||
|
# no need to prepare if the model is not compiled or if the ranks are identical
|
||||||
|
model.enable_lora_hotswap(target_rank=max_rank)
|
||||||
|
|
||||||
|
file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors")
|
||||||
|
file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors")
|
||||||
|
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
||||||
|
|
||||||
|
if do_compile:
|
||||||
|
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
# additionally check if dynamic compilation works.
|
||||||
|
if different_shapes is not None:
|
||||||
|
for height, width in different_shapes:
|
||||||
|
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||||
|
_ = model(**new_inputs_dict)
|
||||||
|
else:
|
||||||
|
output0_after = model(**inputs_dict)["sample"]
|
||||||
|
assert_tensors_close(
|
||||||
|
output0_before, output0_after, atol=tol, rtol=tol, msg="Output mismatch after loading adapter0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# hotswap the 2nd adapter
|
||||||
|
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
||||||
|
|
||||||
|
# we need to call forward to potentially trigger recompilation
|
||||||
|
with torch.inference_mode():
|
||||||
|
if different_shapes is not None:
|
||||||
|
for height, width in different_shapes:
|
||||||
|
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||||
|
_ = model(**new_inputs_dict)
|
||||||
|
else:
|
||||||
|
output1_after = model(**inputs_dict)["sample"]
|
||||||
|
assert_tensors_close(
|
||||||
|
output1_before,
|
||||||
|
output1_after,
|
||||||
|
atol=tol,
|
||||||
|
rtol=tol,
|
||||||
|
msg="Output mismatch after hotswapping to adapter1",
|
||||||
|
)
|
||||||
|
|
||||||
|
# check error when not passing valid adapter name
|
||||||
|
name = "does-not-exist"
|
||||||
|
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
|
||||||
|
with pytest.raises(ValueError, match=re.escape(msg)):
|
||||||
|
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||||
|
def test_hotswapping_model(self, tmp_path, rank0, rank1):
|
||||||
|
self._check_model_hotswap(
|
||||||
|
tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||||
|
def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1):
|
||||||
|
# It's important to add this context to raise an error on recompilation
|
||||||
|
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||||
|
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||||
|
self._check_model_hotswap(
|
||||||
|
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||||
|
def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1):
|
||||||
|
if "unet" not in self.model_class.__name__.lower():
|
||||||
|
pytest.skip("Test only applies to UNet.")
|
||||||
|
|
||||||
|
# It's important to add this context to raise an error on recompilation
|
||||||
|
target_modules = ["conv", "conv1", "conv2"]
|
||||||
|
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||||
|
self._check_model_hotswap(
|
||||||
|
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||||
|
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1):
|
||||||
|
if "unet" not in self.model_class.__name__.lower():
|
||||||
|
pytest.skip("Test only applies to UNet.")
|
||||||
|
|
||||||
|
# It's important to add this context to raise an error on recompilation
|
||||||
|
target_modules = ["to_q", "conv"]
|
||||||
|
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||||
|
self._check_model_hotswap(
|
||||||
|
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||||
|
def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1):
|
||||||
|
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
|
||||||
|
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
|
||||||
|
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
|
||||||
|
# block.
|
||||||
|
target_modules = ["to_q"]
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
|
||||||
|
target_modules.append(self._get_linear_module_name_other_than_attn(model))
|
||||||
|
del model
|
||||||
|
|
||||||
|
# It's important to add this context to raise an error on recompilation
|
||||||
|
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||||
|
self._check_model_hotswap(
|
||||||
|
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
|
||||||
|
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||||
|
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.add_adapter(lora_config)
|
||||||
|
|
||||||
|
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||||
|
with pytest.raises(RuntimeError, match=msg):
|
||||||
|
model.enable_lora_hotswap(target_rank=32)
|
||||||
|
|
||||||
|
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
|
||||||
|
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||||
|
import logging
|
||||||
|
|
||||||
|
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.add_adapter(lora_config)
|
||||||
|
msg = (
|
||||||
|
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||||
|
)
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||||
|
assert any(msg in record.message for record in caplog.records)
|
||||||
|
|
||||||
|
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
|
||||||
|
# check possibility to ignore the error/warning
|
||||||
|
import logging
|
||||||
|
|
||||||
|
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.add_adapter(lora_config)
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||||
|
assert len(caplog.records) == 0
|
||||||
|
|
||||||
|
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||||
|
# check that wrong argument value raises an error
|
||||||
|
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
model = self.model_class(**init_dict).to(torch_device)
|
||||||
|
model.add_adapter(lora_config)
|
||||||
|
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
|
||||||
|
with pytest.raises(ValueError, match=msg):
|
||||||
|
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||||
|
|
||||||
|
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
|
||||||
|
# check the error and log
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||||
|
target_modules0 = ["to_q"]
|
||||||
|
target_modules1 = ["to_q", "to_k"]
|
||||||
|
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||||
|
with caplog.at_level(logging.ERROR):
|
||||||
|
self._check_model_hotswap(
|
||||||
|
tmp_path,
|
||||||
|
do_compile=True,
|
||||||
|
rank0=8,
|
||||||
|
rank1=8,
|
||||||
|
target_modules0=target_modules0,
|
||||||
|
target_modules1=target_modules1,
|
||||||
|
)
|
||||||
|
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||||
|
@require_torch_version_greater("2.7.1")
|
||||||
|
def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1):
|
||||||
|
different_shapes_for_compilation = self.different_shapes_for_compilation
|
||||||
|
if different_shapes_for_compilation is None:
|
||||||
|
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||||
|
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
|
||||||
|
# variable to represent input sizes that are the same. For more details,
|
||||||
|
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
|
||||||
|
torch.fx.experimental._config.use_duck_shape = False
|
||||||
|
|
||||||
|
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||||
|
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||||
|
self._check_model_hotswap(
|
||||||
|
tmp_path,
|
||||||
|
do_compile=True,
|
||||||
|
rank0=rank0,
|
||||||
|
rank1=rank1,
|
||||||
|
target_modules0=target_modules,
|
||||||
|
)
|
||||||
493
tests/models/testing_utils/memory.py
Normal file
493
tests/models/testing_utils/memory.py
Normal file
@@ -0,0 +1,493 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import glob
|
||||||
|
import inspect
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from accelerate.utils.modeling import compute_module_sizes
|
||||||
|
|
||||||
|
from diffusers.utils.testing_utils import _check_safetensors_serialization
|
||||||
|
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
|
||||||
|
|
||||||
|
from ...testing_utils import (
|
||||||
|
assert_tensors_close,
|
||||||
|
backend_empty_cache,
|
||||||
|
backend_max_memory_allocated,
|
||||||
|
backend_reset_peak_memory_stats,
|
||||||
|
backend_synchronize,
|
||||||
|
is_cpu_offload,
|
||||||
|
is_group_offload,
|
||||||
|
is_memory,
|
||||||
|
require_accelerator,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
from .common import check_device_map_is_respected
|
||||||
|
|
||||||
|
|
||||||
|
def cast_maybe_tensor_dtype(inputs_dict, from_dtype, to_dtype):
|
||||||
|
"""Helper to cast tensor inputs from one dtype to another."""
|
||||||
|
for key, value in inputs_dict.items():
|
||||||
|
if isinstance(value, torch.Tensor) and value.dtype == from_dtype:
|
||||||
|
inputs_dict[key] = value.to(to_dtype)
|
||||||
|
return inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
def require_offload_support(func):
|
||||||
|
"""
|
||||||
|
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
if self.model_class._no_split_modules is None:
|
||||||
|
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def require_group_offload_support(func):
|
||||||
|
"""
|
||||||
|
Decorator to skip tests if model doesn't support group offloading.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
if not self.model_class._supports_group_offloading:
|
||||||
|
pytest.skip("Model does not support group offloading.")
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@is_cpu_offload
|
||||||
|
class CPUOffloadTesterMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for testing CPU offloading functionality.
|
||||||
|
|
||||||
|
Expected class attributes to be set by subclasses:
|
||||||
|
- model_class: The model class to test
|
||||||
|
- model_split_percents: List of percentages for splitting model across devices
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Pytest mark: cpu_offload
|
||||||
|
Use `pytest -m "not cpu_offload"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_split_percents = [0.5, 0.7]
|
||||||
|
|
||||||
|
@require_offload_support
|
||||||
|
def test_cpu_offload(self, tmp_path):
|
||||||
|
config = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**config).eval()
|
||||||
|
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
base_output = model(**inputs_dict)
|
||||||
|
|
||||||
|
model_size = compute_module_sizes(model)[""]
|
||||||
|
# We test several splits of sizes to make sure it works
|
||||||
|
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||||
|
model.cpu().save_pretrained(str(tmp_path))
|
||||||
|
|
||||||
|
for max_size in max_gpu_sizes:
|
||||||
|
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||||
|
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
|
||||||
|
# Making sure part of the model will actually end up offloaded
|
||||||
|
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
|
||||||
|
|
||||||
|
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
new_output = new_model(**inputs_dict)
|
||||||
|
|
||||||
|
assert_tensors_close(
|
||||||
|
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading"
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_offload_support
|
||||||
|
def test_disk_offload_without_safetensors(self, tmp_path):
|
||||||
|
config = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**config).eval()
|
||||||
|
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
base_output = model(**inputs_dict)
|
||||||
|
|
||||||
|
model_size = compute_module_sizes(model)[""]
|
||||||
|
max_size = int(self.model_split_percents[0] * model_size)
|
||||||
|
# Force disk offload by setting very small CPU memory
|
||||||
|
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
|
||||||
|
|
||||||
|
model.cpu().save_pretrained(str(tmp_path), safe_serialization=False)
|
||||||
|
# This errors out because it's missing an offload folder
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
|
||||||
|
|
||||||
|
new_model = self.model_class.from_pretrained(
|
||||||
|
str(tmp_path), device_map="auto", max_memory=max_memory, offload_folder=str(tmp_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
new_output = new_model(**inputs_dict)
|
||||||
|
|
||||||
|
assert_tensors_close(
|
||||||
|
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading"
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_offload_support
|
||||||
|
def test_disk_offload_with_safetensors(self, tmp_path):
|
||||||
|
config = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**config).eval()
|
||||||
|
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
base_output = model(**inputs_dict)
|
||||||
|
|
||||||
|
model_size = compute_module_sizes(model)[""]
|
||||||
|
model.cpu().save_pretrained(str(tmp_path))
|
||||||
|
|
||||||
|
max_size = int(self.model_split_percents[0] * model_size)
|
||||||
|
max_memory = {0: max_size, "cpu": max_size}
|
||||||
|
new_model = self.model_class.from_pretrained(
|
||||||
|
str(tmp_path), device_map="auto", offload_folder=str(tmp_path), max_memory=max_memory
|
||||||
|
)
|
||||||
|
|
||||||
|
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
new_output = new_model(**inputs_dict)
|
||||||
|
|
||||||
|
assert_tensors_close(
|
||||||
|
base_output[0],
|
||||||
|
new_output[0],
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0,
|
||||||
|
msg="Output should match with disk offloading (safetensors)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@is_group_offload
|
||||||
|
class GroupOffloadTesterMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for testing group offloading functionality.
|
||||||
|
|
||||||
|
Expected class attributes to be set by subclasses:
|
||||||
|
- model_class: The model class to test
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Pytest mark: group_offload
|
||||||
|
Use `pytest -m "not group_offload"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
@require_group_offload_support
|
||||||
|
def test_group_offloading(self, record_stream=False):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run_forward(model):
|
||||||
|
assert all(
|
||||||
|
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||||
|
for module in model.modules()
|
||||||
|
if hasattr(module, "_diffusers_hook")
|
||||||
|
), "Group offloading hook should be set"
|
||||||
|
model.eval()
|
||||||
|
return model(**inputs_dict)[0]
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
output_without_group_offloading = run_forward(model)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
|
||||||
|
output_with_group_offloading1 = run_forward(model)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
|
||||||
|
output_with_group_offloading2 = run_forward(model)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.enable_group_offload(torch_device, offload_type="leaf_level")
|
||||||
|
output_with_group_offloading3 = run_forward(model)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.enable_group_offload(
|
||||||
|
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
|
||||||
|
)
|
||||||
|
output_with_group_offloading4 = run_forward(model)
|
||||||
|
|
||||||
|
assert_tensors_close(
|
||||||
|
output_without_group_offloading,
|
||||||
|
output_with_group_offloading1,
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0,
|
||||||
|
msg="Output should match with block-level offloading",
|
||||||
|
)
|
||||||
|
assert_tensors_close(
|
||||||
|
output_without_group_offloading,
|
||||||
|
output_with_group_offloading2,
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0,
|
||||||
|
msg="Output should match with non-blocking block-level offloading",
|
||||||
|
)
|
||||||
|
assert_tensors_close(
|
||||||
|
output_without_group_offloading,
|
||||||
|
output_with_group_offloading3,
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0,
|
||||||
|
msg="Output should match with leaf-level offloading",
|
||||||
|
)
|
||||||
|
assert_tensors_close(
|
||||||
|
output_without_group_offloading,
|
||||||
|
output_with_group_offloading4,
|
||||||
|
atol=1e-5,
|
||||||
|
rtol=0,
|
||||||
|
msg="Output should match with leaf-level offloading with stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_group_offload_support
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_group_offloading_with_layerwise_casting(self, record_stream=False, offload_type="block_level"):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
_ = model(**inputs_dict)[0]
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
storage_dtype, compute_dtype = torch.float16, torch.float32
|
||||||
|
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.eval()
|
||||||
|
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
|
||||||
|
model.enable_group_offload(
|
||||||
|
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
|
||||||
|
)
|
||||||
|
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||||
|
_ = model(**inputs_dict)[0]
|
||||||
|
|
||||||
|
@require_group_offload_support
|
||||||
|
@torch.no_grad()
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_group_offloading_with_disk(self, tmp_path, offload_type="block_level", record_stream=False, atol=1e-5):
|
||||||
|
def _has_generator_arg(model):
|
||||||
|
sig = inspect.signature(model.forward)
|
||||||
|
params = sig.parameters
|
||||||
|
return "generator" in params
|
||||||
|
|
||||||
|
def _run_forward(model, inputs_dict):
|
||||||
|
accepts_generator = _has_generator_arg(model)
|
||||||
|
if accepts_generator:
|
||||||
|
inputs_dict["generator"] = torch.manual_seed(0)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
return model(**inputs_dict)[0]
|
||||||
|
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
model.to(torch_device)
|
||||||
|
output_without_group_offloading = _run_forward(model, inputs_dict)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_blocks_per_group = None if offload_type == "leaf_level" else 1
|
||||||
|
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
|
||||||
|
tmpdir = str(tmp_path)
|
||||||
|
model.enable_group_offload(
|
||||||
|
torch_device,
|
||||||
|
offload_type=offload_type,
|
||||||
|
offload_to_disk_path=tmpdir,
|
||||||
|
use_stream=True,
|
||||||
|
record_stream=record_stream,
|
||||||
|
**additional_kwargs,
|
||||||
|
)
|
||||||
|
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
|
||||||
|
assert has_safetensors, "No safetensors found in the directory."
|
||||||
|
|
||||||
|
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
|
||||||
|
# in nature. So, skip it.
|
||||||
|
if offload_type != "leaf_level":
|
||||||
|
is_correct, extra_files, missing_files = _check_safetensors_serialization(
|
||||||
|
module=model,
|
||||||
|
offload_to_disk_path=tmpdir,
|
||||||
|
offload_type=offload_type,
|
||||||
|
num_blocks_per_group=num_blocks_per_group,
|
||||||
|
)
|
||||||
|
if not is_correct:
|
||||||
|
if extra_files:
|
||||||
|
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
|
||||||
|
elif missing_files:
|
||||||
|
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||||
|
|
||||||
|
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||||
|
assert_tensors_close(
|
||||||
|
output_without_group_offloading,
|
||||||
|
output_with_group_offloading,
|
||||||
|
atol=atol,
|
||||||
|
rtol=0,
|
||||||
|
msg="Output should match with disk-based group offloading",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerwiseCastingTesterMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for testing layerwise dtype casting for memory optimization.
|
||||||
|
|
||||||
|
Expected class attributes to be set by subclasses:
|
||||||
|
- model_class: The model class to test
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_layerwise_casting_memory(self):
|
||||||
|
MB_TOLERANCE = 0.2
|
||||||
|
LEAST_COMPUTE_CAPABILITY = 8.0
|
||||||
|
|
||||||
|
def reset_memory_stats():
|
||||||
|
gc.collect()
|
||||||
|
backend_synchronize(torch_device)
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
backend_reset_peak_memory_stats(torch_device)
|
||||||
|
|
||||||
|
def get_memory_usage(storage_dtype, compute_dtype):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
config = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||||
|
model = self.model_class(**config).eval()
|
||||||
|
model = model.to(torch_device, dtype=compute_dtype)
|
||||||
|
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||||
|
|
||||||
|
reset_memory_stats()
|
||||||
|
model(**inputs_dict)
|
||||||
|
model_memory_footprint = model.get_memory_footprint()
|
||||||
|
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
|
||||||
|
|
||||||
|
return model_memory_footprint, peak_inference_memory_allocated_mb
|
||||||
|
|
||||||
|
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
|
||||||
|
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
|
||||||
|
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
|
||||||
|
torch.float8_e4m3fn, torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
|
||||||
|
assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint, (
|
||||||
|
"Memory footprint should decrease with lower precision storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
|
||||||
|
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
|
||||||
|
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
|
||||||
|
assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory, (
|
||||||
|
"Peak memory should be lower with bf16 compute on newer GPUs"
|
||||||
|
)
|
||||||
|
|
||||||
|
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
|
||||||
|
# bytes. This only happens for some models, so we allow a small tolerance.
|
||||||
|
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
|
||||||
|
assert (
|
||||||
|
fp8_e4m3_fp32_max_memory < fp32_max_memory
|
||||||
|
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
|
||||||
|
), "Peak memory should be lower or within tolerance with fp8 storage"
|
||||||
|
|
||||||
|
def test_layerwise_casting_training(self):
|
||||||
|
def test_fn(storage_dtype, compute_dtype):
|
||||||
|
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
|
||||||
|
pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
|
||||||
|
|
||||||
|
model = self.model_class(**self.get_init_dict())
|
||||||
|
model = model.to(torch_device, dtype=compute_dtype)
|
||||||
|
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
inputs_dict = self.get_inputs_dict()
|
||||||
|
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||||
|
with torch.amp.autocast(device_type=torch.device(torch_device).type):
|
||||||
|
output = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
input_tensor = inputs_dict[self.main_input_name]
|
||||||
|
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
|
||||||
|
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
|
||||||
|
loss = torch.nn.functional.mse_loss(output, noise)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
test_fn(torch.float16, torch.float32)
|
||||||
|
test_fn(torch.float8_e4m3fn, torch.float32)
|
||||||
|
test_fn(torch.float8_e5m2, torch.float32)
|
||||||
|
test_fn(torch.float8_e4m3fn, torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
@is_memory
|
||||||
|
@require_accelerator
|
||||||
|
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
|
||||||
|
"""
|
||||||
|
Combined mixin class for all memory optimization tests including CPU/disk offloading,
|
||||||
|
group offloading, and layerwise dtype casting.
|
||||||
|
|
||||||
|
This mixin inherits from:
|
||||||
|
- CPUOffloadTesterMixin: CPU and disk offloading tests
|
||||||
|
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
|
||||||
|
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
|
||||||
|
|
||||||
|
Expected class attributes to be set by subclasses:
|
||||||
|
- model_class: The model class to test
|
||||||
|
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Pytest mark: memory
|
||||||
|
Use `pytest -m "not memory"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
102
tests/models/testing_utils/parallelism.py
Normal file
102
tests/models/testing_utils/parallelism.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from diffusers.models._modeling_parallel import ContextParallelConfig
|
||||||
|
|
||||||
|
from ...testing_utils import (
|
||||||
|
is_context_parallel,
|
||||||
|
require_torch_multi_accelerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
|
||||||
|
try:
|
||||||
|
# Setup distributed environment
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = "12355"
|
||||||
|
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
init_method="env://",
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
|
||||||
|
model = model_class(**init_dict)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs_on_device = {}
|
||||||
|
for key, value in inputs_dict.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
inputs_on_device[key] = value.to(device)
|
||||||
|
else:
|
||||||
|
inputs_on_device[key] = value
|
||||||
|
|
||||||
|
cp_config = ContextParallelConfig(**cp_dict)
|
||||||
|
model.enable_parallelism(config=cp_config)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(**inputs_on_device, return_dict=False)[0]
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
result_queue.put(("success", output.shape))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if rank == 0:
|
||||||
|
result_queue.put(("error", str(e)))
|
||||||
|
finally:
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
@is_context_parallel
|
||||||
|
@require_torch_multi_accelerator
|
||||||
|
class ContextParallelTesterMixin:
|
||||||
|
base_precision = 1e-3
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||||
|
def test_context_parallel_inference(self, cp_type):
|
||||||
|
if not torch.distributed.is_available():
|
||||||
|
pytest.skip("torch.distributed is not available.")
|
||||||
|
|
||||||
|
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||||
|
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||||
|
|
||||||
|
world_size = 2
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
cp_dict = {cp_type: world_size}
|
||||||
|
|
||||||
|
ctx = mp.get_context("spawn")
|
||||||
|
result_queue = ctx.Queue()
|
||||||
|
|
||||||
|
mp.spawn(
|
||||||
|
_context_parallel_worker,
|
||||||
|
args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue),
|
||||||
|
nprocs=world_size,
|
||||||
|
join=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
status, result = result_queue.get(timeout=60)
|
||||||
|
assert status == "success", f"Context parallel inference failed: {result}"
|
||||||
1291
tests/models/testing_utils/quantization.py
Normal file
1291
tests/models/testing_utils/quantization.py
Normal file
File diff suppressed because it is too large
Load Diff
244
tests/models/testing_utils/single_file.py
Normal file
244
tests/models/testing_utils/single_file.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
|
|
||||||
|
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||||
|
|
||||||
|
from ...testing_utils import (
|
||||||
|
assert_tensors_close,
|
||||||
|
backend_empty_cache,
|
||||||
|
is_single_file,
|
||||||
|
nightly,
|
||||||
|
require_torch_accelerator,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
|
||||||
|
"""Download a single file checkpoint from the Hub to a temporary directory."""
|
||||||
|
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
|
||||||
|
"""Download diffusers config files (excluding weights) from a repository."""
|
||||||
|
path = snapshot_download(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
ignore_patterns=[
|
||||||
|
"**/*.ckpt",
|
||||||
|
"*.ckpt",
|
||||||
|
"**/*.bin",
|
||||||
|
"*.bin",
|
||||||
|
"**/*.pt",
|
||||||
|
"*.pt",
|
||||||
|
"**/*.safetensors",
|
||||||
|
"*.safetensors",
|
||||||
|
],
|
||||||
|
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
|
||||||
|
local_dir=tmpdir,
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@nightly
|
||||||
|
@require_torch_accelerator
|
||||||
|
@is_single_file
|
||||||
|
class SingleFileTesterMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for testing single file loading for models.
|
||||||
|
|
||||||
|
Expected class attributes:
|
||||||
|
- model_class: The model class to test
|
||||||
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||||
|
- ckpt_path: Path or Hub path to the single file checkpoint
|
||||||
|
- subfolder: (Optional) Subfolder within the repo
|
||||||
|
- torch_dtype: (Optional) torch dtype to use for testing
|
||||||
|
|
||||||
|
Pytest mark: single_file
|
||||||
|
Use `pytest -m "not single_file"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
pretrained_model_name_or_path = None
|
||||||
|
ckpt_path = None
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
def test_single_file_model_config(self):
|
||||||
|
pretrained_kwargs = {}
|
||||||
|
single_file_kwargs = {}
|
||||||
|
|
||||||
|
pretrained_kwargs["device"] = torch_device
|
||||||
|
single_file_kwargs["device"] = torch_device
|
||||||
|
|
||||||
|
if hasattr(self, "subfolder") and self.subfolder:
|
||||||
|
pretrained_kwargs["subfolder"] = self.subfolder
|
||||||
|
|
||||||
|
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||||
|
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||||
|
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||||
|
|
||||||
|
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||||
|
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||||
|
|
||||||
|
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||||
|
for param_name, param_value in model_single_file.config.items():
|
||||||
|
if param_name in PARAMS_TO_IGNORE:
|
||||||
|
continue
|
||||||
|
assert model.config[param_name] == param_value, (
|
||||||
|
f"{param_name} differs between pretrained loading and single file loading: "
|
||||||
|
f"pretrained={model.config[param_name]}, single_file={param_value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_single_file_model_parameters(self):
|
||||||
|
pretrained_kwargs = {}
|
||||||
|
single_file_kwargs = {}
|
||||||
|
|
||||||
|
pretrained_kwargs["device"] = torch_device
|
||||||
|
single_file_kwargs["device"] = torch_device
|
||||||
|
|
||||||
|
if hasattr(self, "subfolder") and self.subfolder:
|
||||||
|
pretrained_kwargs["subfolder"] = self.subfolder
|
||||||
|
|
||||||
|
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||||
|
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||||
|
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||||
|
|
||||||
|
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||||
|
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||||
|
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
state_dict_single_file = model_single_file.state_dict()
|
||||||
|
|
||||||
|
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
|
||||||
|
"Model parameters keys differ between pretrained and single file loading. "
|
||||||
|
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
|
||||||
|
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in state_dict.keys():
|
||||||
|
param = state_dict[key]
|
||||||
|
param_single_file = state_dict_single_file[key]
|
||||||
|
|
||||||
|
assert param.shape == param_single_file.shape, (
|
||||||
|
f"Parameter shape mismatch for {key}: "
|
||||||
|
f"pretrained {param.shape} vs single file {param_single_file.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_tensors_close(
|
||||||
|
param, param_single_file, atol=1e-5, rtol=1e-5, msg=f"Parameter values differ for {key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_single_file_loading_local_files_only(self, tmp_path):
|
||||||
|
single_file_kwargs = {}
|
||||||
|
|
||||||
|
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||||
|
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||||
|
|
||||||
|
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||||
|
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
|
||||||
|
|
||||||
|
model_single_file = self.model_class.from_single_file(
|
||||||
|
local_ckpt_path, local_files_only=True, **single_file_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model_single_file is not None, "Failed to load model with local_files_only=True"
|
||||||
|
|
||||||
|
def test_single_file_loading_with_diffusers_config(self):
|
||||||
|
single_file_kwargs = {}
|
||||||
|
|
||||||
|
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||||
|
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||||
|
|
||||||
|
# Load with config parameter
|
||||||
|
model_single_file = self.model_class.from_single_file(
|
||||||
|
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load pretrained for comparison
|
||||||
|
pretrained_kwargs = {}
|
||||||
|
if hasattr(self, "subfolder") and self.subfolder:
|
||||||
|
pretrained_kwargs["subfolder"] = self.subfolder
|
||||||
|
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||||
|
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||||
|
|
||||||
|
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||||
|
|
||||||
|
# Compare configs
|
||||||
|
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||||
|
for param_name, param_value in model_single_file.config.items():
|
||||||
|
if param_name in PARAMS_TO_IGNORE:
|
||||||
|
continue
|
||||||
|
assert model.config[param_name] == param_value, (
|
||||||
|
f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path):
|
||||||
|
single_file_kwargs = {}
|
||||||
|
|
||||||
|
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||||
|
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||||
|
|
||||||
|
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||||
|
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
|
||||||
|
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, str(tmp_path))
|
||||||
|
|
||||||
|
model_single_file = self.model_class.from_single_file(
|
||||||
|
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
|
||||||
|
|
||||||
|
def test_single_file_loading_dtype(self):
|
||||||
|
for dtype in [torch.float32, torch.float16]:
|
||||||
|
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
|
||||||
|
|
||||||
|
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
del model_single_file
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
def test_checkpoint_variant_loading(self):
|
||||||
|
if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths:
|
||||||
|
return
|
||||||
|
|
||||||
|
for ckpt_path in self.alternate_ckpt_paths:
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
single_file_kwargs = {}
|
||||||
|
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||||
|
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||||
|
|
||||||
|
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
|
||||||
|
|
||||||
|
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
|
||||||
|
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
207
tests/models/testing_utils/training.py
Normal file
207
tests/models/testing_utils/training.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from diffusers.training_utils import EMAModel
|
||||||
|
|
||||||
|
from ...testing_utils import is_training, require_torch_accelerator_with_training, torch_all_close, torch_device
|
||||||
|
|
||||||
|
|
||||||
|
@is_training
|
||||||
|
@require_torch_accelerator_with_training
|
||||||
|
class TrainingTesterMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for testing training functionality on models.
|
||||||
|
|
||||||
|
Expected class attributes to be set by subclasses:
|
||||||
|
- model_class: The model class to test
|
||||||
|
|
||||||
|
Expected methods to be implemented by subclasses:
|
||||||
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||||
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||||
|
|
||||||
|
Expected properties to be implemented by subclasses:
|
||||||
|
- output_shape: Tuple defining the expected output shape
|
||||||
|
|
||||||
|
Pytest mark: training
|
||||||
|
Use `pytest -m "not training"` to skip these tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_training(self):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
output = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||||
|
loss = torch.nn.functional.mse_loss(output, noise)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
def test_training_with_ema(self):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
ema_model = EMAModel(model.parameters())
|
||||||
|
|
||||||
|
output = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||||
|
loss = torch.nn.functional.mse_loss(output, noise)
|
||||||
|
loss.backward()
|
||||||
|
ema_model.step(model.parameters())
|
||||||
|
|
||||||
|
def test_gradient_checkpointing(self):
|
||||||
|
if not self.model_class._supports_gradient_checkpointing:
|
||||||
|
pytest.skip("Gradient checkpointing is not supported.")
|
||||||
|
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
|
||||||
|
# at init model should have gradient checkpointing disabled
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
|
||||||
|
|
||||||
|
# check enable works
|
||||||
|
model.enable_gradient_checkpointing()
|
||||||
|
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
|
||||||
|
|
||||||
|
# check disable works
|
||||||
|
model.disable_gradient_checkpointing()
|
||||||
|
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_is_applied(self, expected_set=None):
|
||||||
|
if not self.model_class._supports_gradient_checkpointing:
|
||||||
|
pytest.skip("Gradient checkpointing is not supported.")
|
||||||
|
|
||||||
|
if expected_set is None:
|
||||||
|
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
|
||||||
|
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
|
||||||
|
model_class_copy = copy.copy(self.model_class)
|
||||||
|
model = model_class_copy(**init_dict)
|
||||||
|
model.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
modules_with_gc_enabled = {}
|
||||||
|
for submodule in model.modules():
|
||||||
|
if hasattr(submodule, "gradient_checkpointing"):
|
||||||
|
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
|
||||||
|
modules_with_gc_enabled[submodule.__class__.__name__] = True
|
||||||
|
|
||||||
|
assert set(modules_with_gc_enabled.keys()) == expected_set, (
|
||||||
|
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} do not match expected set {expected_set}"
|
||||||
|
)
|
||||||
|
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
|
||||||
|
if not self.model_class._supports_gradient_checkpointing:
|
||||||
|
pytest.skip("Gradient checkpointing is not supported.")
|
||||||
|
|
||||||
|
if skip is None:
|
||||||
|
skip = set()
|
||||||
|
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
inputs_dict_copy = copy.deepcopy(inputs_dict)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
assert not model.is_gradient_checkpointing and model.training
|
||||||
|
|
||||||
|
out = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
# run the backwards pass on the model
|
||||||
|
model.zero_grad()
|
||||||
|
|
||||||
|
labels = torch.randn_like(out)
|
||||||
|
loss = (out - labels).mean()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# re-instantiate the model now enabling gradient checkpointing
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model_2 = self.model_class(**init_dict)
|
||||||
|
# clone model
|
||||||
|
model_2.load_state_dict(model.state_dict())
|
||||||
|
model_2.to(torch_device)
|
||||||
|
model_2.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
assert model_2.is_gradient_checkpointing and model_2.training
|
||||||
|
|
||||||
|
out_2 = model_2(**inputs_dict_copy, return_dict=False)[0]
|
||||||
|
|
||||||
|
# run the backwards pass on the model
|
||||||
|
model_2.zero_grad()
|
||||||
|
loss_2 = (out_2 - labels).mean()
|
||||||
|
loss_2.backward()
|
||||||
|
|
||||||
|
# compare the output and parameters gradients
|
||||||
|
assert (loss - loss_2).abs() < loss_tolerance, (
|
||||||
|
f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
|
||||||
|
)
|
||||||
|
|
||||||
|
named_params = dict(model.named_parameters())
|
||||||
|
named_params_2 = dict(model_2.named_parameters())
|
||||||
|
|
||||||
|
for name, param in named_params.items():
|
||||||
|
if "post_quant_conv" in name:
|
||||||
|
continue
|
||||||
|
if name in skip:
|
||||||
|
continue
|
||||||
|
if param.grad is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol), (
|
||||||
|
f"Gradient mismatch for {name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_mixed_precision_training(self):
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# Test with float16
|
||||||
|
if torch.device(torch_device).type != "cpu":
|
||||||
|
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
|
||||||
|
output = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||||
|
loss = torch.nn.functional.mse_loss(output, noise)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Test with bfloat16
|
||||||
|
if torch.device(torch_device).type != "cpu":
|
||||||
|
model.zero_grad()
|
||||||
|
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
|
||||||
|
output = model(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||||
|
loss = torch.nn.functional.mse_loss(output, noise)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
@@ -13,23 +13,51 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import unittest
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import FluxTransformer2DModel
|
from diffusers import FluxTransformer2DModel
|
||||||
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
|
||||||
from diffusers.models.embeddings import ImageProjection
|
from diffusers.models.embeddings import ImageProjection
|
||||||
|
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
||||||
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
|
|
||||||
from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
|
from ...testing_utils import enable_full_determinism, torch_device
|
||||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
from ..testing_utils import (
|
||||||
|
AttentionTesterMixin,
|
||||||
|
BaseModelTesterConfig,
|
||||||
|
BitsAndBytesCompileTesterMixin,
|
||||||
|
BitsAndBytesTesterMixin,
|
||||||
|
ContextParallelTesterMixin,
|
||||||
|
FasterCacheTesterMixin,
|
||||||
|
FirstBlockCacheTesterMixin,
|
||||||
|
GGUFCompileTesterMixin,
|
||||||
|
GGUFTesterMixin,
|
||||||
|
IPAdapterTesterMixin,
|
||||||
|
LoraHotSwappingForModelTesterMixin,
|
||||||
|
LoraTesterMixin,
|
||||||
|
MemoryTesterMixin,
|
||||||
|
ModelOptCompileTesterMixin,
|
||||||
|
ModelOptTesterMixin,
|
||||||
|
ModelTesterMixin,
|
||||||
|
PyramidAttentionBroadcastTesterMixin,
|
||||||
|
QuantoCompileTesterMixin,
|
||||||
|
QuantoTesterMixin,
|
||||||
|
SingleFileTesterMixin,
|
||||||
|
TorchAoCompileTesterMixin,
|
||||||
|
TorchAoTesterMixin,
|
||||||
|
TorchCompileTesterMixin,
|
||||||
|
TrainingTesterMixin,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
enable_full_determinism()
|
enable_full_determinism()
|
||||||
|
|
||||||
|
|
||||||
def create_flux_ip_adapter_state_dict(model):
|
# TODO: This standalone function maintains backward compatibility with pipeline tests
|
||||||
# "ip_adapter" (cross-attention weights)
|
# (tests/pipelines/test_pipelines_common.py) and will be refactored.
|
||||||
|
def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Create a dummy IP Adapter state dict for Flux transformer testing."""
|
||||||
ip_cross_attn_state_dict = {}
|
ip_cross_attn_state_dict = {}
|
||||||
key_id = 0
|
key_id = 0
|
||||||
|
|
||||||
@@ -39,7 +67,7 @@ def create_flux_ip_adapter_state_dict(model):
|
|||||||
|
|
||||||
joint_attention_dim = model.config["joint_attention_dim"]
|
joint_attention_dim = model.config["joint_attention_dim"]
|
||||||
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
|
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
|
||||||
sd = FluxIPAdapterJointAttnProcessor2_0(
|
sd = FluxIPAdapterAttnProcessor(
|
||||||
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
|
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
|
||||||
).state_dict()
|
).state_dict()
|
||||||
ip_cross_attn_state_dict.update(
|
ip_cross_attn_state_dict.update(
|
||||||
@@ -50,11 +78,8 @@ def create_flux_ip_adapter_state_dict(model):
|
|||||||
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
key_id += 1
|
key_id += 1
|
||||||
|
|
||||||
# "image_proj" (ImageProjection layer weights)
|
|
||||||
|
|
||||||
image_projection = ImageProjection(
|
image_projection = ImageProjection(
|
||||||
cross_attention_dim=model.config["joint_attention_dim"],
|
cross_attention_dim=model.config["joint_attention_dim"],
|
||||||
image_embed_dim=(
|
image_embed_dim=(
|
||||||
@@ -75,57 +100,37 @@ def create_flux_ip_adapter_state_dict(model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
del sd
|
del sd
|
||||||
ip_state_dict = {}
|
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
|
||||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
|
||||||
return ip_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
class FluxTransformerTesterConfig(BaseModelTesterConfig):
|
||||||
model_class = FluxTransformer2DModel
|
@property
|
||||||
main_input_name = "hidden_states"
|
def model_class(self):
|
||||||
# We override the items here because the transformer under consideration is small.
|
return FluxTransformer2DModel
|
||||||
model_split_percents = [0.7, 0.6, 0.6]
|
|
||||||
|
|
||||||
# Skip setting testing with default: AttnProcessor
|
|
||||||
uses_custom_attn_processor = True
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_input(self):
|
def pretrained_model_name_or_path(self):
|
||||||
return self.prepare_dummy_input()
|
return "hf-internal-testing/tiny-flux-pipe"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_shape(self):
|
def pretrained_model_kwargs(self):
|
||||||
|
return {"subfolder": "transformer"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self) -> tuple[int, int]:
|
||||||
return (16, 4)
|
return (16, 4)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_shape(self):
|
def input_shape(self) -> tuple[int, int]:
|
||||||
return (16, 4)
|
return (16, 4)
|
||||||
|
|
||||||
def prepare_dummy_input(self, height=4, width=4):
|
@property
|
||||||
batch_size = 1
|
def generator(self):
|
||||||
num_latent_channels = 4
|
return torch.Generator("cpu").manual_seed(0)
|
||||||
num_image_channels = 3
|
|
||||||
sequence_length = 48
|
|
||||||
embedding_dim = 32
|
|
||||||
|
|
||||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
|
||||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
|
||||||
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
|
|
||||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
|
|
||||||
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
|
|
||||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
|
||||||
|
|
||||||
|
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||||
|
"""Return Flux model initialization arguments."""
|
||||||
return {
|
return {
|
||||||
"hidden_states": hidden_states,
|
|
||||||
"encoder_hidden_states": encoder_hidden_states,
|
|
||||||
"img_ids": image_ids,
|
|
||||||
"txt_ids": text_ids,
|
|
||||||
"pooled_projections": pooled_prompt_embeds,
|
|
||||||
"timestep": timestep,
|
|
||||||
}
|
|
||||||
|
|
||||||
def prepare_init_args_and_inputs_for_common(self):
|
|
||||||
init_dict = {
|
|
||||||
"patch_size": 1,
|
"patch_size": 1,
|
||||||
"in_channels": 4,
|
"in_channels": 4,
|
||||||
"num_layers": 1,
|
"num_layers": 1,
|
||||||
@@ -137,11 +142,32 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
|||||||
"axes_dims_rope": [4, 4, 8],
|
"axes_dims_rope": [4, 4, 8],
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs_dict = self.dummy_input
|
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||||
return init_dict, inputs_dict
|
batch_size = 1
|
||||||
|
height = width = 4
|
||||||
|
num_latent_channels = 4
|
||||||
|
num_image_channels = 3
|
||||||
|
sequence_length = 48
|
||||||
|
embedding_dim = 32
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), generator=self.generator),
|
||||||
|
"encoder_hidden_states": randn_tensor(
|
||||||
|
(batch_size, sequence_length, embedding_dim), generator=self.generator
|
||||||
|
),
|
||||||
|
"pooled_projections": randn_tensor((batch_size, embedding_dim), generator=self.generator),
|
||||||
|
"img_ids": randn_tensor((height * width, num_image_channels), generator=self.generator),
|
||||||
|
"txt_ids": randn_tensor((sequence_length, num_image_channels), generator=self.generator),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
"""Test that deprecated 3D img_ids and txt_ids still work."""
|
||||||
|
init_dict = self.get_init_dict()
|
||||||
|
inputs_dict = self.get_dummy_inputs()
|
||||||
|
|
||||||
model = self.model_class(**init_dict)
|
model = self.model_class(**init_dict)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -162,63 +188,267 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output_2 = model(**inputs_dict).to_tuple()[0]
|
output_2 = model(**inputs_dict).to_tuple()[0]
|
||||||
|
|
||||||
self.assertEqual(output_1.shape, output_2.shape)
|
assert output_1.shape == output_2.shape
|
||||||
self.assertTrue(
|
assert torch.allclose(output_1, output_2, atol=1e-5), (
|
||||||
torch.allclose(output_1, output_2, atol=1e-5),
|
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
|
||||||
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
|
"are not equal as them as 2d inputs"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_gradient_checkpointing_is_applied(self):
|
|
||||||
expected_set = {"FluxTransformer2DModel"}
|
|
||||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
|
||||||
|
|
||||||
# The test exists for cases like
|
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
|
||||||
# https://github.com/huggingface/diffusers/issues/11874
|
"""Memory optimization tests for Flux Transformer."""
|
||||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
|
||||||
def test_lora_exclude_modules(self):
|
|
||||||
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
|
|
||||||
|
|
||||||
lora_rank = 4
|
pass
|
||||||
target_module = "single_transformer_blocks.0.proj_out"
|
|
||||||
adapter_name = "foo"
|
|
||||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
|
||||||
model = self.model_class(**init_dict).to(torch_device)
|
|
||||||
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
target_mod_shape = state_dict[f"{target_module}.weight"].shape
|
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
|
||||||
lora_state_dict = {
|
"""Training tests for Flux Transformer."""
|
||||||
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
|
|
||||||
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
|
||||||
|
"""Attention processor tests for Flux Transformer."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
|
||||||
|
"""Context Parallel inference tests for Flux Transformer"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
|
||||||
|
"""IP Adapter tests for Flux Transformer."""
|
||||||
|
|
||||||
|
ip_adapter_processor_cls = FluxIPAdapterAttnProcessor
|
||||||
|
|
||||||
|
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
# Create dummy image embeds for IP adapter
|
||||||
|
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
|
||||||
|
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
|
||||||
|
|
||||||
|
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
|
||||||
|
|
||||||
|
return inputs_dict
|
||||||
|
|
||||||
|
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
|
||||||
|
return create_flux_ip_adapter_state_dict(model)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
|
||||||
|
"""LoRA adapter tests for Flux Transformer."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||||
|
"""LoRA hot-swapping tests for Flux Transformer."""
|
||||||
|
|
||||||
|
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||||
|
|
||||||
|
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||||
|
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||||
|
batch_size = 1
|
||||||
|
num_latent_channels = 4
|
||||||
|
num_image_channels = 3
|
||||||
|
sequence_length = 24
|
||||||
|
embedding_dim = 8
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||||
|
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||||
|
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||||
|
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||||
|
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||||
}
|
}
|
||||||
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
|
|
||||||
config = LoraConfig(
|
|
||||||
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
|
|
||||||
)
|
|
||||||
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
|
|
||||||
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
|
|
||||||
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
|
|
||||||
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
|
|
||||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
|
|
||||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
|
|
||||||
|
|
||||||
|
|
||||||
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
|
||||||
model_class = FluxTransformer2DModel
|
|
||||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||||
|
|
||||||
def prepare_init_args_and_inputs_for_common(self):
|
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||||
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
|
"""Override to support dynamic height/width for compilation tests."""
|
||||||
|
batch_size = 1
|
||||||
|
num_latent_channels = 4
|
||||||
|
num_image_channels = 3
|
||||||
|
sequence_length = 24
|
||||||
|
embedding_dim = 8
|
||||||
|
|
||||||
def prepare_dummy_input(self, height, width):
|
return {
|
||||||
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
|
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||||
|
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||||
|
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||||
|
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||||
|
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||||
model_class = FluxTransformer2DModel
|
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||||
|
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
|
||||||
|
subfolder = "transformer"
|
||||||
|
pass
|
||||||
|
|
||||||
def prepare_init_args_and_inputs_for_common(self):
|
|
||||||
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
|
|
||||||
|
|
||||||
def prepare_dummy_input(self, height, width):
|
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||||
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
|
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {
|
||||||
|
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||||
|
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||||
|
"pooled_projections": randn_tensor((1, 768)),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||||
|
"img_ids": randn_tensor((4096, 3)),
|
||||||
|
"txt_ids": randn_tensor((512, 3)),
|
||||||
|
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||||
|
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {
|
||||||
|
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||||
|
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||||
|
"pooled_projections": randn_tensor((1, 768)),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||||
|
"img_ids": randn_tensor((4096, 3)),
|
||||||
|
"txt_ids": randn_tensor((512, 3)),
|
||||||
|
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
|
||||||
|
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {
|
||||||
|
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||||
|
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||||
|
"pooled_projections": randn_tensor((1, 768)),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||||
|
"img_ids": randn_tensor((4096, 3)),
|
||||||
|
"txt_ids": randn_tensor((512, 3)),
|
||||||
|
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
|
||||||
|
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
|
||||||
|
|
||||||
|
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {
|
||||||
|
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||||
|
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||||
|
"pooled_projections": randn_tensor((1, 768)),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||||
|
"img_ids": randn_tensor((4096, 3)),
|
||||||
|
"txt_ids": randn_tensor((512, 3)),
|
||||||
|
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
|
||||||
|
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {
|
||||||
|
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||||
|
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||||
|
"pooled_projections": randn_tensor((1, 768)),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||||
|
"img_ids": randn_tensor((4096, 3)),
|
||||||
|
"txt_ids": randn_tensor((512, 3)),
|
||||||
|
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin):
|
||||||
|
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {
|
||||||
|
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||||
|
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||||
|
"pooled_projections": randn_tensor((1, 768)),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||||
|
"img_ids": randn_tensor((4096, 3)),
|
||||||
|
"txt_ids": randn_tensor((512, 3)),
|
||||||
|
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompileTesterMixin):
|
||||||
|
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {
|
||||||
|
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||||
|
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||||
|
"pooled_projections": randn_tensor((1, 768)),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||||
|
"img_ids": randn_tensor((4096, 3)),
|
||||||
|
"txt_ids": randn_tensor((512, 3)),
|
||||||
|
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin):
|
||||||
|
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {
|
||||||
|
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||||
|
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||||
|
"pooled_projections": randn_tensor((1, 768)),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||||
|
"img_ids": randn_tensor((4096, 3)),
|
||||||
|
"txt_ids": randn_tensor((512, 3)),
|
||||||
|
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin):
|
||||||
|
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
|
||||||
|
|
||||||
|
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {
|
||||||
|
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||||
|
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||||
|
"pooled_projections": randn_tensor((1, 768)),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||||
|
"img_ids": randn_tensor((4096, 3)),
|
||||||
|
"txt_ids": randn_tensor((512, 3)),
|
||||||
|
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
|
||||||
|
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||||
|
return {
|
||||||
|
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||||
|
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||||
|
"pooled_projections": randn_tensor((1, 768)),
|
||||||
|
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||||
|
"img_ids": randn_tensor((4096, 3)),
|
||||||
|
"txt_ids": randn_tensor((512, 3)),
|
||||||
|
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
|
||||||
|
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
|
||||||
|
"""FirstBlockCache tests for Flux Transformer."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin):
|
||||||
|
"""FasterCache tests for Flux Transformer."""
|
||||||
|
|
||||||
|
# Flux is guidance distilled, so we can test at model level without CFG batch handling
|
||||||
|
FASTER_CACHE_CONFIG = {
|
||||||
|
"spatial_attention_block_skip_range": 2,
|
||||||
|
"spatial_attention_timestep_skip_range": (-1, 901),
|
||||||
|
"tensor_format": "BCHW",
|
||||||
|
"is_guidance_distilled": True,
|
||||||
|
}
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from diffusers.utils.import_utils import (
|
|||||||
is_gguf_available,
|
is_gguf_available,
|
||||||
is_kernels_available,
|
is_kernels_available,
|
||||||
is_note_seq_available,
|
is_note_seq_available,
|
||||||
|
is_nvidia_modelopt_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_opencv_available,
|
is_opencv_available,
|
||||||
is_optimum_quanto_available,
|
is_optimum_quanto_available,
|
||||||
@@ -130,6 +131,59 @@ def torch_all_close(a, b, *args, **kwargs):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def assert_tensors_close(
|
||||||
|
actual: "torch.Tensor",
|
||||||
|
expected: "torch.Tensor",
|
||||||
|
atol: float = 1e-5,
|
||||||
|
rtol: float = 1e-5,
|
||||||
|
msg: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Assert that two tensors are close within tolerance.
|
||||||
|
|
||||||
|
Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected|
|
||||||
|
Provides concise, actionable error messages without dumping full tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actual: The actual tensor from the computation.
|
||||||
|
expected: The expected tensor to compare against.
|
||||||
|
atol: Absolute tolerance.
|
||||||
|
rtol: Relative tolerance.
|
||||||
|
msg: Optional message prefix for the assertion error.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If tensors have different shapes or values exceed tolerance.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass")
|
||||||
|
"""
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("PyTorch needs to be installed to use this function.")
|
||||||
|
|
||||||
|
if actual.shape != expected.shape:
|
||||||
|
raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")
|
||||||
|
|
||||||
|
if not torch.allclose(actual, expected, atol=atol, rtol=rtol):
|
||||||
|
abs_diff = (actual - expected).abs()
|
||||||
|
max_diff = abs_diff.max().item()
|
||||||
|
|
||||||
|
flat_idx = abs_diff.argmax().item()
|
||||||
|
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
|
||||||
|
|
||||||
|
threshold = atol + rtol * expected.abs()
|
||||||
|
mismatched = (abs_diff > threshold).sum().item()
|
||||||
|
total = actual.numel()
|
||||||
|
|
||||||
|
raise AssertionError(
|
||||||
|
f"{msg}\n"
|
||||||
|
f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n"
|
||||||
|
f" Max diff: {max_diff:.6e} at index {max_idx}\n"
|
||||||
|
f" Actual: {actual.flatten()[flat_idx].item():.6e}\n"
|
||||||
|
f" Expected: {expected.flatten()[flat_idx].item():.6e}\n"
|
||||||
|
f" atol: {atol:.6e}, rtol: {rtol:.6e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def numpy_cosine_similarity_distance(a, b):
|
def numpy_cosine_similarity_distance(a, b):
|
||||||
similarity = np.dot(a, b) / (norm(a) * norm(b))
|
similarity = np.dot(a, b) / (norm(a) * norm(b))
|
||||||
distance = 1.0 - similarity.mean()
|
distance = 1.0 - similarity.mean()
|
||||||
@@ -241,7 +295,6 @@ def parse_flag_from_env(key, default=False):
|
|||||||
|
|
||||||
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||||||
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
|
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
|
||||||
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
|
|
||||||
|
|
||||||
|
|
||||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||||
@@ -282,12 +335,155 @@ def nightly(test_case):
|
|||||||
|
|
||||||
def is_torch_compile(test_case):
|
def is_torch_compile(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that runs compile tests in the diffusers CI.
|
Decorator marking a test as a torch.compile test. These tests can be filtered using:
|
||||||
|
pytest -m "not compile" to skip
|
||||||
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
|
pytest -m compile to run only these tests
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
|
return pytest.mark.compile(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_file(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a single file loading test. These tests can be filtered using:
|
||||||
|
pytest -m "not single_file" to skip
|
||||||
|
pytest -m single_file to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.single_file(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_lora(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a LoRA test. These tests can be filtered using:
|
||||||
|
pytest -m "not lora" to skip
|
||||||
|
pytest -m lora to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.lora(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_ip_adapter(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as an IP Adapter test. These tests can be filtered using:
|
||||||
|
pytest -m "not ip_adapter" to skip
|
||||||
|
pytest -m ip_adapter to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.ip_adapter(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_training(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a training test. These tests can be filtered using:
|
||||||
|
pytest -m "not training" to skip
|
||||||
|
pytest -m training to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.training(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_attention(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as an attention test. These tests can be filtered using:
|
||||||
|
pytest -m "not attention" to skip
|
||||||
|
pytest -m attention to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.attention(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_memory(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a memory optimization test. These tests can be filtered using:
|
||||||
|
pytest -m "not memory" to skip
|
||||||
|
pytest -m memory to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.memory(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cpu_offload(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a CPU offload test. These tests can be filtered using:
|
||||||
|
pytest -m "not cpu_offload" to skip
|
||||||
|
pytest -m cpu_offload to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.cpu_offload(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_group_offload(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a group offload test. These tests can be filtered using:
|
||||||
|
pytest -m "not group_offload" to skip
|
||||||
|
pytest -m group_offload to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.group_offload(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_quantization(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a quantization test. These tests can be filtered using:
|
||||||
|
pytest -m "not quantization" to skip
|
||||||
|
pytest -m quantization to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.quantization(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_bitsandbytes(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
|
||||||
|
pytest -m "not bitsandbytes" to skip
|
||||||
|
pytest -m bitsandbytes to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.bitsandbytes(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_quanto(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a Quanto quantization test. These tests can be filtered using:
|
||||||
|
pytest -m "not quanto" to skip
|
||||||
|
pytest -m quanto to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.quanto(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_torchao(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a TorchAO quantization test. These tests can be filtered using:
|
||||||
|
pytest -m "not torchao" to skip
|
||||||
|
pytest -m torchao to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.torchao(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_gguf(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a GGUF quantization test. These tests can be filtered using:
|
||||||
|
pytest -m "not gguf" to skip
|
||||||
|
pytest -m gguf to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.gguf(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_modelopt(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using:
|
||||||
|
pytest -m "not modelopt" to skip
|
||||||
|
pytest -m modelopt to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.modelopt(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_context_parallel(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a context parallel inference test. These tests can be filtered using:
|
||||||
|
pytest -m "not context_parallel" to skip
|
||||||
|
pytest -m context_parallel to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.context_parallel(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cache(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as a cache test. These tests can be filtered using:
|
||||||
|
pytest -m "not cache" to skip
|
||||||
|
pytest -m cache to run only these tests
|
||||||
|
"""
|
||||||
|
return pytest.mark.cache(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_torch(test_case):
|
def require_torch(test_case):
|
||||||
@@ -650,6 +846,19 @@ def require_kernels_version_greater_or_equal(kernels_version):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def require_modelopt_version_greater_or_equal(modelopt_version):
|
||||||
|
def decorator(test_case):
|
||||||
|
correct_nvidia_modelopt_version = is_nvidia_modelopt_available() and version.parse(
|
||||||
|
version.parse(importlib.metadata.version("modelopt")).base_version
|
||||||
|
) >= version.parse(modelopt_version)
|
||||||
|
return pytest.mark.skipif(
|
||||||
|
not correct_nvidia_modelopt_version,
|
||||||
|
reason=f"Test requires modelopt with version greater than {modelopt_version}.",
|
||||||
|
)(test_case)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def deprecate_after_peft_backend(test_case):
|
def deprecate_after_peft_backend(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that will be skipped after PEFT backend
|
Decorator marking a test that will be skipped after PEFT backend
|
||||||
|
|||||||
509
utils/generate_model_tests.py
Normal file
509
utils/generate_model_tests.py
Normal file
@@ -0,0 +1,509 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Utility script to generate test suites for diffusers model classes.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_flux.py
|
||||||
|
|
||||||
|
This will analyze the model file and generate a test file with appropriate
|
||||||
|
test classes based on the model's mixins and attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
MIXIN_TO_TESTER = {
|
||||||
|
"ModelMixin": "ModelTesterMixin",
|
||||||
|
"PeftAdapterMixin": "LoraTesterMixin",
|
||||||
|
}
|
||||||
|
|
||||||
|
ATTRIBUTE_TO_TESTER = {
|
||||||
|
"_cp_plan": "ContextParallelTesterMixin",
|
||||||
|
"_supports_gradient_checkpointing": "TrainingTesterMixin",
|
||||||
|
}
|
||||||
|
|
||||||
|
ALWAYS_INCLUDE_TESTERS = [
|
||||||
|
"ModelTesterMixin",
|
||||||
|
"MemoryTesterMixin",
|
||||||
|
"TorchCompileTesterMixin",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Attention-related class names that indicate the model uses attention
|
||||||
|
ATTENTION_INDICATORS = {
|
||||||
|
"AttentionMixin",
|
||||||
|
"AttentionModuleMixin",
|
||||||
|
}
|
||||||
|
|
||||||
|
OPTIONAL_TESTERS = [
|
||||||
|
("BitsAndBytesTesterMixin", "bnb"),
|
||||||
|
("QuantoTesterMixin", "quanto"),
|
||||||
|
("TorchAoTesterMixin", "torchao"),
|
||||||
|
("GGUFTesterMixin", "gguf"),
|
||||||
|
("ModelOptTesterMixin", "modelopt"),
|
||||||
|
("SingleFileTesterMixin", "single_file"),
|
||||||
|
("IPAdapterTesterMixin", "ip_adapter"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelAnalyzer(ast.NodeVisitor):
|
||||||
|
def __init__(self):
|
||||||
|
self.model_classes = []
|
||||||
|
self.current_class = None
|
||||||
|
self.imports = set()
|
||||||
|
|
||||||
|
def visit_Import(self, node: ast.Import):
|
||||||
|
for alias in node.names:
|
||||||
|
self.imports.add(alias.name.split(".")[-1])
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||||||
|
for alias in node.names:
|
||||||
|
self.imports.add(alias.name)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def visit_ClassDef(self, node: ast.ClassDef):
|
||||||
|
base_names = []
|
||||||
|
for base in node.bases:
|
||||||
|
if isinstance(base, ast.Name):
|
||||||
|
base_names.append(base.id)
|
||||||
|
elif isinstance(base, ast.Attribute):
|
||||||
|
base_names.append(base.attr)
|
||||||
|
|
||||||
|
if "ModelMixin" in base_names:
|
||||||
|
class_info = {
|
||||||
|
"name": node.name,
|
||||||
|
"bases": base_names,
|
||||||
|
"attributes": {},
|
||||||
|
"has_forward": False,
|
||||||
|
"init_params": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for item in node.body:
|
||||||
|
if isinstance(item, ast.Assign):
|
||||||
|
for target in item.targets:
|
||||||
|
if isinstance(target, ast.Name):
|
||||||
|
attr_name = target.id
|
||||||
|
if attr_name.startswith("_"):
|
||||||
|
class_info["attributes"][attr_name] = self._get_value(item.value)
|
||||||
|
|
||||||
|
elif isinstance(item, ast.FunctionDef):
|
||||||
|
if item.name == "forward":
|
||||||
|
class_info["has_forward"] = True
|
||||||
|
class_info["forward_params"] = self._extract_func_params(item)
|
||||||
|
elif item.name == "__init__":
|
||||||
|
class_info["init_params"] = self._extract_func_params(item)
|
||||||
|
|
||||||
|
self.model_classes.append(class_info)
|
||||||
|
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def _extract_func_params(self, func_node: ast.FunctionDef) -> list[dict]:
|
||||||
|
params = []
|
||||||
|
args = func_node.args
|
||||||
|
|
||||||
|
num_defaults = len(args.defaults)
|
||||||
|
num_args = len(args.args)
|
||||||
|
first_default_idx = num_args - num_defaults
|
||||||
|
|
||||||
|
for i, arg in enumerate(args.args):
|
||||||
|
if arg.arg == "self":
|
||||||
|
continue
|
||||||
|
|
||||||
|
param_info = {"name": arg.arg, "type": None, "default": None}
|
||||||
|
|
||||||
|
if arg.annotation:
|
||||||
|
param_info["type"] = self._get_annotation_str(arg.annotation)
|
||||||
|
|
||||||
|
default_idx = i - first_default_idx
|
||||||
|
if default_idx >= 0 and default_idx < len(args.defaults):
|
||||||
|
param_info["default"] = self._get_value(args.defaults[default_idx])
|
||||||
|
|
||||||
|
params.append(param_info)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _get_annotation_str(self, node) -> str:
|
||||||
|
if isinstance(node, ast.Name):
|
||||||
|
return node.id
|
||||||
|
elif isinstance(node, ast.Constant):
|
||||||
|
return repr(node.value)
|
||||||
|
elif isinstance(node, ast.Subscript):
|
||||||
|
base = self._get_annotation_str(node.value)
|
||||||
|
if isinstance(node.slice, ast.Tuple):
|
||||||
|
args = ", ".join(self._get_annotation_str(el) for el in node.slice.elts)
|
||||||
|
else:
|
||||||
|
args = self._get_annotation_str(node.slice)
|
||||||
|
return f"{base}[{args}]"
|
||||||
|
elif isinstance(node, ast.Attribute):
|
||||||
|
return f"{self._get_annotation_str(node.value)}.{node.attr}"
|
||||||
|
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||||
|
left = self._get_annotation_str(node.left)
|
||||||
|
right = self._get_annotation_str(node.right)
|
||||||
|
return f"{left} | {right}"
|
||||||
|
elif isinstance(node, ast.Tuple):
|
||||||
|
return ", ".join(self._get_annotation_str(el) for el in node.elts)
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
def _get_value(self, node):
|
||||||
|
if isinstance(node, ast.Constant):
|
||||||
|
return node.value
|
||||||
|
elif isinstance(node, ast.Name):
|
||||||
|
if node.id == "None":
|
||||||
|
return None
|
||||||
|
elif node.id == "True":
|
||||||
|
return True
|
||||||
|
elif node.id == "False":
|
||||||
|
return False
|
||||||
|
return node.id
|
||||||
|
elif isinstance(node, ast.List):
|
||||||
|
return [self._get_value(el) for el in node.elts]
|
||||||
|
elif isinstance(node, ast.Dict):
|
||||||
|
return {self._get_value(k): self._get_value(v) for k, v in zip(node.keys, node.values)}
|
||||||
|
return "<complex>"
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_model_file(filepath: str) -> tuple[list[dict], set[str]]:
|
||||||
|
with open(filepath) as f:
|
||||||
|
source = f.read()
|
||||||
|
|
||||||
|
tree = ast.parse(source)
|
||||||
|
analyzer = ModelAnalyzer()
|
||||||
|
analyzer.visit(tree)
|
||||||
|
|
||||||
|
return analyzer.model_classes, analyzer.imports
|
||||||
|
|
||||||
|
|
||||||
|
def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]:
|
||||||
|
testers = list(ALWAYS_INCLUDE_TESTERS)
|
||||||
|
|
||||||
|
for base in model_info["bases"]:
|
||||||
|
if base in MIXIN_TO_TESTER:
|
||||||
|
tester = MIXIN_TO_TESTER[base]
|
||||||
|
if tester not in testers:
|
||||||
|
testers.append(tester)
|
||||||
|
|
||||||
|
for attr, tester in ATTRIBUTE_TO_TESTER.items():
|
||||||
|
if attr in model_info["attributes"]:
|
||||||
|
value = model_info["attributes"][attr]
|
||||||
|
if value is not None and value is not False:
|
||||||
|
if tester not in testers:
|
||||||
|
testers.append(tester)
|
||||||
|
|
||||||
|
if "_cp_plan" in model_info["attributes"] and model_info["attributes"]["_cp_plan"] is not None:
|
||||||
|
if "ContextParallelTesterMixin" not in testers:
|
||||||
|
testers.append("ContextParallelTesterMixin")
|
||||||
|
|
||||||
|
# Include AttentionTesterMixin if the model imports attention-related classes
|
||||||
|
if imports & ATTENTION_INDICATORS:
|
||||||
|
testers.append("AttentionTesterMixin")
|
||||||
|
|
||||||
|
for tester, flag in OPTIONAL_TESTERS:
|
||||||
|
if flag in include_optional:
|
||||||
|
if tester not in testers:
|
||||||
|
testers.append(tester)
|
||||||
|
|
||||||
|
return testers
|
||||||
|
|
||||||
|
|
||||||
|
def generate_config_class(model_info: dict, model_name: str) -> str:
|
||||||
|
class_name = f"{model_name}TesterConfig"
|
||||||
|
model_class = model_info["name"]
|
||||||
|
forward_params = model_info.get("forward_params", [])
|
||||||
|
init_params = model_info.get("init_params", [])
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"class {class_name}:",
|
||||||
|
f" model_class = {model_class}",
|
||||||
|
' pretrained_model_name_or_path = ""',
|
||||||
|
' pretrained_model_kwargs = {"subfolder": "transformer"}',
|
||||||
|
"",
|
||||||
|
" @property",
|
||||||
|
" def generator(self):",
|
||||||
|
' return torch.Generator("cpu").manual_seed(0)',
|
||||||
|
"",
|
||||||
|
" def get_init_dict(self) -> dict[str, int | list[int]]:",
|
||||||
|
]
|
||||||
|
|
||||||
|
if init_params:
|
||||||
|
lines.append(" # __init__ parameters:")
|
||||||
|
for param in init_params:
|
||||||
|
type_str = f": {param['type']}" if param["type"] else ""
|
||||||
|
default_str = f" = {param['default']}" if param["default"] is not None else ""
|
||||||
|
lines.append(f" # {param['name']}{type_str}{default_str}")
|
||||||
|
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
" return {}",
|
||||||
|
"",
|
||||||
|
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if forward_params:
|
||||||
|
lines.append(" # forward() parameters:")
|
||||||
|
for param in forward_params:
|
||||||
|
type_str = f": {param['type']}" if param["type"] else ""
|
||||||
|
default_str = f" = {param['default']}" if param["default"] is not None else ""
|
||||||
|
lines.append(f" # {param['name']}{type_str}{default_str}")
|
||||||
|
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
" # TODO: Fill in dummy inputs",
|
||||||
|
" return {}",
|
||||||
|
"",
|
||||||
|
" @property",
|
||||||
|
" def input_shape(self) -> tuple[int, ...]:",
|
||||||
|
" return (1, 1)",
|
||||||
|
"",
|
||||||
|
" @property",
|
||||||
|
" def output_shape(self) -> tuple[int, ...]:",
|
||||||
|
" return (1, 1)",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_test_class(model_name: str, config_class: str, tester: str) -> str:
|
||||||
|
tester_short = tester.replace("TesterMixin", "")
|
||||||
|
class_name = f"Test{model_name}{tester_short}"
|
||||||
|
|
||||||
|
lines = [f"class {class_name}({config_class}, {tester}):"]
|
||||||
|
|
||||||
|
if tester == "TorchCompileTesterMixin":
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
" different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]",
|
||||||
|
"",
|
||||||
|
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
|
||||||
|
" # TODO: Implement dynamic input generation",
|
||||||
|
" return {}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif tester == "IPAdapterTesterMixin":
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
" ip_adapter_processor_cls = None # TODO: Set processor class",
|
||||||
|
"",
|
||||||
|
" def modify_inputs_for_ip_adapter(self, model, inputs_dict):",
|
||||||
|
" # TODO: Add IP adapter image embeds to inputs",
|
||||||
|
" return inputs_dict",
|
||||||
|
"",
|
||||||
|
" def create_ip_adapter_state_dict(self, model):",
|
||||||
|
" # TODO: Create IP adapter state dict",
|
||||||
|
" return {}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif tester == "SingleFileTesterMixin":
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
' ckpt_path = "" # TODO: Set checkpoint path',
|
||||||
|
" alternate_keys_ckpt_paths = []",
|
||||||
|
' pretrained_model_name_or_path = ""',
|
||||||
|
' subfolder = "transformer"',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif tester == "GGUFTesterMixin":
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
' gguf_filename = "" # TODO: Set GGUF filename',
|
||||||
|
"",
|
||||||
|
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||||
|
" # TODO: Override with larger inputs for quantization tests",
|
||||||
|
" return {}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif tester in ["BitsAndBytesTesterMixin", "QuantoTesterMixin", "TorchAoTesterMixin", "ModelOptTesterMixin"]:
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||||
|
" # TODO: Override with larger inputs for quantization tests",
|
||||||
|
" return {}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif tester == "LoraHotSwappingForModelTesterMixin":
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
" different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]",
|
||||||
|
"",
|
||||||
|
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
|
||||||
|
" # TODO: Implement dynamic input generation",
|
||||||
|
" return {}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lines.append(" pass")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str:
|
||||||
|
model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "")
|
||||||
|
testers = determine_testers(model_info, include_optional, imports)
|
||||||
|
tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"})
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"# coding=utf-8",
|
||||||
|
"# Copyright 2025 HuggingFace Inc.",
|
||||||
|
"#",
|
||||||
|
'# Licensed under the Apache License, Version 2.0 (the "License");',
|
||||||
|
"# you may not use this file except in compliance with the License.",
|
||||||
|
"# You may obtain a copy of the License at",
|
||||||
|
"#",
|
||||||
|
"# http://www.apache.org/licenses/LICENSE-2.0",
|
||||||
|
"#",
|
||||||
|
"# Unless required by applicable law or agreed to in writing, software",
|
||||||
|
'# distributed under the License is distributed on an "AS IS" BASIS,',
|
||||||
|
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
|
||||||
|
"# See the License for the specific language governing permissions and",
|
||||||
|
"# limitations under the License.",
|
||||||
|
"",
|
||||||
|
"import torch",
|
||||||
|
"",
|
||||||
|
f"from diffusers import {model_info['name']}",
|
||||||
|
"from diffusers.utils.torch_utils import randn_tensor",
|
||||||
|
"",
|
||||||
|
"from ...testing_utils import enable_full_determinism, torch_device",
|
||||||
|
]
|
||||||
|
|
||||||
|
if "LoraTesterMixin" in testers:
|
||||||
|
lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin")
|
||||||
|
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
"from ..testing_utils import (",
|
||||||
|
*[f" {tester}," for tester in sorted(tester_imports)],
|
||||||
|
")",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"enable_full_determinism()",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
config_class = f"{model_name}TesterConfig"
|
||||||
|
lines.append(generate_config_class(model_info, model_name))
|
||||||
|
lines.append("")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
for tester in testers:
|
||||||
|
lines.append(generate_test_class(model_name, config_class, tester))
|
||||||
|
lines.append("")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
if "LoraTesterMixin" in testers:
|
||||||
|
lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin"))
|
||||||
|
lines.append("")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines).rstrip() + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_output_path(model_filepath: str) -> str:
|
||||||
|
path = Path(model_filepath)
|
||||||
|
model_filename = path.stem
|
||||||
|
|
||||||
|
if "transformers" in path.parts:
|
||||||
|
return f"tests/models/transformers/test_models_{model_filename}.py"
|
||||||
|
elif "unets" in path.parts:
|
||||||
|
return f"tests/models/unets/test_models_{model_filename}.py"
|
||||||
|
elif "autoencoders" in path.parts:
|
||||||
|
return f"tests/models/autoencoders/test_models_{model_filename}.py"
|
||||||
|
else:
|
||||||
|
return f"tests/models/test_models_{model_filename}.py"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Generate test suite for a diffusers model class")
|
||||||
|
parser.add_argument(
|
||||||
|
"model_filepath",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model file (e.g., src/diffusers/models/transformers/transformer_flux.py)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", "-o", type=str, default=None, help="Output file path (default: auto-generated based on model path)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--include",
|
||||||
|
"-i",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
default=[],
|
||||||
|
choices=["compile", "bnb", "quanto", "torchao", "gguf", "modelopt", "single_file", "ip_adapter", "all"],
|
||||||
|
help="Optional testers to include",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--class-name",
|
||||||
|
"-c",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Specific model class to generate tests for (default: first model class found)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--dry-run", action="store_true", help="Print generated code without writing to file")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not Path(args.model_filepath).exists():
|
||||||
|
print(f"Error: File not found: {args.model_filepath}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
model_classes, imports = analyze_model_file(args.model_filepath)
|
||||||
|
|
||||||
|
if not model_classes:
|
||||||
|
print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.class_name:
|
||||||
|
model_info = next((m for m in model_classes if m["name"] == args.class_name), None)
|
||||||
|
if not model_info:
|
||||||
|
available = [m["name"] for m in model_classes]
|
||||||
|
print(f"Error: Class '{args.class_name}' not found. Available: {available}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
model_info = model_classes[0]
|
||||||
|
if len(model_classes) > 1:
|
||||||
|
print(f"Multiple model classes found, using: {model_info['name']}", file=sys.stderr)
|
||||||
|
print("Use --class-name to specify a different class", file=sys.stderr)
|
||||||
|
|
||||||
|
include_optional = args.include
|
||||||
|
if "all" in include_optional:
|
||||||
|
include_optional = [flag for _, flag in OPTIONAL_TESTERS]
|
||||||
|
|
||||||
|
generated_code = generate_test_file(model_info, args.model_filepath, include_optional, imports)
|
||||||
|
|
||||||
|
if args.dry_run:
|
||||||
|
print(generated_code)
|
||||||
|
else:
|
||||||
|
output_path = args.output or get_test_output_path(args.model_filepath)
|
||||||
|
output_dir = Path(output_path).parent
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(output_path, "w") as f:
|
||||||
|
f.write(generated_code)
|
||||||
|
|
||||||
|
print(f"Generated test file: {output_path}")
|
||||||
|
print(f"Model class: {model_info['name']}")
|
||||||
|
print(f"Detected attributes: {list(model_info['attributes'].keys())}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user