mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
1 Commits
shared-var
...
dynamic-up
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be55fa631f |
@@ -263,6 +263,41 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def enable_dynamic_upcasting(self, upcast_dtype=None):
|
||||
upcast_dtype = upcast_dtype or torch.float32
|
||||
downcast_dtype = self.dtype
|
||||
|
||||
def upcast_hook_fn(module):
|
||||
module = module.to(upcast_dtype)
|
||||
|
||||
def downcast_hook_fn(module):
|
||||
module = module.to(downcast_dtype)
|
||||
|
||||
def fn_recursive_upcast(module):
|
||||
has_children = list(module.children())
|
||||
if not has_children:
|
||||
module.register_forward_pre_hook(upcast_hook_fn)
|
||||
module.register_forward_hook(downcast_hook_fn)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_upcast(child)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_upcast(module)
|
||||
|
||||
def disable_dynamic_upcasting(self):
|
||||
def fn_recursive_upcast(module):
|
||||
has_children = list(module.children())
|
||||
if not has_children:
|
||||
module._forward_pre_hooks = OrderedDict()
|
||||
module._forward_hooks = OrderedDict()
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_upcast(child)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_upcast(module)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
||||
@@ -19,6 +19,7 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
|
||||
@@ -1172,6 +1173,93 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
component.to("cpu")
|
||||
self.hf_device_map = None
|
||||
|
||||
def enable_dynamic_upcasting(
|
||||
self,
|
||||
components: Optional[List[str]] = None,
|
||||
upcast_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Enable module-wise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g.
|
||||
torch.float8_e4m3fn, but perform inference using a dtype that is supported on the GPU, by casting the module to
|
||||
the appropriate dtype right before the foward pass. The module is then moved back to the low memory dtype after
|
||||
the foward pass.
|
||||
|
||||
"""
|
||||
if components is None:
|
||||
raise ValueError("Please provide a list of pipeline component names to apply dynamic upcasting")
|
||||
|
||||
def fn_recursive_upcast(module, dtype, original_dtype, keep_in_fp32_modules):
|
||||
has_children = list(module.children())
|
||||
upcast_dtype = dtype
|
||||
downcast_dtype = original_dtype
|
||||
|
||||
def upcast_hook_fn(module, inputs):
|
||||
module = module.to(upcast_dtype)
|
||||
|
||||
def downcast_hook_fn(module, *args, **kwargs):
|
||||
module = module.to(downcast_dtype)
|
||||
|
||||
if not has_children:
|
||||
module.register_forward_pre_hook(upcast_hook_fn)
|
||||
module.register_forward_hook(downcast_hook_fn)
|
||||
|
||||
for name, child in module.named_children():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = upcast_dtype
|
||||
|
||||
fn_recursive_upcast(child, dtype, original_dtype, keep_in_fp32_modules)
|
||||
|
||||
for component in components:
|
||||
if not hasattr(self, component):
|
||||
raise ValueError(f"Pipeline has no component named: {component}")
|
||||
|
||||
component_module = getattr(self, component)
|
||||
if not isinstance(component_module, torch.nn.Module):
|
||||
raise ValueError(
|
||||
f"Pipeline component: {component} is not a torch.nn.Module. Cannot apply dynamic upcasting."
|
||||
)
|
||||
|
||||
use_keep_in_fp32_modules = (
|
||||
hasattr(component_module, "_keep_in_fp32_modules")
|
||||
and (component_module._keep_in_fp32_modules is not None)
|
||||
and (upcast_dtype != torch.float32)
|
||||
)
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = component_module._keep_in_fp32_modules
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
original_dtype = component_module.dtype
|
||||
for name, module in component_module.named_children():
|
||||
fn_recursive_upcast(module, upcast_dtype, original_dtype, keep_in_fp32_modules)
|
||||
|
||||
def disable_dynamic_upcasting(
|
||||
self,
|
||||
):
|
||||
def fn_recursive_upcast(module):
|
||||
has_children = list(module.children())
|
||||
if not has_children:
|
||||
module._forward_pre_hooks = OrderedDict()
|
||||
module._forward_hooks = OrderedDict()
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_upcast(child)
|
||||
|
||||
for component in self.components:
|
||||
if not hasattr(self, component):
|
||||
raise ValueError(f"Pipeline has no component named: {component}")
|
||||
|
||||
component_module = getattr(self, component)
|
||||
if not issubclass(component_module, torch.nn.Module):
|
||||
raise ValueError(
|
||||
f"Pipeline component: {component} is not an torch.nn.Module. Cannot apply dynamic upcasting."
|
||||
)
|
||||
|
||||
for module in component_module.children():
|
||||
fn_recursive_upcast(module)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
|
||||
|
||||
Reference in New Issue
Block a user