Compare commits

..

7 Commits

Author SHA1 Message Date
Steven Liu
511056a4e3 [docs] Pipeline group offloading (#12286)
init

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-09-10 08:26:36 +05:30
sayakpaul
8e2d0383e1 up 2025-09-10 08:12:16 +05:30
Sayak Paul
aa0cafbc5e Merge branch 'main' into support-group-offloading-pipeline-level 2025-09-10 08:01:14 +05:30
Sayak Paul
1a8ebf6513 Merge branch 'main' into support-group-offloading-pipeline-level 2025-09-08 12:58:42 +05:30
Sayak Paul
506424c09b Merge branch 'main' into support-group-offloading-pipeline-level 2025-09-05 06:49:14 +05:30
sayakpaul
e141f5cfd0 add tests 2025-09-04 11:49:54 +05:30
sayakpaul
25d9c70d1c feat: support group offloading at the pipeline level. 2025-09-04 11:11:11 +05:30
10 changed files with 247 additions and 1356 deletions

View File

@@ -254,8 +254,8 @@ export_to_video(video, "output.mp4", fps=24)
pipeline.vae.enable_tiling()
def round_to_nearest_resolution_acceptable_by_vae(height, width):
height = height - (height % pipeline.vae_spatial_compression_ratio)
width = width - (width % pipeline.vae_spatial_compression_ratio)
height = height - (height % pipeline.vae_temporal_compression_ratio)
width = width - (width % pipeline.vae_temporal_compression_ratio)
return height, width
prompt = """
@@ -325,95 +325,6 @@ export_to_video(video, "output.mp4", fps=24)
</details>
- LTX-Video 0.9.8 distilled model is similar to the 0.9.7 variant. It is guidance and timestep-distilled, and similar inference code can be used as above. An improvement of this version is that it supports generating very long videos. Additionally, it supports using tone mapping to improve the quality of the generated video using the `tone_map_compression_ratio` parameter. The default value of `0.6` is recommended.
<details>
<summary>Show example code</summary>
```python
import torch
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
from diffusers.utils import export_to_video, load_video
pipeline = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.8-13B-distilled", torch_dtype=torch.bfloat16)
# TODO: Update the checkpoint here once updated in LTX org
upsampler = LTXLatentUpsamplerModel.from_pretrained("a-r-r-o-w/LTX-0.9.8-Latent-Upsampler", torch_dtype=torch.bfloat16)
pipe_upsample = LTXLatentUpsamplePipeline(vae=pipeline.vae, latent_upsampler=upsampler).to(torch.bfloat16)
pipeline.to("cuda")
pipe_upsample.to("cuda")
pipeline.vae.enable_tiling()
def round_to_nearest_resolution_acceptable_by_vae(height, width):
height = height - (height % pipeline.vae_spatial_compression_ratio)
width = width - (width % pipeline.vae_spatial_compression_ratio)
return height, width
prompt = """The camera pans over a snow-covered mountain range, revealing a vast expanse of snow-capped peaks and valleys.The mountains are covered in a thick layer of snow, with some areas appearing almost white while others have a slightly darker, almost grayish hue. The peaks are jagged and irregular, with some rising sharply into the sky while others are more rounded. The valleys are deep and narrow, with steep slopes that are also covered in snow. The trees in the foreground are mostly bare, with only a few leaves remaining on their branches. The sky is overcast, with thick clouds obscuring the sun. The overall impression is one of peace and tranquility, with the snow-covered mountains standing as a testament to the power and beauty of nature."""
# prompt = """A woman walks away from a white Jeep parked on a city street at night, then ascends a staircase and knocks on a door. The woman, wearing a dark jacket and jeans, walks away from the Jeep parked on the left side of the street, her back to the camera; she walks at a steady pace, her arms swinging slightly by her sides; the street is dimly lit, with streetlights casting pools of light on the wet pavement; a man in a dark jacket and jeans walks past the Jeep in the opposite direction; the camera follows the woman from behind as she walks up a set of stairs towards a building with a green door; she reaches the top of the stairs and turns left, continuing to walk towards the building; she reaches the door and knocks on it with her right hand; the camera remains stationary, focused on the doorway; the scene is captured in real-life footage."""
negative_prompt = "bright colors, symbols, graffiti, watermarks, worst quality, inconsistent motion, blurry, jittery, distorted"
expected_height, expected_width = 480, 832
downscale_factor = 2 / 3
# num_frames = 161
num_frames = 361
# 1. Generate video at smaller resolution
downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
latents = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=downscaled_width,
height=downscaled_height,
num_frames=num_frames,
timesteps=[1000, 993, 987, 981, 975, 909, 725, 0.03],
decode_timestep=0.05,
decode_noise_scale=0.025,
image_cond_noise_scale=0.0,
guidance_scale=1.0,
guidance_rescale=0.7,
generator=torch.Generator().manual_seed(0),
output_type="latent",
).frames
# 2. Upscale generated video using latent upsampler with fewer inference steps
# The available latent upsampler upscales the height/width by 2x
upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
upscaled_latents = pipe_upsample(
latents=latents,
adain_factor=1.0,
tone_map_compression_ratio=0.6,
output_type="latent"
).frames
# 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
video = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=upscaled_width,
height=upscaled_height,
num_frames=num_frames,
denoise_strength=0.999, # Effectively, 4 inference steps out of 5
timesteps=[1000, 909, 725, 421, 0],
latents=upscaled_latents,
decode_timestep=0.05,
decode_noise_scale=0.025,
image_cond_noise_scale=0.0,
guidance_scale=1.0,
guidance_rescale=0.7,
generator=torch.Generator().manual_seed(0),
output_type="pil",
).frames[0]
# 4. Downscale the video to the expected resolution
video = [frame.resize((expected_width, expected_height)) for frame in video]
export_to_video(video, "output.mp4", fps=24)
```
</details>
- LTX-Video supports LoRAs with [`~loaders.LTXVideoLoraLoaderMixin.load_lora_weights`].
<details>

View File

@@ -291,13 +291,53 @@ Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://
> [!WARNING]
> Group offloading may not work with all models if the forward implementation contains weight-dependent device casting of inputs because it may clash with group offloading's device casting mechanism.
Call [`~ModelMixin.enable_group_offload`] to enable it for standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.
The `offload_type` parameter can be set to `block_level` or `leaf_level`.
Enable group offloading by configuring the `offload_type` parameter to `block_level` or `leaf_level`.
- `block_level` offloads groups of layers based on the `num_blocks_per_group` parameter. For example, if `num_blocks_per_group=2` on a model with 40 layers, 2 layers are onloaded and offloaded at a time (20 total onloads/offloads). This drastically reduces memory requirements.
- `leaf_level` offloads individual layers at the lowest level and is equivalent to [CPU offloading](#cpu-offloading). But it can be made faster if you use streams without giving up inference speed.
Group offloading is supported for entire pipelines or individual models. Applying group offloading to the entire pipeline is the easiest option while selectively applying it to individual models gives users more flexibility to use different offloading techniques for different models.
<hfoptions id="group-offloading">
<hfoption id="pipeline">
Call [`~DiffusionPipeline.enable_group_offload`] on a pipeline.
```py
import torch
from diffusers import CogVideoXPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipeline.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True
)
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
export_to_video(video, "output.mp4", fps=8)
```
</hfoption>
<hfoption id="model">
Call [`~ModelMixin.enable_group_offload`] on standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.
```py
import torch
from diffusers import CogVideoXPipeline
@@ -328,6 +368,9 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
export_to_video(video, "output.mp4", fps=8)
```
</hfoption>
</hfoptions>
#### CUDA stream
The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.

View File

@@ -369,15 +369,6 @@ def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]:
"spatial_upsample": True,
"temporal_upsample": False,
}
elif version == "0.9.8":
config = {
"in_channels": 128,
"mid_channels": 512,
"num_blocks_per_stage": 4,
"dims": 3,
"spatial_upsample": True,
"temporal_upsample": False,
}
else:
raise ValueError(f"Unsupported version: {version}")
return config
@@ -411,7 +402,7 @@ def get_args():
"--version",
type=str,
default="0.9.0",
choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7", "0.9.8"],
choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7"],
help="Version of the LTX model",
)
return parser.parse_args()

View File

@@ -491,7 +491,6 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LTXConditionInfinitePipeline",
"LTXConditionPipeline",
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
@@ -1146,7 +1145,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LTXConditionInfinitePipeline,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,

View File

@@ -281,7 +281,6 @@ else:
"LTXPipeline",
"LTXImageToVideoPipeline",
"LTXConditionPipeline",
"LTXConditionInfinitePipeline",
"LTXLatentUpsamplePipeline",
]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
@@ -682,13 +681,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import (
LTXConditionInfinitePipeline,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
)
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (

View File

@@ -25,7 +25,6 @@ else:
_import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"]
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_condition_infinite"] = ["LTXConditionInfinitePipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
_import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
@@ -40,7 +39,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_condition_infinite import LTXConditionInfinitePipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline

File diff suppressed because it is too large Load Diff

View File

@@ -93,8 +93,7 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
return init_latents
@staticmethod
def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
"""
Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent
tensor.
@@ -122,39 +121,6 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
result = torch.lerp(latents, result, factor)
return result
@staticmethod
def tone_map_latents(latents: torch.Tensor, compression: float) -> torch.Tensor:
"""
Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
smooth way using a sigmoid-based compression.
This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially
when controlling dynamic behavior with a `compression` factor.
Args:
latents : torch.Tensor
Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
compression : float
Compression strength in the range [0, 1].
- 0.0: No tone-mapping (identity transform)
- 1.0: Full compression effect
Returns:
torch.Tensor
The tone-mapped latent tensor of the same shape as input.
"""
# Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
scale_factor = compression * 0.75
abs_latents = torch.abs(latents)
# Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
# When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
scales = 1.0 - 0.8 * scale_factor * sigmoid_term
filtered = latents * scales
return filtered
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
def _normalize_latents(
@@ -206,7 +172,7 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
"""
self.vae.disable_tiling()
def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
def check_inputs(self, video, height, width, latents):
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
@@ -215,9 +181,6 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
if video is None and latents is None:
raise ValueError("One of `video` or `latents` has to be provided.")
if not (0 <= tone_map_compression_ratio <= 1):
raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
@torch.no_grad()
def __call__(
self,
@@ -228,7 +191,6 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
adain_factor: float = 0.0,
tone_map_compression_ratio: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
@@ -238,7 +200,6 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
height=height,
width=width,
latents=latents,
tone_map_compression_ratio=tone_map_compression_ratio,
)
if video is not None:
@@ -281,9 +242,6 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
else:
latents = latents_upsampled
if tone_map_compression_ratio > 0.0:
latents = self.tone_map_latents(latents, tone_map_compression_ratio)
if output_type == "latent":
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor

View File

@@ -1334,6 +1334,133 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
offload_buffers = len(model._parameters) > 0
cpu_offload(model, device, offload_buffers=offload_buffers)
def enable_group_offload(
self,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage=False,
offload_to_disk_path: Optional[str] = None,
exclude_modules: Optional[Union[str, List[str]]] = None,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is,
and where it is beneficial, we need to first provide some context on how other supported offloading methods
work.
Typically, offloading is done at two levels:
- Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator
device when needed for computation. This method is more memory-efficient than keeping all components on the
accelerator, but the memory requirements are still quite high. For this method to work, one needs memory
equivalent to size of the model in runtime dtype + size of largest intermediate activation tensors to be able
to complete the forward pass.
- Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method.
It
works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
memory, but can be slower due to the excessive number of device synchronizations.
Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
(either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations
is reduced.
Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability
to overlap data transfer and computation to reduce the overall execution time compared to sequential
offloading. This is enabled using layer prefetching with streams, i.e., the layer that is to be executed next
starts onloading to the accelerator device while the current layer is being executed - this increases the
memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made
much faster when using streams.
Args:
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
The device to which the group of modules are offloaded. This should typically be the CPU. Default is
CPU.
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level".
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in
limited RAM environment settings where a reasonable speed-memory trade-off is desired.
num_blocks_per_group (`int`, *optional*):
The number of blocks per group when using offload_type="block_level". This is required when using
offload_type="block_level".
non_blocking (`bool`, defaults to `False`):
If True, offloading and onloading is done with non-blocking data transfer.
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to
the [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html)
more details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them.
This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be
useful when the CPU memory is a bottleneck but may counteract the benefits of using streams.
exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading.
Example:
```python
>>> from diffusers import DiffusionPipeline
>>> import torch
>>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
>>> pipe.enable_group_offload(
... onload_device=torch.device("cuda"),
... offload_device=torch.device("cpu"),
... offload_type="leaf_level",
... use_stream=True,
... )
>>> image = pipe("a beautiful sunset").images[0]
```
"""
from ..hooks import apply_group_offloading
if isinstance(exclude_modules, str):
exclude_modules = [exclude_modules]
elif exclude_modules is None:
exclude_modules = []
unknown = set(exclude_modules) - self.components.keys()
if unknown:
logger.info(
f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected."
)
group_offload_kwargs = {
"onload_device": onload_device,
"offload_device": offload_device,
"offload_type": offload_type,
"num_blocks_per_group": num_blocks_per_group,
"non_blocking": non_blocking,
"use_stream": use_stream,
"record_stream": record_stream,
"low_cpu_mem_usage": low_cpu_mem_usage,
"offload_to_disk_path": offload_to_disk_path,
}
for name, component in self.components.items():
if name not in exclude_modules and isinstance(component, torch.nn.Module):
if hasattr(component, "enable_group_offload"):
component.enable_group_offload(**group_offload_kwargs)
else:
apply_group_offloading(module=component, **group_offload_kwargs)
if exclude_modules:
for module_name in exclude_modules:
module = getattr(self, module_name, None)
if module is not None and isinstance(module, torch.nn.Module):
module.to(onload_device)
logger.debug(f"Placed `{module_name}` on {onload_device} device as it was in `exclude_modules`.")
def reset_device_map(self):
r"""
Resets the device maps (if any) to None.

View File

@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, Union
import numpy as np
import PIL.Image
import pytest
import torch
import torch.nn as nn
from huggingface_hub import ModelCard, delete_repo
@@ -2362,6 +2363,73 @@ class PipelineTesterMixin:
max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()
self.assertLess(max_diff, expected_max_difference)
@require_torch_accelerator
def test_pipeline_level_group_offloading_sanity_checks(self):
components = self.get_dummy_components()
pipe: DiffusionPipeline = self.pipeline_class(**components)
for name, component in pipe.components.items():
if hasattr(component, "_supports_group_offloading"):
if not component._supports_group_offloading:
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
module_names = sorted(
[name for name, component in pipe.components.items() if isinstance(component, torch.nn.Module)]
)
exclude_module_name = module_names[0]
offload_device = "cpu"
pipe.enable_group_offload(
onload_device=torch_device,
offload_device=offload_device,
offload_type="leaf_level",
exclude_modules=exclude_module_name,
)
excluded_module = getattr(pipe, exclude_module_name)
self.assertTrue(torch.device(excluded_module.device).type == torch.device(torch_device).type)
for name, component in pipe.components.items():
if name not in [exclude_module_name] and isinstance(component, torch.nn.Module):
# `component.device` prints the `onload_device` type. We should probably override the
# `device` property in `ModelMixin`.
component_device = next(component.parameters())[0].device
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
@require_torch_accelerator
def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe: DiffusionPipeline = self.pipeline_class(**components)
for name, component in pipe.components.items():
if hasattr(component, "_supports_group_offloading"):
if not component._supports_group_offloading:
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
# Regular inference.
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = torch.manual_seed(0)
out = pipe(**inputs)[0]
pipe.to("cpu")
del pipe
# Inference with offloading
pipe: DiffusionPipeline = self.pipeline_class(**components)
offload_device = "cpu"
pipe.enable_group_offload(
onload_device=torch_device,
offload_device=offload_device,
offload_type="leaf_level",
)
pipe.set_progress_bar_config(disable=None)
inputs["generator"] = torch.manual_seed(0)
out_offload = pipe(**inputs)[0]
max_diff = np.abs(to_np(out) - to_np(out_offload)).max()
self.assertLess(max_diff, expected_max_difference)
@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):