mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 07:24:32 +08:00
Compare commits
1 Commits
single-fil
...
dynamic-up
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be55fa631f |
@@ -263,6 +263,41 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
"""
|
"""
|
||||||
self.set_use_memory_efficient_attention_xformers(False)
|
self.set_use_memory_efficient_attention_xformers(False)
|
||||||
|
|
||||||
|
def enable_dynamic_upcasting(self, upcast_dtype=None):
|
||||||
|
upcast_dtype = upcast_dtype or torch.float32
|
||||||
|
downcast_dtype = self.dtype
|
||||||
|
|
||||||
|
def upcast_hook_fn(module):
|
||||||
|
module = module.to(upcast_dtype)
|
||||||
|
|
||||||
|
def downcast_hook_fn(module):
|
||||||
|
module = module.to(downcast_dtype)
|
||||||
|
|
||||||
|
def fn_recursive_upcast(module):
|
||||||
|
has_children = list(module.children())
|
||||||
|
if not has_children:
|
||||||
|
module.register_forward_pre_hook(upcast_hook_fn)
|
||||||
|
module.register_forward_hook(downcast_hook_fn)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_upcast(child)
|
||||||
|
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_upcast(module)
|
||||||
|
|
||||||
|
def disable_dynamic_upcasting(self):
|
||||||
|
def fn_recursive_upcast(module):
|
||||||
|
has_children = list(module.children())
|
||||||
|
if not has_children:
|
||||||
|
module._forward_pre_hooks = OrderedDict()
|
||||||
|
module._forward_hooks = OrderedDict()
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_upcast(child)
|
||||||
|
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_upcast(module)
|
||||||
|
|
||||||
def save_pretrained(
|
def save_pretrained(
|
||||||
self,
|
self,
|
||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
|
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
|
||||||
@@ -1172,6 +1173,93 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
component.to("cpu")
|
component.to("cpu")
|
||||||
self.hf_device_map = None
|
self.hf_device_map = None
|
||||||
|
|
||||||
|
def enable_dynamic_upcasting(
|
||||||
|
self,
|
||||||
|
components: Optional[List[str]] = None,
|
||||||
|
upcast_dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Enable module-wise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g.
|
||||||
|
torch.float8_e4m3fn, but perform inference using a dtype that is supported on the GPU, by casting the module to
|
||||||
|
the appropriate dtype right before the foward pass. The module is then moved back to the low memory dtype after
|
||||||
|
the foward pass.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if components is None:
|
||||||
|
raise ValueError("Please provide a list of pipeline component names to apply dynamic upcasting")
|
||||||
|
|
||||||
|
def fn_recursive_upcast(module, dtype, original_dtype, keep_in_fp32_modules):
|
||||||
|
has_children = list(module.children())
|
||||||
|
upcast_dtype = dtype
|
||||||
|
downcast_dtype = original_dtype
|
||||||
|
|
||||||
|
def upcast_hook_fn(module, inputs):
|
||||||
|
module = module.to(upcast_dtype)
|
||||||
|
|
||||||
|
def downcast_hook_fn(module, *args, **kwargs):
|
||||||
|
module = module.to(downcast_dtype)
|
||||||
|
|
||||||
|
if not has_children:
|
||||||
|
module.register_forward_pre_hook(upcast_hook_fn)
|
||||||
|
module.register_forward_hook(downcast_hook_fn)
|
||||||
|
|
||||||
|
for name, child in module.named_children():
|
||||||
|
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
dtype = upcast_dtype
|
||||||
|
|
||||||
|
fn_recursive_upcast(child, dtype, original_dtype, keep_in_fp32_modules)
|
||||||
|
|
||||||
|
for component in components:
|
||||||
|
if not hasattr(self, component):
|
||||||
|
raise ValueError(f"Pipeline has no component named: {component}")
|
||||||
|
|
||||||
|
component_module = getattr(self, component)
|
||||||
|
if not isinstance(component_module, torch.nn.Module):
|
||||||
|
raise ValueError(
|
||||||
|
f"Pipeline component: {component} is not a torch.nn.Module. Cannot apply dynamic upcasting."
|
||||||
|
)
|
||||||
|
|
||||||
|
use_keep_in_fp32_modules = (
|
||||||
|
hasattr(component_module, "_keep_in_fp32_modules")
|
||||||
|
and (component_module._keep_in_fp32_modules is not None)
|
||||||
|
and (upcast_dtype != torch.float32)
|
||||||
|
)
|
||||||
|
if use_keep_in_fp32_modules:
|
||||||
|
keep_in_fp32_modules = component_module._keep_in_fp32_modules
|
||||||
|
else:
|
||||||
|
keep_in_fp32_modules = []
|
||||||
|
|
||||||
|
original_dtype = component_module.dtype
|
||||||
|
for name, module in component_module.named_children():
|
||||||
|
fn_recursive_upcast(module, upcast_dtype, original_dtype, keep_in_fp32_modules)
|
||||||
|
|
||||||
|
def disable_dynamic_upcasting(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
def fn_recursive_upcast(module):
|
||||||
|
has_children = list(module.children())
|
||||||
|
if not has_children:
|
||||||
|
module._forward_pre_hooks = OrderedDict()
|
||||||
|
module._forward_hooks = OrderedDict()
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_upcast(child)
|
||||||
|
|
||||||
|
for component in self.components:
|
||||||
|
if not hasattr(self, component):
|
||||||
|
raise ValueError(f"Pipeline has no component named: {component}")
|
||||||
|
|
||||||
|
component_module = getattr(self, component)
|
||||||
|
if not issubclass(component_module, torch.nn.Module):
|
||||||
|
raise ValueError(
|
||||||
|
f"Pipeline component: {component} is not an torch.nn.Module. Cannot apply dynamic upcasting."
|
||||||
|
)
|
||||||
|
|
||||||
|
for module in component_module.children():
|
||||||
|
fn_recursive_upcast(module)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@validate_hf_hub_args
|
@validate_hf_hub_args
|
||||||
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
|
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
|
||||||
|
|||||||
Reference in New Issue
Block a user