mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Compare commits
5 Commits
8d415a6f48
...
256e010674
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
256e010674 | ||
|
|
8430ac2a2f | ||
|
|
bb9e713d02 | ||
|
|
c98c157a9e | ||
|
|
f12d161d67 |
@@ -401,6 +401,8 @@
|
||||
title: WanAnimateTransformer3DModel
|
||||
- local: api/models/wan_transformer_3d
|
||||
title: WanTransformer3DModel
|
||||
- local: api/models/z_image_transformer2d
|
||||
title: ZImageTransformer2DModel
|
||||
title: Transformers
|
||||
- sections:
|
||||
- local: api/models/stable_cascade_unet
|
||||
@@ -551,6 +553,8 @@
|
||||
title: Kandinsky 2.2
|
||||
- local: api/pipelines/kandinsky3
|
||||
title: Kandinsky 3
|
||||
- local: api/pipelines/kandinsky5_image
|
||||
title: Kandinsky 5.0 Image
|
||||
- local: api/pipelines/kolors
|
||||
title: Kolors
|
||||
- local: api/pipelines/latent_consistency_models
|
||||
@@ -646,6 +650,8 @@
|
||||
title: VisualCloze
|
||||
- local: api/pipelines/wuerstchen
|
||||
title: Wuerstchen
|
||||
- local: api/pipelines/z_image
|
||||
title: Z-Image
|
||||
title: Image
|
||||
- sections:
|
||||
- local: api/pipelines/allegro
|
||||
@@ -664,8 +670,6 @@
|
||||
title: HunyuanVideo1.5
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/kandinsky5_image
|
||||
title: Kandinsky 5.0 Image
|
||||
- local: api/pipelines/kandinsky5_video
|
||||
title: Kandinsky 5.0 Video
|
||||
- local: api/pipelines/latte
|
||||
|
||||
19
docs/source/en/api/models/z_image_transformer2d.md
Normal file
19
docs/source/en/api/models/z_image_transformer2d.md
Normal file
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ZImageTransformer2DModel
|
||||
|
||||
A Transformer model for image-like data from [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo).
|
||||
|
||||
## ZImageTransformer2DModel
|
||||
|
||||
[[autodoc]] ZImageTransformer2DModel
|
||||
@@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
[Kandinsky 5.0](https://arxiv.org/abs/2511.14993) is a family of diffusion models for Video & Image generation.
|
||||
|
||||
Kandinsky 5.0 Image Lite is a lightweight image generation model (6B parameters)
|
||||
Kandinsky 5.0 Image Lite is a lightweight image generation model (6B parameters).
|
||||
|
||||
The model introduces several key innovations:
|
||||
- **Latent diffusion pipeline** with **Flow Matching** for improved training stability
|
||||
@@ -21,10 +21,14 @@ The model introduces several key innovations:
|
||||
|
||||
The original codebase can be found at [kandinskylab/Kandinsky-5](https://github.com/kandinskylab/Kandinsky-5).
|
||||
|
||||
> [!TIP]
|
||||
> Check out the [Kandinsky Lab](https://huggingface.co/kandinskylab) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.
|
||||
|
||||
|
||||
## Available Models
|
||||
|
||||
Kandinsky 5.0 Image Lite:
|
||||
|
||||
| model_id | Description | Use Cases |
|
||||
|------------|-------------|-----------|
|
||||
| [**kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers**](https://huggingface.co/kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers) | 6B image Supervised Fine-Tuned model | Highest generation quality |
|
||||
|
||||
@@ -30,6 +30,7 @@ The original codebase can be found at [kandinskylab/Kandinsky-5](https://github.
|
||||
## Available Models
|
||||
|
||||
Kandinsky 5.0 T2V Pro:
|
||||
|
||||
| model_id | Description | Use Cases |
|
||||
|------------|-------------|-----------|
|
||||
| **kandinskylab/Kandinsky-5.0-T2V-Pro-sft-5s-Diffusers** | 5 second Text-to-Video Pro model | High-quality text-to-video generation |
|
||||
|
||||
33
docs/source/en/api/pipelines/z_image.md
Normal file
33
docs/source/en/api/pipelines/z_image.md
Normal file
@@ -0,0 +1,33 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Z-Image
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
[Z-Image](https://huggingface.co/papers/2511.22699) is a powerful and highly efficient image generation model with 6B parameters. Currently there's only one model with two more to be released:
|
||||
|
||||
|Model|Hugging Face|
|
||||
|---|---|
|
||||
|Z-Image-Turbo|https://huggingface.co/Tongyi-MAI/Z-Image-Turbo|
|
||||
|
||||
## Z-Image-Turbo
|
||||
|
||||
Z-Image-Turbo is a distilled version of Z-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence.
|
||||
|
||||
## ZImagePipeline
|
||||
|
||||
[[autodoc]] ZImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -15,7 +15,7 @@
|
||||
import hashlib
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
@@ -59,6 +59,9 @@ class GroupOffloadingConfig:
|
||||
num_blocks_per_group: Optional[int] = None
|
||||
offload_to_disk_path: Optional[str] = None
|
||||
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
|
||||
block_modules: Optional[List[str]] = None
|
||||
exclude_kwargs: Optional[List[str]] = None
|
||||
module_prefix: Optional[str] = ""
|
||||
|
||||
|
||||
class ModuleGroup:
|
||||
@@ -77,7 +80,7 @@ class ModuleGroup:
|
||||
low_cpu_mem_usage: bool = False,
|
||||
onload_self: bool = True,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
group_id: Optional[int] = None,
|
||||
group_id: Optional[Union[int, str]] = None,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
self.offload_device = offload_device
|
||||
@@ -322,7 +325,21 @@ class GroupOffloadingHook(ModelHook):
|
||||
self.group.stream.synchronize()
|
||||
|
||||
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
|
||||
# Some Autoencoder models use a feature cache that is passed through submodules
|
||||
# and modified in place. The `send_to_device` call returns a copy of this feature cache object
|
||||
# which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
|
||||
exclude_kwargs = self.config.exclude_kwargs or []
|
||||
if exclude_kwargs:
|
||||
moved_kwargs = send_to_device(
|
||||
{k: v for k, v in kwargs.items() if k not in exclude_kwargs},
|
||||
self.group.onload_device,
|
||||
non_blocking=self.group.non_blocking,
|
||||
)
|
||||
kwargs.update(moved_kwargs)
|
||||
else:
|
||||
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||||
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output):
|
||||
@@ -455,6 +472,8 @@ def apply_group_offloading(
|
||||
record_stream: bool = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
block_modules: Optional[List[str]] = None,
|
||||
exclude_kwargs: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
||||
@@ -512,6 +531,13 @@ def apply_group_offloading(
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
block_modules (`List[str]`, *optional*):
|
||||
List of module names that should be treated as blocks for offloading. If provided, only these modules will
|
||||
be considered for block-level offloading. If not provided, the default block detection logic will be used.
|
||||
exclude_kwargs (`List[str]`, *optional*):
|
||||
List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
|
||||
caching lists that need to maintain their object identity across forward passes. If not provided, will be
|
||||
inferred from the module's `_skip_keys` attribute if it exists.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -553,6 +579,12 @@ def apply_group_offloading(
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
if block_modules is None:
|
||||
block_modules = getattr(module, "_group_offload_block_modules", None)
|
||||
|
||||
if exclude_kwargs is None:
|
||||
exclude_kwargs = getattr(module, "_skip_keys", None)
|
||||
|
||||
config = GroupOffloadingConfig(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
@@ -563,6 +595,8 @@ def apply_group_offloading(
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
block_modules=block_modules,
|
||||
exclude_kwargs=exclude_kwargs,
|
||||
)
|
||||
_apply_group_offloading(module, config)
|
||||
|
||||
@@ -578,46 +612,66 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf
|
||||
|
||||
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
|
||||
"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
|
||||
defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is
|
||||
done at the top-level blocks and modules specified in block_modules.
|
||||
|
||||
When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
|
||||
module, recursively apply block offloading to it.
|
||||
"""
|
||||
if config.stream is not None and config.num_blocks_per_group != 1:
|
||||
logger.warning(
|
||||
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
|
||||
)
|
||||
config.num_blocks_per_group = 1
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
block_modules = set(config.block_modules) if config.block_modules is not None else set()
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
|
||||
modules_with_group_offloading = set()
|
||||
unmatched_modules = []
|
||||
matched_module_groups = []
|
||||
for name, submodule in module.named_children():
|
||||
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
||||
unmatched_modules.append((name, submodule))
|
||||
modules_with_group_offloading.add(name)
|
||||
continue
|
||||
|
||||
for i in range(0, len(submodule), config.num_blocks_per_group):
|
||||
current_modules = submodule[i : i + config.num_blocks_per_group]
|
||||
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
|
||||
group = ModuleGroup(
|
||||
modules=current_modules,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=current_modules[-1],
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
group_id=group_id,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
for j in range(i, i + len(current_modules)):
|
||||
modules_with_group_offloading.add(f"{name}.{j}")
|
||||
for name, submodule in module.named_children():
|
||||
# Check if this is an explicitly defined block module
|
||||
if name in block_modules:
|
||||
# Track submodule using a prefix to avoid filename collisions during disk offload.
|
||||
# Without this, submodules sharing the same model class would be assigned identical
|
||||
# filenames (derived from the class name).
|
||||
prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
|
||||
submodule_config = replace(config, module_prefix=prefix)
|
||||
|
||||
_apply_group_offloading_block_level(submodule, submodule_config)
|
||||
modules_with_group_offloading.add(name)
|
||||
|
||||
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
||||
# Handle ModuleList and Sequential blocks as before
|
||||
for i in range(0, len(submodule), config.num_blocks_per_group):
|
||||
current_modules = list(submodule[i : i + config.num_blocks_per_group])
|
||||
if len(current_modules) == 0:
|
||||
continue
|
||||
|
||||
group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
|
||||
group = ModuleGroup(
|
||||
modules=current_modules,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=current_modules[-1],
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
group_id=group_id,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
for j in range(i, i + len(current_modules)):
|
||||
modules_with_group_offloading.add(f"{name}.{j}")
|
||||
else:
|
||||
# This is an unmatched module
|
||||
unmatched_modules.append((name, submodule))
|
||||
|
||||
# Apply group offloading hooks to the module groups
|
||||
for i, group in enumerate(matched_module_groups):
|
||||
@@ -632,28 +686,29 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
|
||||
parameters = [param for _, param in parameters]
|
||||
buffers = [buffer for _, buffer in buffers]
|
||||
|
||||
# Create a group for the unmatched submodules of the top-level module so that they are on the correct
|
||||
# device when the forward pass is called.
|
||||
# Create a group for the remaining unmatched submodules of the top-level
|
||||
# module so that they are on the correct device when the forward pass is called.
|
||||
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=unmatched_modules,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=parameters,
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
record_stream=False,
|
||||
onload_self=True,
|
||||
group_id=f"{module.__class__.__name__}_unmatched_group",
|
||||
)
|
||||
if config.stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, config=config)
|
||||
else:
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
||||
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=unmatched_modules,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=parameters,
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
record_stream=False,
|
||||
onload_self=True,
|
||||
group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group",
|
||||
)
|
||||
if config.stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, config=config)
|
||||
else:
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
||||
|
||||
|
||||
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
|
||||
@@ -74,6 +74,7 @@ class AutoencoderKL(
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
|
||||
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -619,6 +619,7 @@ class WanEncoder3d(nn.Module):
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -961,6 +962,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
|
||||
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
|
||||
# these are shared mutable state modified in-place
|
||||
_skip_keys = ["feat_cache", "feat_idx"]
|
||||
@@ -1414,6 +1416,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
|
||||
@@ -531,6 +531,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
record_stream: bool = False,
|
||||
low_cpu_mem_usage=False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
block_modules: Optional[str] = None,
|
||||
exclude_kwargs: Optional[str] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Activates group offloading for the current model.
|
||||
@@ -570,6 +572,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
|
||||
f"open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
|
||||
apply_group_offloading(
|
||||
module=self,
|
||||
onload_device=onload_device,
|
||||
@@ -581,6 +584,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
block_modules=block_modules,
|
||||
exclude_kwargs=exclude_kwargs,
|
||||
)
|
||||
|
||||
def set_attention_backend(self, backend: str) -> None:
|
||||
|
||||
@@ -84,33 +84,35 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
num_train_timesteps (`int`, defaults to `1000`):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
beta_start (`float`, defaults to `0.0001`):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
beta_end (`float`, defaults to `0.02`):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
trained_betas (`np.ndarray` or `List[float]`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
solver_order (`int`, defaults to 2):
|
||||
solver_order (`int`, defaults to `2`):
|
||||
The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
prediction_type (`str`, defaults to `epsilon`):
|
||||
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://huggingface.co/papers/2210.02303) paper).
|
||||
`sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||
as Stable Diffusion.
|
||||
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||||
dynamic_thresholding_ratio (`float`, defaults to `0.995`):
|
||||
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
sample_max_value (`float`, defaults to `1.0`):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
algorithm_type (`str`, defaults to `deis`):
|
||||
algorithm_type (`"deis"`, defaults to `"deis"`):
|
||||
The algorithm type for the solver.
|
||||
solver_type (`"logrho"`, defaults to `"logrho"`):
|
||||
Solver type for DEIS.
|
||||
lower_order_final (`bool`, defaults to `True`):
|
||||
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps.
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
@@ -121,11 +123,19 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
||||
flow_shift (`float`, *optional*, defaults to `1.0`):
|
||||
The flow shift parameter for flow-based models.
|
||||
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
steps_offset (`int`, defaults to `0`):
|
||||
An offset added to the inference steps, as required by some model families.
|
||||
use_dynamic_shifting (`bool`, defaults to `False`):
|
||||
Whether to use dynamic shifting for the noise schedule.
|
||||
time_shift_type (`"exponential"`, defaults to `"exponential"`):
|
||||
The type of time shifting to apply.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -137,29 +147,38 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "deis",
|
||||
solver_type: str = "logrho",
|
||||
algorithm_type: Literal["deis"] = "deis",
|
||||
solver_type: Literal["logrho"] = "logrho",
|
||||
lower_order_final: bool = True,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
use_flow_sigmas: Optional[bool] = False,
|
||||
flow_shift: Optional[float] = 1.0,
|
||||
timestep_spacing: str = "linspace",
|
||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
|
||||
steps_offset: int = 0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
time_shift_type: str = "exponential",
|
||||
):
|
||||
time_shift_type: Literal["exponential"] = "exponential",
|
||||
) -> None:
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
self.config.use_beta_sigmas,
|
||||
self.config.use_exponential_sigmas,
|
||||
self.config.use_karras_sigmas,
|
||||
]
|
||||
)
|
||||
> 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
@@ -169,7 +188,15 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -211,21 +238,21 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
def step_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
def begin_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
@@ -236,8 +263,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
|
||||
):
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
mu: Optional[float] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -246,6 +276,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
mu (`float`, *optional*):
|
||||
The mu parameter for dynamic shifting. Only used when `use_dynamic_shifting=True` and
|
||||
`time_shift_type="exponential"`.
|
||||
"""
|
||||
if mu is not None:
|
||||
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
||||
@@ -363,7 +396,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert sigma values to corresponding timestep values through interpolation.
|
||||
|
||||
@@ -400,7 +433,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
@@ -422,7 +455,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
||||
Models](https://huggingface.co/papers/2206.00364).
|
||||
@@ -648,7 +681,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
sigma_t, sigma_s = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
)
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
@@ -714,7 +750,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||
|
||||
rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
|
||||
rho_t, rho_s0, rho_s1 = (
|
||||
sigma_t / alpha_t,
|
||||
sigma_s0 / alpha_s0,
|
||||
sigma_s1 / alpha_s1,
|
||||
)
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
|
||||
@@ -854,7 +894,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
@@ -884,18 +924,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`int`):
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
return_dict (`bool`):
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
@@ -1000,5 +1039,5 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -19,6 +19,7 @@ import unittest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.hooks import HookRegistry, ModelHook
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
@@ -149,6 +150,74 @@ class LayerOutputTrackerHook(ModelHook):
|
||||
return output
|
||||
|
||||
|
||||
# Model with only standalone computational layers at top level
|
||||
class DummyModelWithStandaloneLayers(ModelMixin):
|
||||
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layer1 = torch.nn.Linear(in_features, hidden_features)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.layer2 = torch.nn.Linear(hidden_features, hidden_features)
|
||||
self.layer3 = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.layer1(x)
|
||||
x = self.activation(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
return x
|
||||
|
||||
|
||||
# Model with deeply nested structure
|
||||
class DummyModelWithDeeplyNestedBlocks(ModelMixin):
|
||||
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.input_layer = torch.nn.Linear(in_features, hidden_features)
|
||||
self.container = ContainerWithNestedModuleList(hidden_features)
|
||||
self.output_layer = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.input_layer(x)
|
||||
x = self.container(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class ContainerWithNestedModuleList(torch.nn.Module):
|
||||
def __init__(self, features: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Top-level computational layer
|
||||
self.proj_in = torch.nn.Linear(features, features)
|
||||
|
||||
# Nested container with ModuleList
|
||||
self.nested_container = NestedContainer(features)
|
||||
|
||||
# Another top-level computational layer
|
||||
self.proj_out = torch.nn.Linear(features, features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj_in(x)
|
||||
x = self.nested_container(x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class NestedContainer(torch.nn.Module):
|
||||
def __init__(self, features: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.blocks = torch.nn.ModuleList([torch.nn.Linear(features, features), torch.nn.Linear(features, features)])
|
||||
self.norm = torch.nn.LayerNorm(features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class GroupOffloadTests(unittest.TestCase):
|
||||
in_features = 64
|
||||
@@ -340,7 +409,7 @@ class GroupOffloadTests(unittest.TestCase):
|
||||
out = model(x)
|
||||
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
|
||||
|
||||
num_repeats = 4
|
||||
num_repeats = 2
|
||||
for i in range(num_repeats):
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
@@ -362,3 +431,138 @@ class GroupOffloadTests(unittest.TestCase):
|
||||
self.assertLess(
|
||||
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
|
||||
)
|
||||
|
||||
def test_vae_like_model_without_streams(self):
|
||||
"""Test VAE-like model with block-level offloading but without streams."""
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
config = self.get_autoencoder_kl_config()
|
||||
model = AutoencoderKL(**config)
|
||||
|
||||
model_ref = AutoencoderKL(**config)
|
||||
model_ref.load_state_dict(model.state_dict(), strict=True)
|
||||
model_ref.to(torch_device)
|
||||
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False)
|
||||
|
||||
x = torch.randn(2, 3, 32, 32).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ref = model_ref(x).sample
|
||||
out = model(x).sample
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
|
||||
)
|
||||
|
||||
def test_model_with_only_standalone_layers(self):
|
||||
"""Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
model = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64)
|
||||
|
||||
model_ref = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64)
|
||||
model_ref.load_state_dict(model.state_dict(), strict=True)
|
||||
model_ref.to(torch_device)
|
||||
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
|
||||
|
||||
x = torch.randn(2, 64).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(2):
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref, out, atol=1e-5),
|
||||
f"Outputs do not match at iteration {i} for model with standalone layers.",
|
||||
)
|
||||
|
||||
@parameterized.expand([("block_level",), ("leaf_level",)])
|
||||
def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
|
||||
"""Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
config = self.get_autoencoder_kl_config()
|
||||
model = AutoencoderKL(**config)
|
||||
|
||||
model_ref = AutoencoderKL(**config)
|
||||
model_ref.load_state_dict(model.state_dict(), strict=True)
|
||||
model_ref.to(torch_device)
|
||||
|
||||
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
|
||||
|
||||
x = torch.randn(2, 3, 32, 32).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ref = model_ref(x).sample
|
||||
out = model(x).sample
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref, out, atol=1e-5),
|
||||
f"Outputs do not match for standalone Conv layers with {offload_type}.",
|
||||
)
|
||||
|
||||
def test_multiple_invocations_with_vae_like_model(self):
|
||||
"""Test that multiple forward passes work correctly with VAE-like model."""
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
config = self.get_autoencoder_kl_config()
|
||||
model = AutoencoderKL(**config)
|
||||
|
||||
model_ref = AutoencoderKL(**config)
|
||||
model_ref.load_state_dict(model.state_dict(), strict=True)
|
||||
model_ref.to(torch_device)
|
||||
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
|
||||
|
||||
x = torch.randn(2, 3, 32, 32).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(2):
|
||||
out_ref = model_ref(x).sample
|
||||
out = model(x).sample
|
||||
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")
|
||||
|
||||
def test_nested_container_parameters_offloading(self):
|
||||
"""Test that parameters from non-computational layers in nested containers are handled correctly."""
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
|
||||
|
||||
model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
|
||||
model_ref.load_state_dict(model.state_dict(), strict=True)
|
||||
model_ref.to(torch_device)
|
||||
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
|
||||
|
||||
x = torch.randn(2, 64).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(2):
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref, out, atol=1e-5),
|
||||
f"Outputs do not match at iteration {i} for nested parameters.",
|
||||
)
|
||||
|
||||
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [2, 4]
|
||||
norm_num_groups = norm_num_groups or 2
|
||||
init_dict = {
|
||||
"block_out_channels": block_out_channels,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
|
||||
"up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": norm_num_groups,
|
||||
"layers_per_block": 1,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
@@ -1791,7 +1791,6 @@ class ModelTesterMixin:
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = run_forward(model)
|
||||
output_without_group_offloading = normalize_output(output_without_group_offloading)
|
||||
@@ -1916,6 +1915,9 @@ class ModelTesterMixin:
|
||||
offload_to_disk_path=tmpdir,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
block_modules=model._group_offload_block_modules
|
||||
if hasattr(model, "_group_offload_block_modules")
|
||||
else None,
|
||||
)
|
||||
if not is_correct:
|
||||
if extra_files:
|
||||
|
||||
@@ -1424,6 +1424,8 @@ if is_torch_available():
|
||||
offload_to_disk_path: str,
|
||||
offload_type: str,
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
block_modules: Optional[List[str]] = None,
|
||||
module_prefix: str = "",
|
||||
) -> Set[str]:
|
||||
expected_files = set()
|
||||
|
||||
@@ -1435,23 +1437,36 @@ if is_torch_available():
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
|
||||
|
||||
# Handle groups of ModuleList and Sequential blocks
|
||||
block_modules_set = set(block_modules) if block_modules is not None else set()
|
||||
|
||||
modules_with_group_offloading = set()
|
||||
unmatched_modules = []
|
||||
for name, submodule in module.named_children():
|
||||
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
||||
unmatched_modules.append(module)
|
||||
continue
|
||||
if name in block_modules_set:
|
||||
new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}."
|
||||
submodule_files = _get_expected_safetensors_files(
|
||||
submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix
|
||||
)
|
||||
expected_files.update(submodule_files)
|
||||
modules_with_group_offloading.add(name)
|
||||
|
||||
for i in range(0, len(submodule), num_blocks_per_group):
|
||||
current_modules = submodule[i : i + num_blocks_per_group]
|
||||
if not current_modules:
|
||||
continue
|
||||
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
|
||||
expected_files.add(get_hashed_filename(group_id))
|
||||
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
||||
for i in range(0, len(submodule), num_blocks_per_group):
|
||||
current_modules = submodule[i : i + num_blocks_per_group]
|
||||
if not current_modules:
|
||||
continue
|
||||
group_id = f"{module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
|
||||
expected_files.add(get_hashed_filename(group_id))
|
||||
for j in range(i, i + len(current_modules)):
|
||||
modules_with_group_offloading.add(f"{name}.{j}")
|
||||
else:
|
||||
unmatched_modules.append(submodule)
|
||||
|
||||
# Handle the group for unmatched top-level modules and parameters
|
||||
for module in unmatched_modules:
|
||||
expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
|
||||
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
||||
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
||||
|
||||
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
|
||||
expected_files.add(get_hashed_filename(f"{module_prefix}{module.__class__.__name__}_unmatched_group"))
|
||||
|
||||
elif offload_type == "leaf_level":
|
||||
# Handle leaf-level module groups
|
||||
@@ -1492,12 +1507,13 @@ if is_torch_available():
|
||||
offload_to_disk_path: str,
|
||||
offload_type: str,
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
block_modules: Optional[List[str]] = None,
|
||||
) -> bool:
|
||||
if not os.path.isdir(offload_to_disk_path):
|
||||
return False, None, None
|
||||
|
||||
expected_files = _get_expected_safetensors_files(
|
||||
module, offload_to_disk_path, offload_type, num_blocks_per_group
|
||||
module, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules
|
||||
)
|
||||
actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
|
||||
missing_files = expected_files - actual_files
|
||||
|
||||
Reference in New Issue
Block a user