mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 06:54:32 +08:00
Compare commits
4 Commits
v0.33.1
...
hooks/qol-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2bff2d5eb1 | ||
|
|
d080379e94 | ||
|
|
8546c9ed29 | ||
|
|
e08285ef9c |
@@ -124,6 +124,7 @@ class GroupOffloadingHook(ModelHook):
|
|||||||
group: ModuleGroup,
|
group: ModuleGroup,
|
||||||
next_group: Optional[ModuleGroup] = None,
|
next_group: Optional[ModuleGroup] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
self.group = group
|
self.group = group
|
||||||
self.next_group = next_group
|
self.next_group = next_group
|
||||||
|
|
||||||
@@ -168,6 +169,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
|||||||
_is_stateful = False
|
_is_stateful = False
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
|
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
|
||||||
self._layer_execution_tracker_module_names = set()
|
self._layer_execution_tracker_module_names = set()
|
||||||
|
|
||||||
@@ -253,6 +255,7 @@ class LayerExecutionTrackerHook(ModelHook):
|
|||||||
_is_stateful = False
|
_is_stateful = False
|
||||||
|
|
||||||
def __init__(self, execution_order_update_callback):
|
def __init__(self, execution_order_update_callback):
|
||||||
|
super().__init__()
|
||||||
self.execution_order_update_callback = execution_order_update_callback
|
self.execution_order_update_callback = execution_order_update_callback
|
||||||
|
|
||||||
def pre_forward(self, module, *args, **kwargs):
|
def pre_forward(self, module, *args, **kwargs):
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ class ModelHook:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.fn_ref: "HookFunctionReference" = None
|
self.fn_ref: "HookFunctionReference" = None
|
||||||
|
self._is_enabled = True
|
||||||
|
|
||||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||||
r"""
|
r"""
|
||||||
@@ -142,8 +143,10 @@ class HookRegistry:
|
|||||||
|
|
||||||
self._module_ref = hook.initialize_hook(self._module_ref)
|
self._module_ref = hook.initialize_hook(self._module_ref)
|
||||||
|
|
||||||
def create_new_forward(function_reference: HookFunctionReference):
|
def create_new_forward(hook: ModelHook, function_reference: HookFunctionReference):
|
||||||
def new_forward(module, *args, **kwargs):
|
def new_forward(module, *args, **kwargs):
|
||||||
|
if not hook._is_enabled:
|
||||||
|
return function_reference.original_forward(*args, **kwargs)
|
||||||
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
|
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
|
||||||
output = function_reference.forward(*args, **kwargs)
|
output = function_reference.forward(*args, **kwargs)
|
||||||
return function_reference.post_forward(module, output)
|
return function_reference.post_forward(module, output)
|
||||||
@@ -155,6 +158,7 @@ class HookRegistry:
|
|||||||
fn_ref = HookFunctionReference()
|
fn_ref = HookFunctionReference()
|
||||||
fn_ref.pre_forward = hook.pre_forward
|
fn_ref.pre_forward = hook.pre_forward
|
||||||
fn_ref.post_forward = hook.post_forward
|
fn_ref.post_forward = hook.post_forward
|
||||||
|
fn_ref.original_forward = forward
|
||||||
fn_ref.forward = forward
|
fn_ref.forward = forward
|
||||||
|
|
||||||
if hasattr(hook, "new_forward"):
|
if hasattr(hook, "new_forward"):
|
||||||
@@ -163,7 +167,7 @@ class HookRegistry:
|
|||||||
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
|
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
|
||||||
)
|
)
|
||||||
|
|
||||||
rewritten_forward = create_new_forward(fn_ref)
|
rewritten_forward = create_new_forward(hook, fn_ref)
|
||||||
self._module_ref.forward = functools.update_wrapper(
|
self._module_ref.forward = functools.update_wrapper(
|
||||||
functools.partial(rewritten_forward, self._module_ref), rewritten_forward
|
functools.partial(rewritten_forward, self._module_ref), rewritten_forward
|
||||||
)
|
)
|
||||||
@@ -234,3 +238,19 @@ class HookRegistry:
|
|||||||
if i < len(self._hook_order) - 1:
|
if i < len(self._hook_order) - 1:
|
||||||
registry_repr += "\n"
|
registry_repr += "\n"
|
||||||
return f"HookRegistry(\n{registry_repr}\n)"
|
return f"HookRegistry(\n{registry_repr}\n)"
|
||||||
|
|
||||||
|
|
||||||
|
def _set_hook_state(module: torch.nn.Module, name: str, value: bool) -> None:
|
||||||
|
for submodule in module.modules():
|
||||||
|
if hasattr(submodule, "_diffusers_hook"):
|
||||||
|
hook = submodule._diffusers_hook.get_hook(name)
|
||||||
|
if hook is not None:
|
||||||
|
hook._is_enabled = value
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_all_hooks(module: torch.nn.Module):
|
||||||
|
for submodule in module.modules():
|
||||||
|
if hasattr(submodule, "_diffusers_hook"):
|
||||||
|
for hook_name in list(submodule._diffusers_hook.hooks.keys()):
|
||||||
|
submodule._diffusers_hook.remove_hook(hook_name, recurse=False)
|
||||||
|
del submodule._diffusers_hook
|
||||||
|
|||||||
@@ -52,6 +52,8 @@ class LayerwiseCastingHook(ModelHook):
|
|||||||
_is_stateful = False
|
_is_stateful = False
|
||||||
|
|
||||||
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
|
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self.storage_dtype = storage_dtype
|
self.storage_dtype = storage_dtype
|
||||||
self.compute_dtype = compute_dtype
|
self.compute_dtype = compute_dtype
|
||||||
self.non_blocking = non_blocking
|
self.non_blocking = non_blocking
|
||||||
|
|||||||
@@ -1755,6 +1755,44 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def _enable_hook(self, name: str) -> None:
|
||||||
|
r"""
|
||||||
|
This method enables the hook with the given name on the model and all its submodules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (`str`):
|
||||||
|
The name of the hook to enable.
|
||||||
|
|
||||||
|
This method is not backwards compatible and may be subject to change in future versions.
|
||||||
|
"""
|
||||||
|
from ..hooks.hooks import _set_hook_state
|
||||||
|
|
||||||
|
_set_hook_state(self, name, True)
|
||||||
|
|
||||||
|
def _disable_hook(self, name: str) -> None:
|
||||||
|
r"""
|
||||||
|
This method disables the hook with the given name on the model and all its submodules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (`str`):
|
||||||
|
The name of the hook to disable.
|
||||||
|
|
||||||
|
This method is not backwards compatible and may be subject to change in future versions.
|
||||||
|
"""
|
||||||
|
from ..hooks.hooks import _set_hook_state
|
||||||
|
|
||||||
|
_set_hook_state(self, name, False)
|
||||||
|
|
||||||
|
def _remove_all_hooks(self) -> None:
|
||||||
|
r"""
|
||||||
|
This method removes all hooks from the model and all its submodules.
|
||||||
|
|
||||||
|
This method is not backwards compatible and may be subject to change in future versions.
|
||||||
|
"""
|
||||||
|
from ..hooks.hooks import _remove_all_hooks
|
||||||
|
|
||||||
|
_remove_all_hooks(self)
|
||||||
|
|
||||||
|
|
||||||
class LegacyModelMixin(ModelMixin):
|
class LegacyModelMixin(ModelMixin):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -17,7 +17,9 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from diffusers.hooks import HookRegistry, ModelHook
|
from diffusers.hooks import HookRegistry, ModelHook
|
||||||
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
from diffusers.training_utils import free_memory
|
from diffusers.training_utils import free_memory
|
||||||
from diffusers.utils.logging import get_logger
|
from diffusers.utils.logging import get_logger
|
||||||
from diffusers.utils.testing_utils import CaptureLogger, torch_device
|
from diffusers.utils.testing_utils import CaptureLogger, torch_device
|
||||||
@@ -61,6 +63,26 @@ class DummyModel(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModelWithMixin(ModelMixin, ConfigMixin):
|
||||||
|
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
|
||||||
|
self.activation = torch.nn.ReLU()
|
||||||
|
self.blocks = torch.nn.ModuleList(
|
||||||
|
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
|
||||||
|
)
|
||||||
|
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.linear_1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
x = self.linear_2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class AddHook(ModelHook):
|
class AddHook(ModelHook):
|
||||||
def __init__(self, value: int):
|
def __init__(self, value: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -380,3 +402,96 @@ class HookTests(unittest.TestCase):
|
|||||||
.replace("\n", "")
|
.replace("\n", "")
|
||||||
)
|
)
|
||||||
self.assertEqual(output, expected_invocation_order_log)
|
self.assertEqual(output, expected_invocation_order_log)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMixinHookTests(unittest.TestCase):
|
||||||
|
in_features = 4
|
||||||
|
hidden_features = 8
|
||||||
|
out_features = 4
|
||||||
|
num_layers = 2
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
params = self.get_module_parameters()
|
||||||
|
self.model = DummyModelWithMixin(**params)
|
||||||
|
self.model.to(torch_device)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
|
del self.model
|
||||||
|
gc.collect()
|
||||||
|
free_memory()
|
||||||
|
|
||||||
|
def get_module_parameters(self):
|
||||||
|
return {
|
||||||
|
"in_features": self.in_features,
|
||||||
|
"hidden_features": self.hidden_features,
|
||||||
|
"out_features": self.out_features,
|
||||||
|
"num_layers": self.num_layers,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_generator(self):
|
||||||
|
return torch.manual_seed(0)
|
||||||
|
|
||||||
|
def test_enable_disable_hook(self):
|
||||||
|
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||||
|
registry.register_hook(AddHook(1), "add_hook")
|
||||||
|
registry.register_hook(MultiplyHook(2), "multiply_hook")
|
||||||
|
|
||||||
|
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
||||||
|
output1 = self.model(input).mean().detach().cpu().item()
|
||||||
|
|
||||||
|
self.model._disable_hook("multiply_hook")
|
||||||
|
output2 = self.model(input).mean().detach().cpu().item()
|
||||||
|
|
||||||
|
self.model._enable_hook("multiply_hook")
|
||||||
|
output3 = self.model(input).mean().detach().cpu().item()
|
||||||
|
|
||||||
|
self.assertNotEqual(output1, output2)
|
||||||
|
self.assertEqual(output1, output3)
|
||||||
|
|
||||||
|
def test_enable_disable_hook_containing_new_forward(self):
|
||||||
|
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||||
|
registry.register_hook(AddHook(1), "add_hook")
|
||||||
|
for block in self.model.blocks:
|
||||||
|
block_registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||||
|
block_registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
|
||||||
|
registry.register_hook(MultiplyHook(2), "multiply_hook")
|
||||||
|
|
||||||
|
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
||||||
|
output1 = self.model(input).mean().detach().cpu().item()
|
||||||
|
|
||||||
|
self.model._disable_hook("skip_layer_hook")
|
||||||
|
output2 = self.model(input).mean().detach().cpu().item()
|
||||||
|
|
||||||
|
self.model._disable_hook("add_hook")
|
||||||
|
output3 = self.model(input).mean().detach().cpu().item()
|
||||||
|
|
||||||
|
self.model._enable_hook("skip_layer_hook")
|
||||||
|
self.model._enable_hook("add_hook")
|
||||||
|
output4 = self.model(input).mean().detach().cpu().item()
|
||||||
|
|
||||||
|
self.assertNotEqual(output1, output2)
|
||||||
|
self.assertNotEqual(output2, output3)
|
||||||
|
self.assertEqual(output1, output4)
|
||||||
|
|
||||||
|
def test_remove_all_hooks(self):
|
||||||
|
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||||
|
registry.register_hook(AddHook(1), "add_hook")
|
||||||
|
registry.register_hook(MultiplyHook(2), "multiply_hook")
|
||||||
|
|
||||||
|
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
|
||||||
|
output1 = self.model(input).mean().detach().cpu().item()
|
||||||
|
|
||||||
|
self.model._disable_hook("add_hook")
|
||||||
|
self.model._disable_hook("multiply_hook")
|
||||||
|
output2 = self.model(input).mean().detach().cpu().item()
|
||||||
|
|
||||||
|
self.model._remove_all_hooks()
|
||||||
|
output3 = self.model(input).mean().detach().cpu().item()
|
||||||
|
|
||||||
|
for module in self.model.modules():
|
||||||
|
self.assertFalse(hasattr(module, "_diffusers_hook"))
|
||||||
|
|
||||||
|
self.assertNotEqual(output1, output3)
|
||||||
|
self.assertEqual(output2, output3)
|
||||||
|
|||||||
Reference in New Issue
Block a user