|
|
|
|
@@ -13,8 +13,8 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import gc
|
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from diffusers.hooks import HookRegistry, ModelHook
|
|
|
|
|
@@ -134,20 +134,18 @@ class SkipLayerHook(ModelHook):
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HookTests(unittest.TestCase):
|
|
|
|
|
class TestHooks:
|
|
|
|
|
in_features = 4
|
|
|
|
|
hidden_features = 8
|
|
|
|
|
out_features = 4
|
|
|
|
|
num_layers = 2
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
def setup_method(self):
|
|
|
|
|
params = self.get_module_parameters()
|
|
|
|
|
self.model = DummyModel(**params)
|
|
|
|
|
self.model.to(torch_device)
|
|
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
|
super().tearDown()
|
|
|
|
|
|
|
|
|
|
def teardown_method(self):
|
|
|
|
|
del self.model
|
|
|
|
|
gc.collect()
|
|
|
|
|
free_memory()
|
|
|
|
|
@@ -171,20 +169,20 @@ class HookTests(unittest.TestCase):
|
|
|
|
|
registry_repr = repr(registry)
|
|
|
|
|
expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)"
|
|
|
|
|
|
|
|
|
|
self.assertEqual(len(registry.hooks), 2)
|
|
|
|
|
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
|
|
|
|
|
self.assertEqual(registry_repr, expected_repr)
|
|
|
|
|
assert len(registry.hooks) == 2
|
|
|
|
|
assert registry._hook_order == ["add_hook", "multiply_hook"]
|
|
|
|
|
assert registry_repr == expected_repr
|
|
|
|
|
|
|
|
|
|
registry.remove_hook("add_hook")
|
|
|
|
|
|
|
|
|
|
self.assertEqual(len(registry.hooks), 1)
|
|
|
|
|
self.assertEqual(registry._hook_order, ["multiply_hook"])
|
|
|
|
|
assert len(registry.hooks) == 1
|
|
|
|
|
assert registry._hook_order == ["multiply_hook"]
|
|
|
|
|
|
|
|
|
|
def test_stateful_hook(self):
|
|
|
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
|
|
|
|
registry.register_hook(StatefulAddHook(1), "stateful_add_hook")
|
|
|
|
|
|
|
|
|
|
self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)
|
|
|
|
|
assert registry.hooks["stateful_add_hook"].increment == 0
|
|
|
|
|
|
|
|
|
|
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
|
|
|
|
num_repeats = 3
|
|
|
|
|
@@ -194,13 +192,13 @@ class HookTests(unittest.TestCase):
|
|
|
|
|
if i == 0:
|
|
|
|
|
output1 = result
|
|
|
|
|
|
|
|
|
|
self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)
|
|
|
|
|
assert registry.get_hook("stateful_add_hook").increment == num_repeats
|
|
|
|
|
|
|
|
|
|
registry.reset_stateful_hooks()
|
|
|
|
|
output2 = self.model(input)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
|
|
|
|
|
self.assertTrue(torch.allclose(output1, output2))
|
|
|
|
|
assert registry.get_hook("stateful_add_hook").increment == 1
|
|
|
|
|
assert torch.allclose(output1, output2)
|
|
|
|
|
|
|
|
|
|
def test_inference(self):
|
|
|
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
|
|
|
|
@@ -218,9 +216,9 @@ class HookTests(unittest.TestCase):
|
|
|
|
|
new_input = input * 2 + 1
|
|
|
|
|
output3 = self.model(new_input).mean().detach().cpu().item()
|
|
|
|
|
|
|
|
|
|
self.assertAlmostEqual(output1, output2, places=5)
|
|
|
|
|
self.assertAlmostEqual(output1, output3, places=5)
|
|
|
|
|
self.assertAlmostEqual(output2, output3, places=5)
|
|
|
|
|
assert output1 == pytest.approx(output2, abs=5e-6)
|
|
|
|
|
assert output1 == pytest.approx(output3, abs=5e-6)
|
|
|
|
|
assert output2 == pytest.approx(output3, abs=5e-6)
|
|
|
|
|
|
|
|
|
|
def test_skip_layer_hook(self):
|
|
|
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
|
|
|
|
@@ -228,30 +226,29 @@ class HookTests(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
input = torch.zeros(1, 4, device=torch_device)
|
|
|
|
|
output = self.model(input).mean().detach().cpu().item()
|
|
|
|
|
self.assertEqual(output, 0.0)
|
|
|
|
|
assert output == 0.0
|
|
|
|
|
|
|
|
|
|
registry.remove_hook("skip_layer_hook")
|
|
|
|
|
registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
|
|
|
|
|
output = self.model(input).mean().detach().cpu().item()
|
|
|
|
|
self.assertNotEqual(output, 0.0)
|
|
|
|
|
assert output != 0.0
|
|
|
|
|
|
|
|
|
|
def test_skip_layer_internal_block(self):
|
|
|
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
|
|
|
|
|
input = torch.zeros(1, 4, device=torch_device)
|
|
|
|
|
|
|
|
|
|
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
|
|
|
|
|
with self.assertRaises(RuntimeError) as cm:
|
|
|
|
|
with pytest.raises(RuntimeError, match="mat1 and mat2 shapes cannot be multiplied"):
|
|
|
|
|
self.model(input).mean().detach().cpu().item()
|
|
|
|
|
self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))
|
|
|
|
|
|
|
|
|
|
registry.remove_hook("skip_layer_hook")
|
|
|
|
|
output = self.model(input).mean().detach().cpu().item()
|
|
|
|
|
self.assertNotEqual(output, 0.0)
|
|
|
|
|
assert output != 0.0
|
|
|
|
|
|
|
|
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
|
|
|
|
|
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
|
|
|
|
|
output = self.model(input).mean().detach().cpu().item()
|
|
|
|
|
self.assertNotEqual(output, 0.0)
|
|
|
|
|
assert output != 0.0
|
|
|
|
|
|
|
|
|
|
def test_invocation_order_stateful_first(self):
|
|
|
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
|
|
|
|
@@ -278,7 +275,7 @@ class HookTests(unittest.TestCase):
|
|
|
|
|
.replace(" ", "")
|
|
|
|
|
.replace("\n", "")
|
|
|
|
|
)
|
|
|
|
|
self.assertEqual(output, expected_invocation_order_log)
|
|
|
|
|
assert output == expected_invocation_order_log
|
|
|
|
|
|
|
|
|
|
registry.remove_hook("add_hook")
|
|
|
|
|
with CaptureLogger(logger) as cap_logger:
|
|
|
|
|
@@ -289,7 +286,7 @@ class HookTests(unittest.TestCase):
|
|
|
|
|
.replace(" ", "")
|
|
|
|
|
.replace("\n", "")
|
|
|
|
|
)
|
|
|
|
|
self.assertEqual(output, expected_invocation_order_log)
|
|
|
|
|
assert output == expected_invocation_order_log
|
|
|
|
|
|
|
|
|
|
def test_invocation_order_stateful_middle(self):
|
|
|
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
|
|
|
|
@@ -316,7 +313,7 @@ class HookTests(unittest.TestCase):
|
|
|
|
|
.replace(" ", "")
|
|
|
|
|
.replace("\n", "")
|
|
|
|
|
)
|
|
|
|
|
self.assertEqual(output, expected_invocation_order_log)
|
|
|
|
|
assert output == expected_invocation_order_log
|
|
|
|
|
|
|
|
|
|
registry.remove_hook("add_hook")
|
|
|
|
|
with CaptureLogger(logger) as cap_logger:
|
|
|
|
|
@@ -327,7 +324,7 @@ class HookTests(unittest.TestCase):
|
|
|
|
|
.replace(" ", "")
|
|
|
|
|
.replace("\n", "")
|
|
|
|
|
)
|
|
|
|
|
self.assertEqual(output, expected_invocation_order_log)
|
|
|
|
|
assert output == expected_invocation_order_log
|
|
|
|
|
|
|
|
|
|
registry.remove_hook("add_hook_2")
|
|
|
|
|
with CaptureLogger(logger) as cap_logger:
|
|
|
|
|
@@ -336,7 +333,7 @@ class HookTests(unittest.TestCase):
|
|
|
|
|
expected_invocation_order_log = (
|
|
|
|
|
("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
|
|
|
|
|
)
|
|
|
|
|
self.assertEqual(output, expected_invocation_order_log)
|
|
|
|
|
assert output == expected_invocation_order_log
|
|
|
|
|
|
|
|
|
|
def test_invocation_order_stateful_last(self):
|
|
|
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
|
|
|
|
@@ -363,7 +360,7 @@ class HookTests(unittest.TestCase):
|
|
|
|
|
.replace(" ", "")
|
|
|
|
|
.replace("\n", "")
|
|
|
|
|
)
|
|
|
|
|
self.assertEqual(output, expected_invocation_order_log)
|
|
|
|
|
assert output == expected_invocation_order_log
|
|
|
|
|
|
|
|
|
|
registry.remove_hook("add_hook")
|
|
|
|
|
with CaptureLogger(logger) as cap_logger:
|
|
|
|
|
@@ -374,4 +371,4 @@ class HookTests(unittest.TestCase):
|
|
|
|
|
.replace(" ", "")
|
|
|
|
|
.replace("\n", "")
|
|
|
|
|
)
|
|
|
|
|
self.assertEqual(output, expected_invocation_order_log)
|
|
|
|
|
assert output == expected_invocation_order_log
|
|
|
|
|
|