mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-11 11:12:04 +08:00
Compare commits
5 Commits
use-pytest
...
tests-load
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1f63a398c | ||
|
|
bf846f722c | ||
|
|
78a86e85cf | ||
|
|
7673ab1757 | ||
|
|
b7648557d4 |
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.hooks import HookRegistry, ModelHook
|
||||
@@ -134,18 +134,20 @@ class SkipLayerHook(ModelHook):
|
||||
return output
|
||||
|
||||
|
||||
class TestHooks:
|
||||
class HookTests(unittest.TestCase):
|
||||
in_features = 4
|
||||
hidden_features = 8
|
||||
out_features = 4
|
||||
num_layers = 2
|
||||
|
||||
def setup_method(self):
|
||||
def setUp(self):
|
||||
params = self.get_module_parameters()
|
||||
self.model = DummyModel(**params)
|
||||
self.model.to(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
del self.model
|
||||
gc.collect()
|
||||
free_memory()
|
||||
@@ -169,20 +171,20 @@ class TestHooks:
|
||||
registry_repr = repr(registry)
|
||||
expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)"
|
||||
|
||||
assert len(registry.hooks) == 2
|
||||
assert registry._hook_order == ["add_hook", "multiply_hook"]
|
||||
assert registry_repr == expected_repr
|
||||
self.assertEqual(len(registry.hooks), 2)
|
||||
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
|
||||
self.assertEqual(registry_repr, expected_repr)
|
||||
|
||||
registry.remove_hook("add_hook")
|
||||
|
||||
assert len(registry.hooks) == 1
|
||||
assert registry._hook_order == ["multiply_hook"]
|
||||
self.assertEqual(len(registry.hooks), 1)
|
||||
self.assertEqual(registry._hook_order, ["multiply_hook"])
|
||||
|
||||
def test_stateful_hook(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry.register_hook(StatefulAddHook(1), "stateful_add_hook")
|
||||
|
||||
assert registry.hooks["stateful_add_hook"].increment == 0
|
||||
self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)
|
||||
|
||||
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
||||
num_repeats = 3
|
||||
@@ -192,13 +194,13 @@ class TestHooks:
|
||||
if i == 0:
|
||||
output1 = result
|
||||
|
||||
assert registry.get_hook("stateful_add_hook").increment == num_repeats
|
||||
self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)
|
||||
|
||||
registry.reset_stateful_hooks()
|
||||
output2 = self.model(input)
|
||||
|
||||
assert registry.get_hook("stateful_add_hook").increment == 1
|
||||
assert torch.allclose(output1, output2)
|
||||
self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
|
||||
self.assertTrue(torch.allclose(output1, output2))
|
||||
|
||||
def test_inference(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
@@ -216,9 +218,9 @@ class TestHooks:
|
||||
new_input = input * 2 + 1
|
||||
output3 = self.model(new_input).mean().detach().cpu().item()
|
||||
|
||||
assert output1 == pytest.approx(output2, abs=5e-6)
|
||||
assert output1 == pytest.approx(output3, abs=5e-6)
|
||||
assert output2 == pytest.approx(output3, abs=5e-6)
|
||||
self.assertAlmostEqual(output1, output2, places=5)
|
||||
self.assertAlmostEqual(output1, output3, places=5)
|
||||
self.assertAlmostEqual(output2, output3, places=5)
|
||||
|
||||
def test_skip_layer_hook(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
@@ -226,29 +228,30 @@ class TestHooks:
|
||||
|
||||
input = torch.zeros(1, 4, device=torch_device)
|
||||
output = self.model(input).mean().detach().cpu().item()
|
||||
assert output == 0.0
|
||||
self.assertEqual(output, 0.0)
|
||||
|
||||
registry.remove_hook("skip_layer_hook")
|
||||
registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
|
||||
output = self.model(input).mean().detach().cpu().item()
|
||||
assert output != 0.0
|
||||
self.assertNotEqual(output, 0.0)
|
||||
|
||||
def test_skip_layer_internal_block(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
|
||||
input = torch.zeros(1, 4, device=torch_device)
|
||||
|
||||
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
|
||||
with pytest.raises(RuntimeError, match="mat1 and mat2 shapes cannot be multiplied"):
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
self.model(input).mean().detach().cpu().item()
|
||||
self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))
|
||||
|
||||
registry.remove_hook("skip_layer_hook")
|
||||
output = self.model(input).mean().detach().cpu().item()
|
||||
assert output != 0.0
|
||||
self.assertNotEqual(output, 0.0)
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
|
||||
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
|
||||
output = self.model(input).mean().detach().cpu().item()
|
||||
assert output != 0.0
|
||||
self.assertNotEqual(output, 0.0)
|
||||
|
||||
def test_invocation_order_stateful_first(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
@@ -275,7 +278,7 @@ class TestHooks:
|
||||
.replace(" ", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
assert output == expected_invocation_order_log
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
registry.remove_hook("add_hook")
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
@@ -286,7 +289,7 @@ class TestHooks:
|
||||
.replace(" ", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
assert output == expected_invocation_order_log
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
def test_invocation_order_stateful_middle(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
@@ -313,7 +316,7 @@ class TestHooks:
|
||||
.replace(" ", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
assert output == expected_invocation_order_log
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
registry.remove_hook("add_hook")
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
@@ -324,7 +327,7 @@ class TestHooks:
|
||||
.replace(" ", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
assert output == expected_invocation_order_log
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
registry.remove_hook("add_hook_2")
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
@@ -333,7 +336,7 @@ class TestHooks:
|
||||
expected_invocation_order_log = (
|
||||
("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
|
||||
)
|
||||
assert output == expected_invocation_order_log
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
def test_invocation_order_stateful_last(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
@@ -360,7 +363,7 @@ class TestHooks:
|
||||
.replace(" ", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
assert output == expected_invocation_order_log
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
registry.remove_hook("add_hook")
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
@@ -371,4 +374,4 @@ class TestHooks:
|
||||
.replace(" ", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
assert output == expected_invocation_order_log
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
@@ -32,6 +33,33 @@ from ..testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
def _get_specified_components(path_or_repo_id, cache_dir=None):
|
||||
if os.path.isdir(path_or_repo_id):
|
||||
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=path_or_repo_id,
|
||||
filename="modular_model_index.json",
|
||||
local_dir=cache_dir,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
components = set()
|
||||
for k, v in config.items():
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
continue
|
||||
for entry in v:
|
||||
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
|
||||
components.add(k)
|
||||
break
|
||||
return components
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
@@ -360,6 +388,39 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_load_expected_components_from_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
|
||||
if not expected:
|
||||
pytest.skip("Skipping test as we couldn't fetch the expected components.")
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in pipe.components
|
||||
if getattr(pipe, name, None) is not None
|
||||
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
|
||||
|
||||
def test_load_expected_components_from_save_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
save_dir = str(tmp_path / "saved-pipeline")
|
||||
pipe.save_pretrained(save_dir)
|
||||
|
||||
expected = _get_specified_components(save_dir)
|
||||
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
|
||||
loaded_pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in loaded_pipe.components
|
||||
if getattr(loaded_pipe, name, None) is not None
|
||||
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, (
|
||||
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
|
||||
)
|
||||
|
||||
def test_modular_index_consistency(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
components_spec = pipe._component_specs
|
||||
|
||||
Reference in New Issue
Block a user