mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Compare commits
9 Commits
sana-tests
...
layerwise-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
51a855c8c6 | ||
|
|
c64fa22c08 | ||
|
|
0d1a1f875a | ||
|
|
f1fa1235e4 | ||
|
|
9b411e5ff3 | ||
|
|
b366b22191 | ||
|
|
1fdae85f49 | ||
|
|
6b9fd0905e | ||
|
|
be55fa631f |
@@ -449,7 +449,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
self.scale_shift_table[None].to(timestep.dtype) + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
@@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
"""
|
||||
|
||||
_always_upcast_modules = ["MaskConditionDecoder"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -70,6 +70,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
|
||||
_always_upcast_modules = ["Decoder"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -192,6 +192,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_always_upcast_modules = ["TemporalDecoder"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
_always_upcast_modules = ["OobleckEncoder", "OobleckDecoder"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -330,7 +330,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
|
||||
|
||||
"""
|
||||
z = (z * self.config.scaling_factor - self.means) / self.stds
|
||||
z = (z * self.config.scaling_factor - self.means.to(z.dtype)) / self.stds.to(z.dtype)
|
||||
|
||||
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
|
||||
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
|
||||
|
||||
@@ -71,6 +71,8 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
|
||||
"""
|
||||
|
||||
_always_upcast_modules = ["Decoder", "VectorQuantizer"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -263,6 +263,80 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def enable_layerwise_upcasting(self, upcast_dtype=None):
|
||||
r"""
|
||||
Enable layerwise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g.
|
||||
torch.float8_e4m3fn, but perform inference using a dtype that is supported by the GPU, by upcasting the
|
||||
individual modules in the model to the appropriate dtype right before the foward pass.
|
||||
|
||||
The module is then moved back to the low memory dtype after the foward pass.
|
||||
"""
|
||||
|
||||
upcast_dtype = upcast_dtype or torch.float32
|
||||
original_dtype = self.dtype
|
||||
|
||||
def upcast_dtype_hook_fn(module, *args, **kwargs):
|
||||
module = module.to(upcast_dtype)
|
||||
|
||||
def cast_to_original_dtype_hook_fn(module, *args, **kwargs):
|
||||
module = module.to(original_dtype)
|
||||
|
||||
def fn_recursive_upcast(module):
|
||||
"""In certain cases modules will apply casting internally or reference the dtype of internal blocks.
|
||||
|
||||
e.g.
|
||||
|
||||
```
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x):
|
||||
dtype = next(iter(self.blocks.parameters())).dtype
|
||||
x = self.blocks(x) + torch.ones(x.size()).to(dtype)
|
||||
```
|
||||
Layerwise upcasting will not work here, since the internal blocks remain in the low memory dtype until
|
||||
their `forward` method is called. We need to add the upcast hook on the entire module in order for the
|
||||
operation to work.
|
||||
|
||||
The `_always_upcast_modules` class attribute is a list of modules within the model that we must upcast
|
||||
entirely, rather than layerwise.
|
||||
|
||||
"""
|
||||
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules:
|
||||
# Upcast entire module and exist recursion
|
||||
module.register_forward_pre_hook(upcast_dtype_hook_fn)
|
||||
module.register_forward_hook(cast_to_original_dtype_hook_fn)
|
||||
|
||||
return
|
||||
|
||||
has_children = list(module.children())
|
||||
if not has_children:
|
||||
module.register_forward_pre_hook(upcast_dtype_hook_fn)
|
||||
module.register_forward_hook(cast_to_original_dtype_hook_fn)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_upcast(child)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_upcast(module)
|
||||
|
||||
def disable_layerwise_upcasting(self):
|
||||
def fn_recursive_upcast(module):
|
||||
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules:
|
||||
module._forward_pre_hooks = OrderedDict()
|
||||
module._forward_hooks = OrderedDict()
|
||||
|
||||
return
|
||||
|
||||
has_children = list(module.children())
|
||||
if not has_children:
|
||||
module._forward_pre_hooks = OrderedDict()
|
||||
module._forward_hooks = OrderedDict()
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_upcast(child)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_upcast(module)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
||||
@@ -276,6 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
|
||||
_supports_gradient_checkpointing = True
|
||||
_always_upcast_modules = ["AuraFlowPatchEmbed"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -457,11 +458,15 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# Apply patch embedding, timestep embedding, and project the caption embeddings.
|
||||
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||
temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
|
||||
temb = self.time_step_embed(timestep).to(dtype=hidden_states.dtype)
|
||||
temb = self.time_step_proj(temb)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
encoder_hidden_states = torch.cat(
|
||||
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
|
||||
[
|
||||
self.register_tokens.to(encoder_hidden_states.dtype).repeat(encoder_hidden_states.size(0), 1, 1),
|
||||
encoder_hidden_states,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# MMDiT blocks.
|
||||
|
||||
@@ -65,6 +65,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_always_upcast_modules = ["PatchEmbed"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -244,6 +244,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
|
||||
"""
|
||||
|
||||
_always_upcast_modules = ["HunyuanDiTAttentionPool"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -484,7 +486,9 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
|
||||
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
|
||||
|
||||
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
|
||||
encoder_hidden_states = torch.where(
|
||||
text_embedding_mask, encoder_hidden_states, self.text_embedding_padding.to(encoder_hidden_states.dtype)
|
||||
)
|
||||
|
||||
skips = []
|
||||
for layer, block in enumerate(self.blocks):
|
||||
|
||||
@@ -64,6 +64,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
video_length (`int`, *optional*):
|
||||
The number of frames in the video-like data.
|
||||
"""
|
||||
_always_upcast_modules = ["PatchEmbed"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -301,7 +302,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
|
||||
|
||||
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
shift, scale = (self.scale_shift_table[None].to(embedded_timestep.dtype) + embedded_timestep[:, None]).chunk(
|
||||
2, dim=1
|
||||
)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
@@ -79,6 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
|
||||
_always_upcast_modules = ["PatchEmbed"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -422,7 +423,8 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 3. Output
|
||||
shift, scale = (
|
||||
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
|
||||
self.scale_shift_table[None].to(embedded_timestep.dtype)
|
||||
+ embedded_timestep[:, None].to(self.scale_shift_table.device)
|
||||
).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
|
||||
@@ -289,7 +289,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might be fp16, so we need to cast here.
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
timesteps_projected = timesteps_projected.to(dtype=hidden_states.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
if self.embedding_proj_norm is not None:
|
||||
|
||||
@@ -54,6 +54,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_always_upcast_modules = ["PatchEmbed"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -283,7 +283,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
|
||||
@@ -641,7 +641,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
@@ -590,7 +590,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
t_emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
# 2. FPS
|
||||
|
||||
@@ -2152,7 +2152,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
@@ -43,6 +43,8 @@ from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_
|
||||
from diffusers.utils.hub_utils import _add_variant
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
disable_full_determinism,
|
||||
enable_full_determinism,
|
||||
get_python_version,
|
||||
is_torch_compile,
|
||||
require_torch_2,
|
||||
@@ -984,6 +986,49 @@ class ModelTesterMixin:
|
||||
new_output = new_model(**inputs_dict)
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_torch_gpu
|
||||
def test_layerwise_upcasting(self):
|
||||
disable_full_determinism()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_cached()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
model.to(torch_device)
|
||||
|
||||
model(**inputs_dict)
|
||||
base_max_memory = torch.cuda.max_memory_allocated()
|
||||
|
||||
# Remove model
|
||||
model.to("cpu")
|
||||
del model
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_cached()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
low_memory_dtype = torch.float8_e4m3fn
|
||||
upcast_dtype = torch.float32
|
||||
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
low_mem_model = self.model_class(**config).eval()
|
||||
low_mem_model.to(low_memory_dtype)
|
||||
low_mem_model.to(torch_device)
|
||||
layerwise_max_memory = torch.cuda.max_memory_allocated()
|
||||
low_mem_model.enable_layerwise_upcasting(upcast_dtype)
|
||||
low_mem_model(**inputs_dict)
|
||||
|
||||
assert layerwise_max_memory < base_max_memory
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ModelPushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user