Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
5b0c7456f3 move test_hooks.py to pytest 2026-03-10 12:03:03 +05:30
4 changed files with 31 additions and 44 deletions

View File

@@ -2538,12 +2538,8 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha_tensor = state_dict.pop(alpha_key, None)
if alpha_tensor is None:
return 1.0, 1.0
scale = (
alpha_tensor.item() / rank
) # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:

View File

@@ -36,7 +36,7 @@ from typing import Any, Callable
from packaging import version
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
if is_torch_available():
@@ -844,8 +844,6 @@ class QuantoConfig(QuantizationConfigMixin):
modules_to_not_convert: list[str] | None = None,
**kwargs,
):
deprecation_message = "`QuantoConfig` is deprecated and will be removed in version 1.0.0."
deprecate("QuantoConfig", "1.0.0", deprecation_message)
self.quant_method = QuantizationMethod.QUANTO
self.weights_dtype = weights_dtype
self.modules_to_not_convert = modules_to_not_convert

View File

@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any
from diffusers.utils.import_utils import is_optimum_quanto_version
from ...utils import (
deprecate,
get_module_from_name,
is_accelerate_available,
is_accelerate_version,
@@ -43,9 +42,6 @@ class QuantoQuantizer(DiffusersQuantizer):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
deprecation_message = "The Quanto quantizer is deprecated and will be removed in version 1.0.0."
deprecate("QuantoQuantizer", "1.0.0", deprecation_message)
if not is_optimum_quanto_available():
raise ImportError(
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"

View File

@@ -13,8 +13,8 @@
# limitations under the License.
import gc
import unittest
import pytest
import torch
from diffusers.hooks import HookRegistry, ModelHook
@@ -134,20 +134,18 @@ class SkipLayerHook(ModelHook):
return output
class HookTests(unittest.TestCase):
class TestHooks:
in_features = 4
hidden_features = 8
out_features = 4
num_layers = 2
def setUp(self):
def setup_method(self):
params = self.get_module_parameters()
self.model = DummyModel(**params)
self.model.to(torch_device)
def tearDown(self):
super().tearDown()
def teardown_method(self):
del self.model
gc.collect()
free_memory()
@@ -171,20 +169,20 @@ class HookTests(unittest.TestCase):
registry_repr = repr(registry)
expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)"
self.assertEqual(len(registry.hooks), 2)
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
self.assertEqual(registry_repr, expected_repr)
assert len(registry.hooks) == 2
assert registry._hook_order == ["add_hook", "multiply_hook"]
assert registry_repr == expected_repr
registry.remove_hook("add_hook")
self.assertEqual(len(registry.hooks), 1)
self.assertEqual(registry._hook_order, ["multiply_hook"])
assert len(registry.hooks) == 1
assert registry._hook_order == ["multiply_hook"]
def test_stateful_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(StatefulAddHook(1), "stateful_add_hook")
self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)
assert registry.hooks["stateful_add_hook"].increment == 0
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
num_repeats = 3
@@ -194,13 +192,13 @@ class HookTests(unittest.TestCase):
if i == 0:
output1 = result
self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)
assert registry.get_hook("stateful_add_hook").increment == num_repeats
registry.reset_stateful_hooks()
output2 = self.model(input)
self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
self.assertTrue(torch.allclose(output1, output2))
assert registry.get_hook("stateful_add_hook").increment == 1
assert torch.allclose(output1, output2)
def test_inference(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -218,9 +216,9 @@ class HookTests(unittest.TestCase):
new_input = input * 2 + 1
output3 = self.model(new_input).mean().detach().cpu().item()
self.assertAlmostEqual(output1, output2, places=5)
self.assertAlmostEqual(output1, output3, places=5)
self.assertAlmostEqual(output2, output3, places=5)
assert output1 == pytest.approx(output2, abs=5e-6)
assert output1 == pytest.approx(output3, abs=5e-6)
assert output2 == pytest.approx(output3, abs=5e-6)
def test_skip_layer_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -228,30 +226,29 @@ class HookTests(unittest.TestCase):
input = torch.zeros(1, 4, device=torch_device)
output = self.model(input).mean().detach().cpu().item()
self.assertEqual(output, 0.0)
assert output == 0.0
registry.remove_hook("skip_layer_hook")
registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output, 0.0)
assert output != 0.0
def test_skip_layer_internal_block(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
input = torch.zeros(1, 4, device=torch_device)
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
with self.assertRaises(RuntimeError) as cm:
with pytest.raises(RuntimeError, match="mat1 and mat2 shapes cannot be multiplied"):
self.model(input).mean().detach().cpu().item()
self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))
registry.remove_hook("skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output, 0.0)
assert output != 0.0
registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output, 0.0)
assert output != 0.0
def test_invocation_order_stateful_first(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -278,7 +275,7 @@ class HookTests(unittest.TestCase):
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
@@ -289,7 +286,7 @@ class HookTests(unittest.TestCase):
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log
def test_invocation_order_stateful_middle(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -316,7 +313,7 @@ class HookTests(unittest.TestCase):
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
@@ -327,7 +324,7 @@ class HookTests(unittest.TestCase):
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log
registry.remove_hook("add_hook_2")
with CaptureLogger(logger) as cap_logger:
@@ -336,7 +333,7 @@ class HookTests(unittest.TestCase):
expected_invocation_order_log = (
("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log
def test_invocation_order_stateful_last(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -363,7 +360,7 @@ class HookTests(unittest.TestCase):
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
@@ -374,4 +371,4 @@ class HookTests(unittest.TestCase):
.replace(" ", "")
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
assert output == expected_invocation_order_log