mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
4 Commits
cnet-union
...
hooks/qol-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2bff2d5eb1 | ||
|
|
d080379e94 | ||
|
|
8546c9ed29 | ||
|
|
e08285ef9c |
@@ -124,6 +124,7 @@ class GroupOffloadingHook(ModelHook):
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.group = group
|
||||
self.next_group = next_group
|
||||
|
||||
@@ -168,6 +169,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
|
||||
self._layer_execution_tracker_module_names = set()
|
||||
|
||||
@@ -253,6 +255,7 @@ class LayerExecutionTrackerHook(ModelHook):
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self, execution_order_update_callback):
|
||||
super().__init__()
|
||||
self.execution_order_update_callback = execution_order_update_callback
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
|
||||
@@ -32,6 +32,7 @@ class ModelHook:
|
||||
|
||||
def __init__(self):
|
||||
self.fn_ref: "HookFunctionReference" = None
|
||||
self._is_enabled = True
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
@@ -142,8 +143,10 @@ class HookRegistry:
|
||||
|
||||
self._module_ref = hook.initialize_hook(self._module_ref)
|
||||
|
||||
def create_new_forward(function_reference: HookFunctionReference):
|
||||
def create_new_forward(hook: ModelHook, function_reference: HookFunctionReference):
|
||||
def new_forward(module, *args, **kwargs):
|
||||
if not hook._is_enabled:
|
||||
return function_reference.original_forward(*args, **kwargs)
|
||||
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
|
||||
output = function_reference.forward(*args, **kwargs)
|
||||
return function_reference.post_forward(module, output)
|
||||
@@ -155,6 +158,7 @@ class HookRegistry:
|
||||
fn_ref = HookFunctionReference()
|
||||
fn_ref.pre_forward = hook.pre_forward
|
||||
fn_ref.post_forward = hook.post_forward
|
||||
fn_ref.original_forward = forward
|
||||
fn_ref.forward = forward
|
||||
|
||||
if hasattr(hook, "new_forward"):
|
||||
@@ -163,7 +167,7 @@ class HookRegistry:
|
||||
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
|
||||
)
|
||||
|
||||
rewritten_forward = create_new_forward(fn_ref)
|
||||
rewritten_forward = create_new_forward(hook, fn_ref)
|
||||
self._module_ref.forward = functools.update_wrapper(
|
||||
functools.partial(rewritten_forward, self._module_ref), rewritten_forward
|
||||
)
|
||||
@@ -234,3 +238,19 @@ class HookRegistry:
|
||||
if i < len(self._hook_order) - 1:
|
||||
registry_repr += "\n"
|
||||
return f"HookRegistry(\n{registry_repr}\n)"
|
||||
|
||||
|
||||
def _set_hook_state(module: torch.nn.Module, name: str, value: bool) -> None:
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook"):
|
||||
hook = submodule._diffusers_hook.get_hook(name)
|
||||
if hook is not None:
|
||||
hook._is_enabled = value
|
||||
|
||||
|
||||
def _remove_all_hooks(module: torch.nn.Module):
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook"):
|
||||
for hook_name in list(submodule._diffusers_hook.hooks.keys()):
|
||||
submodule._diffusers_hook.remove_hook(hook_name, recurse=False)
|
||||
del submodule._diffusers_hook
|
||||
|
||||
@@ -52,6 +52,8 @@ class LayerwiseCastingHook(ModelHook):
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.storage_dtype = storage_dtype
|
||||
self.compute_dtype = compute_dtype
|
||||
self.non_blocking = non_blocking
|
||||
|
||||
@@ -1755,6 +1755,44 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
||||
return state_dict
|
||||
|
||||
def _enable_hook(self, name: str) -> None:
|
||||
r"""
|
||||
This method enables the hook with the given name on the model and all its submodules.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name of the hook to enable.
|
||||
|
||||
This method is not backwards compatible and may be subject to change in future versions.
|
||||
"""
|
||||
from ..hooks.hooks import _set_hook_state
|
||||
|
||||
_set_hook_state(self, name, True)
|
||||
|
||||
def _disable_hook(self, name: str) -> None:
|
||||
r"""
|
||||
This method disables the hook with the given name on the model and all its submodules.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name of the hook to disable.
|
||||
|
||||
This method is not backwards compatible and may be subject to change in future versions.
|
||||
"""
|
||||
from ..hooks.hooks import _set_hook_state
|
||||
|
||||
_set_hook_state(self, name, False)
|
||||
|
||||
def _remove_all_hooks(self) -> None:
|
||||
r"""
|
||||
This method removes all hooks from the model and all its submodules.
|
||||
|
||||
This method is not backwards compatible and may be subject to change in future versions.
|
||||
"""
|
||||
from ..hooks.hooks import _remove_all_hooks
|
||||
|
||||
_remove_all_hooks(self)
|
||||
|
||||
|
||||
class LegacyModelMixin(ModelMixin):
|
||||
r"""
|
||||
|
||||
@@ -17,7 +17,9 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.hooks import HookRegistry, ModelHook
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.training_utils import free_memory
|
||||
from diffusers.utils.logging import get_logger
|
||||
from diffusers.utils.testing_utils import CaptureLogger, torch_device
|
||||
@@ -61,6 +63,26 @@ class DummyModel(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DummyModelWithMixin(ModelMixin, ConfigMixin):
|
||||
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
|
||||
)
|
||||
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_1(x)
|
||||
x = self.activation(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class AddHook(ModelHook):
|
||||
def __init__(self, value: int):
|
||||
super().__init__()
|
||||
@@ -380,3 +402,96 @@ class HookTests(unittest.TestCase):
|
||||
.replace("\n", "")
|
||||
)
|
||||
self.assertEqual(output, expected_invocation_order_log)
|
||||
|
||||
|
||||
class ModelMixinHookTests(unittest.TestCase):
|
||||
in_features = 4
|
||||
hidden_features = 8
|
||||
out_features = 4
|
||||
num_layers = 2
|
||||
|
||||
def setUp(self):
|
||||
params = self.get_module_parameters()
|
||||
self.model = DummyModelWithMixin(**params)
|
||||
self.model.to(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
del self.model
|
||||
gc.collect()
|
||||
free_memory()
|
||||
|
||||
def get_module_parameters(self):
|
||||
return {
|
||||
"in_features": self.in_features,
|
||||
"hidden_features": self.hidden_features,
|
||||
"out_features": self.out_features,
|
||||
"num_layers": self.num_layers,
|
||||
}
|
||||
|
||||
def get_generator(self):
|
||||
return torch.manual_seed(0)
|
||||
|
||||
def test_enable_disable_hook(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry.register_hook(AddHook(1), "add_hook")
|
||||
registry.register_hook(MultiplyHook(2), "multiply_hook")
|
||||
|
||||
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
||||
output1 = self.model(input).mean().detach().cpu().item()
|
||||
|
||||
self.model._disable_hook("multiply_hook")
|
||||
output2 = self.model(input).mean().detach().cpu().item()
|
||||
|
||||
self.model._enable_hook("multiply_hook")
|
||||
output3 = self.model(input).mean().detach().cpu().item()
|
||||
|
||||
self.assertNotEqual(output1, output2)
|
||||
self.assertEqual(output1, output3)
|
||||
|
||||
def test_enable_disable_hook_containing_new_forward(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry.register_hook(AddHook(1), "add_hook")
|
||||
for block in self.model.blocks:
|
||||
block_registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
block_registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
|
||||
registry.register_hook(MultiplyHook(2), "multiply_hook")
|
||||
|
||||
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
||||
output1 = self.model(input).mean().detach().cpu().item()
|
||||
|
||||
self.model._disable_hook("skip_layer_hook")
|
||||
output2 = self.model(input).mean().detach().cpu().item()
|
||||
|
||||
self.model._disable_hook("add_hook")
|
||||
output3 = self.model(input).mean().detach().cpu().item()
|
||||
|
||||
self.model._enable_hook("skip_layer_hook")
|
||||
self.model._enable_hook("add_hook")
|
||||
output4 = self.model(input).mean().detach().cpu().item()
|
||||
|
||||
self.assertNotEqual(output1, output2)
|
||||
self.assertNotEqual(output2, output3)
|
||||
self.assertEqual(output1, output4)
|
||||
|
||||
def test_remove_all_hooks(self):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry.register_hook(AddHook(1), "add_hook")
|
||||
registry.register_hook(MultiplyHook(2), "multiply_hook")
|
||||
|
||||
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
||||
output1 = self.model(input).mean().detach().cpu().item()
|
||||
|
||||
self.model._disable_hook("add_hook")
|
||||
self.model._disable_hook("multiply_hook")
|
||||
output2 = self.model(input).mean().detach().cpu().item()
|
||||
|
||||
self.model._remove_all_hooks()
|
||||
output3 = self.model(input).mean().detach().cpu().item()
|
||||
|
||||
for module in self.model.modules():
|
||||
self.assertFalse(hasattr(module, "_diffusers_hook"))
|
||||
|
||||
self.assertNotEqual(output1, output3)
|
||||
self.assertEqual(output2, output3)
|
||||
|
||||
Reference in New Issue
Block a user