Compare commits

...

9 Commits

Author SHA1 Message Date
Sayak Paul
51a855c8c6 Merge branch 'main' into layerwise-upcasting 2024-08-19 14:22:12 +05:30
Sayak Paul
c64fa22c08 Merge branch 'main' into layerwise-upcasting 2024-08-19 09:13:18 +05:30
Sayak Paul
0d1a1f875a Merge branch 'main' into layerwise-upcasting 2024-08-16 14:15:15 +05:30
Sayak Paul
f1fa1235e4 Merge branch 'main' into layerwise-upcasting 2024-08-16 09:48:53 +05:30
Sayak Paul
9b411e5ff3 Merge branch 'main' into layerwise-upcasting 2024-08-15 10:34:40 +05:30
Dhruv Nair
b366b22191 update 2024-08-14 14:50:18 +02:00
Dhruv Nair
1fdae85f49 update 2024-08-14 14:19:20 +02:00
Dhruv Nair
6b9fd0905e update 2024-08-14 08:22:43 +02:00
Dhruv Nair
be55fa631f update 2024-08-13 14:11:47 +02:00
20 changed files with 154 additions and 12 deletions

View File

@@ -449,7 +449,7 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
self.scale_shift_table[None].to(timestep.dtype) + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa

View File

@@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""
_always_upcast_modules = ["MaskConditionDecoder"]
@register_to_config
def __init__(
self,

View File

@@ -70,6 +70,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
_always_upcast_modules = ["Decoder"]
@register_to_config
def __init__(

View File

@@ -192,6 +192,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_always_upcast_modules = ["TemporalDecoder"]
@register_to_config
def __init__(

View File

@@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = False
_always_upcast_modules = ["OobleckEncoder", "OobleckDecoder"]
@register_to_config
def __init__(

View File

@@ -330,7 +330,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
"""
z = (z * self.config.scaling_factor - self.means) / self.stds
z = (z * self.config.scaling_factor - self.means.to(z.dtype)) / self.stds.to(z.dtype)
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)

View File

@@ -71,6 +71,8 @@ class VQModel(ModelMixin, ConfigMixin):
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
"""
_always_upcast_modules = ["Decoder", "VectorQuantizer"]
@register_to_config
def __init__(
self,

View File

@@ -263,6 +263,80 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"""
self.set_use_memory_efficient_attention_xformers(False)
def enable_layerwise_upcasting(self, upcast_dtype=None):
r"""
Enable layerwise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g.
torch.float8_e4m3fn, but perform inference using a dtype that is supported by the GPU, by upcasting the
individual modules in the model to the appropriate dtype right before the foward pass.
The module is then moved back to the low memory dtype after the foward pass.
"""
upcast_dtype = upcast_dtype or torch.float32
original_dtype = self.dtype
def upcast_dtype_hook_fn(module, *args, **kwargs):
module = module.to(upcast_dtype)
def cast_to_original_dtype_hook_fn(module, *args, **kwargs):
module = module.to(original_dtype)
def fn_recursive_upcast(module):
"""In certain cases modules will apply casting internally or reference the dtype of internal blocks.
e.g.
```
class MyModel(nn.Module):
def forward(self, x):
dtype = next(iter(self.blocks.parameters())).dtype
x = self.blocks(x) + torch.ones(x.size()).to(dtype)
```
Layerwise upcasting will not work here, since the internal blocks remain in the low memory dtype until
their `forward` method is called. We need to add the upcast hook on the entire module in order for the
operation to work.
The `_always_upcast_modules` class attribute is a list of modules within the model that we must upcast
entirely, rather than layerwise.
"""
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules:
# Upcast entire module and exist recursion
module.register_forward_pre_hook(upcast_dtype_hook_fn)
module.register_forward_hook(cast_to_original_dtype_hook_fn)
return
has_children = list(module.children())
if not has_children:
module.register_forward_pre_hook(upcast_dtype_hook_fn)
module.register_forward_hook(cast_to_original_dtype_hook_fn)
for child in module.children():
fn_recursive_upcast(child)
for module in self.children():
fn_recursive_upcast(module)
def disable_layerwise_upcasting(self):
def fn_recursive_upcast(module):
if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules:
module._forward_pre_hooks = OrderedDict()
module._forward_hooks = OrderedDict()
return
has_children = list(module.children())
if not has_children:
module._forward_pre_hooks = OrderedDict()
module._forward_hooks = OrderedDict()
for child in module.children():
fn_recursive_upcast(child)
for module in self.children():
fn_recursive_upcast(module)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],

View File

@@ -276,6 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
_supports_gradient_checkpointing = True
_always_upcast_modules = ["AuraFlowPatchEmbed"]
@register_to_config
def __init__(
@@ -457,11 +458,15 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
# Apply patch embedding, timestep embedding, and project the caption embeddings.
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
temb = self.time_step_embed(timestep).to(dtype=hidden_states.dtype)
temb = self.time_step_proj(temb)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
encoder_hidden_states = torch.cat(
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
[
self.register_tokens.to(encoder_hidden_states.dtype).repeat(encoder_hidden_states.size(0), 1, 1),
encoder_hidden_states,
],
dim=1,
)
# MMDiT blocks.

View File

@@ -65,6 +65,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_always_upcast_modules = ["PatchEmbed"]
@register_to_config
def __init__(

View File

@@ -244,6 +244,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
"""
_always_upcast_modules = ["HunyuanDiTAttentionPool"]
@register_to_config
def __init__(
self,
@@ -484,7 +486,9 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
encoder_hidden_states = torch.where(
text_embedding_mask, encoder_hidden_states, self.text_embedding_padding.to(encoder_hidden_states.dtype)
)
skips = []
for layer, block in enumerate(self.blocks):

View File

@@ -64,6 +64,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
video_length (`int`, *optional*):
The number of frames in the video-like data.
"""
_always_upcast_modules = ["PatchEmbed"]
@register_to_config
def __init__(
@@ -301,7 +302,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
shift, scale = (self.scale_shift_table[None].to(embedded_timestep.dtype) + embedded_timestep[:, None]).chunk(
2, dim=1
)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift

View File

@@ -79,6 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
_always_upcast_modules = ["PatchEmbed"]
@register_to_config
def __init__(
@@ -422,7 +423,8 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
# 3. Output
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
self.scale_shift_table[None].to(embedded_timestep.dtype)
+ embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation

View File

@@ -289,7 +289,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might be fp16, so we need to cast here.
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
timesteps_projected = timesteps_projected.to(dtype=hidden_states.dtype)
time_embeddings = self.time_embedding(timesteps_projected)
if self.embedding_proj_norm is not None:

View File

@@ -54,6 +54,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
"""
_supports_gradient_checkpointing = True
_always_upcast_modules = ["PatchEmbed"]
@register_to_config
def __init__(

View File

@@ -283,7 +283,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:

View File

@@ -641,7 +641,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(repeats=num_frames, dim=0)

View File

@@ -590,7 +590,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
t_emb = t_emb.to(dtype=sample.dtype)
t_emb = self.time_embedding(t_emb, timestep_cond)
# 2. FPS

View File

@@ -2152,7 +2152,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None

View File

@@ -43,6 +43,8 @@ from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_
from diffusers.utils.hub_utils import _add_variant
from diffusers.utils.testing_utils import (
CaptureLogger,
disable_full_determinism,
enable_full_determinism,
get_python_version,
is_torch_compile,
require_torch_2,
@@ -984,6 +986,49 @@ class ModelTesterMixin:
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu
def test_layerwise_upcasting(self):
disable_full_determinism()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_cached()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model.to(torch_device)
model(**inputs_dict)
base_max_memory = torch.cuda.max_memory_allocated()
# Remove model
model.to("cpu")
del model
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_cached()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
low_memory_dtype = torch.float8_e4m3fn
upcast_dtype = torch.float32
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
low_mem_model = self.model_class(**config).eval()
low_mem_model.to(low_memory_dtype)
low_mem_model.to(torch_device)
layerwise_max_memory = torch.cuda.max_memory_allocated()
low_mem_model.enable_layerwise_upcasting(upcast_dtype)
low_mem_model(**inputs_dict)
assert layerwise_max_memory < base_max_memory
enable_full_determinism()
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):