Compare commits

...

1 Commits

Author SHA1 Message Date
Dhruv Nair
be55fa631f update 2024-08-13 14:11:47 +02:00
2 changed files with 123 additions and 0 deletions

View File

@@ -263,6 +263,41 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"""
self.set_use_memory_efficient_attention_xformers(False)
def enable_dynamic_upcasting(self, upcast_dtype=None):
upcast_dtype = upcast_dtype or torch.float32
downcast_dtype = self.dtype
def upcast_hook_fn(module):
module = module.to(upcast_dtype)
def downcast_hook_fn(module):
module = module.to(downcast_dtype)
def fn_recursive_upcast(module):
has_children = list(module.children())
if not has_children:
module.register_forward_pre_hook(upcast_hook_fn)
module.register_forward_hook(downcast_hook_fn)
for child in module.children():
fn_recursive_upcast(child)
for module in self.children():
fn_recursive_upcast(module)
def disable_dynamic_upcasting(self):
def fn_recursive_upcast(module):
has_children = list(module.children())
if not has_children:
module._forward_pre_hooks = OrderedDict()
module._forward_hooks = OrderedDict()
for child in module.children():
fn_recursive_upcast(child)
for module in self.children():
fn_recursive_upcast(module)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],

View File

@@ -19,6 +19,7 @@ import inspect
import os
import re
import sys
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
@@ -1172,6 +1173,93 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
component.to("cpu")
self.hf_device_map = None
def enable_dynamic_upcasting(
self,
components: Optional[List[str]] = None,
upcast_dtype: Optional[torch.dtype] = None,
):
r"""
Enable module-wise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g.
torch.float8_e4m3fn, but perform inference using a dtype that is supported on the GPU, by casting the module to
the appropriate dtype right before the foward pass. The module is then moved back to the low memory dtype after
the foward pass.
"""
if components is None:
raise ValueError("Please provide a list of pipeline component names to apply dynamic upcasting")
def fn_recursive_upcast(module, dtype, original_dtype, keep_in_fp32_modules):
has_children = list(module.children())
upcast_dtype = dtype
downcast_dtype = original_dtype
def upcast_hook_fn(module, inputs):
module = module.to(upcast_dtype)
def downcast_hook_fn(module, *args, **kwargs):
module = module.to(downcast_dtype)
if not has_children:
module.register_forward_pre_hook(upcast_hook_fn)
module.register_forward_hook(downcast_hook_fn)
for name, child in module.named_children():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
dtype = torch.float32
else:
dtype = upcast_dtype
fn_recursive_upcast(child, dtype, original_dtype, keep_in_fp32_modules)
for component in components:
if not hasattr(self, component):
raise ValueError(f"Pipeline has no component named: {component}")
component_module = getattr(self, component)
if not isinstance(component_module, torch.nn.Module):
raise ValueError(
f"Pipeline component: {component} is not a torch.nn.Module. Cannot apply dynamic upcasting."
)
use_keep_in_fp32_modules = (
hasattr(component_module, "_keep_in_fp32_modules")
and (component_module._keep_in_fp32_modules is not None)
and (upcast_dtype != torch.float32)
)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = component_module._keep_in_fp32_modules
else:
keep_in_fp32_modules = []
original_dtype = component_module.dtype
for name, module in component_module.named_children():
fn_recursive_upcast(module, upcast_dtype, original_dtype, keep_in_fp32_modules)
def disable_dynamic_upcasting(
self,
):
def fn_recursive_upcast(module):
has_children = list(module.children())
if not has_children:
module._forward_pre_hooks = OrderedDict()
module._forward_hooks = OrderedDict()
for child in module.children():
fn_recursive_upcast(child)
for component in self.components:
if not hasattr(self, component):
raise ValueError(f"Pipeline has no component named: {component}")
component_module = getattr(self, component)
if not issubclass(component_module, torch.nn.Module):
raise ValueError(
f"Pipeline component: {component} is not an torch.nn.Module. Cannot apply dynamic upcasting."
)
for module in component_module.children():
fn_recursive_upcast(module)
@classmethod
@validate_hf_hub_args
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: