Compare commits

...

4 Commits

Author SHA1 Message Date
Aryan
2bff2d5eb1 try fix for tests 2025-02-21 09:49:42 +01:00
Aryan
d080379e94 try fix for tests 2025-02-21 08:41:11 +01:00
Aryan
8546c9ed29 new_forward support 2025-02-21 06:18:13 +01:00
Aryan
e08285ef9c update 2025-02-21 06:00:30 +01:00
5 changed files with 180 additions and 2 deletions

View File

@@ -124,6 +124,7 @@ class GroupOffloadingHook(ModelHook):
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
) -> None:
super().__init__()
self.group = group
self.next_group = next_group
@@ -168,6 +169,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
_is_stateful = False
def __init__(self):
super().__init__()
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
self._layer_execution_tracker_module_names = set()
@@ -253,6 +255,7 @@ class LayerExecutionTrackerHook(ModelHook):
_is_stateful = False
def __init__(self, execution_order_update_callback):
super().__init__()
self.execution_order_update_callback = execution_order_update_callback
def pre_forward(self, module, *args, **kwargs):

View File

@@ -32,6 +32,7 @@ class ModelHook:
def __init__(self):
self.fn_ref: "HookFunctionReference" = None
self._is_enabled = True
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
@@ -142,8 +143,10 @@ class HookRegistry:
self._module_ref = hook.initialize_hook(self._module_ref)
def create_new_forward(function_reference: HookFunctionReference):
def create_new_forward(hook: ModelHook, function_reference: HookFunctionReference):
def new_forward(module, *args, **kwargs):
if not hook._is_enabled:
return function_reference.original_forward(*args, **kwargs)
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
output = function_reference.forward(*args, **kwargs)
return function_reference.post_forward(module, output)
@@ -155,6 +158,7 @@ class HookRegistry:
fn_ref = HookFunctionReference()
fn_ref.pre_forward = hook.pre_forward
fn_ref.post_forward = hook.post_forward
fn_ref.original_forward = forward
fn_ref.forward = forward
if hasattr(hook, "new_forward"):
@@ -163,7 +167,7 @@ class HookRegistry:
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
)
rewritten_forward = create_new_forward(fn_ref)
rewritten_forward = create_new_forward(hook, fn_ref)
self._module_ref.forward = functools.update_wrapper(
functools.partial(rewritten_forward, self._module_ref), rewritten_forward
)
@@ -234,3 +238,19 @@ class HookRegistry:
if i < len(self._hook_order) - 1:
registry_repr += "\n"
return f"HookRegistry(\n{registry_repr}\n)"
def _set_hook_state(module: torch.nn.Module, name: str, value: bool) -> None:
for submodule in module.modules():
if hasattr(submodule, "_diffusers_hook"):
hook = submodule._diffusers_hook.get_hook(name)
if hook is not None:
hook._is_enabled = value
def _remove_all_hooks(module: torch.nn.Module):
for submodule in module.modules():
if hasattr(submodule, "_diffusers_hook"):
for hook_name in list(submodule._diffusers_hook.hooks.keys()):
submodule._diffusers_hook.remove_hook(hook_name, recurse=False)
del submodule._diffusers_hook

View File

@@ -52,6 +52,8 @@ class LayerwiseCastingHook(ModelHook):
_is_stateful = False
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
super().__init__()
self.storage_dtype = storage_dtype
self.compute_dtype = compute_dtype
self.non_blocking = non_blocking

View File

@@ -1755,6 +1755,44 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
return state_dict
def _enable_hook(self, name: str) -> None:
r"""
This method enables the hook with the given name on the model and all its submodules.
Args:
name (`str`):
The name of the hook to enable.
This method is not backwards compatible and may be subject to change in future versions.
"""
from ..hooks.hooks import _set_hook_state
_set_hook_state(self, name, True)
def _disable_hook(self, name: str) -> None:
r"""
This method disables the hook with the given name on the model and all its submodules.
Args:
name (`str`):
The name of the hook to disable.
This method is not backwards compatible and may be subject to change in future versions.
"""
from ..hooks.hooks import _set_hook_state
_set_hook_state(self, name, False)
def _remove_all_hooks(self) -> None:
r"""
This method removes all hooks from the model and all its submodules.
This method is not backwards compatible and may be subject to change in future versions.
"""
from ..hooks.hooks import _remove_all_hooks
_remove_all_hooks(self)
class LegacyModelMixin(ModelMixin):
r"""

View File

@@ -17,7 +17,9 @@ import unittest
import torch
from diffusers.configuration_utils import ConfigMixin
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models.modeling_utils import ModelMixin
from diffusers.training_utils import free_memory
from diffusers.utils.logging import get_logger
from diffusers.utils.testing_utils import CaptureLogger, torch_device
@@ -61,6 +63,26 @@ class DummyModel(torch.nn.Module):
return x
class DummyModelWithMixin(ModelMixin, ConfigMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
)
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = self.activation(x)
for block in self.blocks:
x = block(x)
x = self.linear_2(x)
return x
class AddHook(ModelHook):
def __init__(self, value: int):
super().__init__()
@@ -380,3 +402,96 @@ class HookTests(unittest.TestCase):
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)
class ModelMixinHookTests(unittest.TestCase):
in_features = 4
hidden_features = 8
out_features = 4
num_layers = 2
def setUp(self):
params = self.get_module_parameters()
self.model = DummyModelWithMixin(**params)
self.model.to(torch_device)
def tearDown(self):
super().tearDown()
del self.model
gc.collect()
free_memory()
def get_module_parameters(self):
return {
"in_features": self.in_features,
"hidden_features": self.hidden_features,
"out_features": self.out_features,
"num_layers": self.num_layers,
}
def get_generator(self):
return torch.manual_seed(0)
def test_enable_disable_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(AddHook(1), "add_hook")
registry.register_hook(MultiplyHook(2), "multiply_hook")
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
output1 = self.model(input).mean().detach().cpu().item()
self.model._disable_hook("multiply_hook")
output2 = self.model(input).mean().detach().cpu().item()
self.model._enable_hook("multiply_hook")
output3 = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output1, output2)
self.assertEqual(output1, output3)
def test_enable_disable_hook_containing_new_forward(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(AddHook(1), "add_hook")
for block in self.model.blocks:
block_registry = HookRegistry.check_if_exists_or_initialize(block)
block_registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
registry.register_hook(MultiplyHook(2), "multiply_hook")
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
output1 = self.model(input).mean().detach().cpu().item()
self.model._disable_hook("skip_layer_hook")
output2 = self.model(input).mean().detach().cpu().item()
self.model._disable_hook("add_hook")
output3 = self.model(input).mean().detach().cpu().item()
self.model._enable_hook("skip_layer_hook")
self.model._enable_hook("add_hook")
output4 = self.model(input).mean().detach().cpu().item()
self.assertNotEqual(output1, output2)
self.assertNotEqual(output2, output3)
self.assertEqual(output1, output4)
def test_remove_all_hooks(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(AddHook(1), "add_hook")
registry.register_hook(MultiplyHook(2), "multiply_hook")
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
output1 = self.model(input).mean().detach().cpu().item()
self.model._disable_hook("add_hook")
self.model._disable_hook("multiply_hook")
output2 = self.model(input).mean().detach().cpu().item()
self.model._remove_all_hooks()
output3 = self.model(input).mean().detach().cpu().item()
for module in self.model.modules():
self.assertFalse(hasattr(module, "_diffusers_hook"))
self.assertNotEqual(output1, output3)
self.assertEqual(output2, output3)