Compare commits

...

7 Commits

Author SHA1 Message Date
Dhruv Nair
211ddd7f9e update 2025-03-18 13:14:21 +01:00
Dhruv Nair
dd2135769e update 2025-03-18 12:56:07 +01:00
Dhruv Nair
60bcc74f28 update 2025-03-18 11:56:11 +01:00
Dhruv Nair
60f468a926 Merge branch 'main' into pinned-context 2025-03-18 11:52:43 +01:00
Dhruv Nair
aceff93e28 update 2025-03-18 11:34:45 +01:00
Dhruv Nair
d0d81fbdeb update 2025-03-17 20:37:35 +01:00
Dhruv Nair
e793adc465 update 2025-03-17 15:22:49 +01:00
3 changed files with 439 additions and 63 deletions

View File

@@ -22,18 +22,357 @@
<!-- TODO(aryan): update abstract once paper is out --> <!-- TODO(aryan): update abstract once paper is out -->
## Generating Videos with Wan 2.1
We will first need to install some addtional dependencies.
```shell
pip install -u ftfy imageio-ffmpeg imageio
```
### Text to Video Generation
The following example requires 11GB VRAM to run and uses the smaller `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` model. You can switch it out
for the larger `Wan2.1-I2V-14B-720P-Diffusers` or `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` if you have at least 35GB VRAM available.
```python
from diffusers import WanPipeline
from diffusers.utils import export_to_video
# Available models: Wan-AI/Wan2.1-I2V-14B-720P-Diffusers or Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
frames = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames).frames[0]
export_to_video(frames, "wan-t2v.mp4", fps=16)
```
<Tip> <Tip>
You can improve the quality of the generated video by running the decoding step in full precision.
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip> </Tip>
Recommendations for inference: ```python
- VAE in `torch.float32` for better decoding quality. from diffusers import WanPipeline, AutoencoderKLWan
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`. from diffusers.utils import export_to_video
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
### Using a custom scheduler model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
# replace this with pipe.to("cuda") if you have sufficient VRAM
pipe.enable_model_cpu_offload()
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
frames = pipe(prompt=prompt, num_frames=num_frames).frames[0]
export_to_video(frames, "wan-t2v.mp4", fps=16)
```
### Image to Video Generation
The Image to Video pipeline requires loading the `AutoencoderKLWan` and the `CLIPVisionModel` components in full precision. The following example will need at least
35GB of VRAM to run.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
# replace this with pipe.to("cuda") if you have sufficient VRAM
pipe.enable_model_cpu_offload()
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
## Memory Optimizations for Wan 2.1
Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
We'll use `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` model in these examples to demonstrate the memory savings, but the techniques are applicable to all model checkpoints.
### Group Offloading the Transformer and UMT5 Text Encoder
Find more information about group offloading [here](../optimization/memory.md)
#### Block Level Group Offloading
We can reduce our VRAM requirements by applying group offloading to the larger model components of the pipeline; the `WanTransformer3DModel` and `UMT5EncoderModel`. Group offloading will break up the individual modules of a model and offload/onload them onto your GPU as needed during inference. In this example, we'll apply `block_level` offloading, which will group the modules in a model into blocks of size `num_blocks_per_group` and offload/onload them to GPU. Moving to between CPU and GPU does add latency to the inference process. You can trade off between latency and memory savings by increasing or decreasing the `num_blocks_per_group`.
The following example will now only require 14GB of VRAM to run, but will take approximately 30 minutes to generate a video.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
apply_group_offloading(text_encoder,
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4
)
transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4,
)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
)
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
pipe.to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 720 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
#### Block Level Group Offloading with CUDA Streams
We can speed up group offloading inference, by enabling the use of [CUDA streams](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html). However, using CUDA streams requires moving the model parameters into pinned memory. This allocation is handled by Pytorch under the hood, and can result in a significant spike in CPU RAM usage. Please consider this option if your CPU RAM is atleast 2X the size of the model you are group offloading.
In the following example we will use CUDA streams when group offloading the `WanTransformer3DModel`. When testing on an A100, this example will require 14GB of VRAM, 52GB of CPU RAM, but will generate a video in approximately 9 minutes.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
apply_group_offloading(text_encoder,
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4
)
transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True
)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
)
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
pipe.to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 720 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### Applying Layerwise Casting to the Transformer
Find more information about layerwise casting [here](../optimization/memory.md)
In this example, we will model offloading with layerwise casting. Layerwise casting will downcast each layer's weights to `torch.float8_e4m3fn`, temporarily upcast to `torch.bfloat16` during the forward pass of the layer, then revert to `torch.float8_e4m3fn` afterward. This approach reduces memory requirements by approximately 50% while introducing a minor quality reduction in the generated video due to the precision trade-off.
This example will require 20GB of VRAM.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionMode
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")
max_area = 720 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### Using a Custom Scheduler
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows: Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
@@ -49,11 +388,10 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler
pipe.scheduler = <CUSTOM_SCHEDULER_HERE> pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
``` ```
### Using single file loading with Wan ## Using Single File Loading with Wan 2.1
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
method.
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
method.
```python ```python
import torch import torch
@@ -65,6 +403,11 @@ transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torc
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer) pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
``` ```
## Recommendations for Inference:
- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`.
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
## WanPipeline ## WanPipeline
[[autodoc]] WanPipeline [[autodoc]] WanPipeline

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import nullcontext from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
import torch import torch
@@ -56,7 +56,7 @@ class ModuleGroup:
buffers: Optional[List[torch.Tensor]] = None, buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False, non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None, stream: Optional[torch.cuda.Stream] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, low_cpu_mem_usage=False,
onload_self: bool = True, onload_self: bool = True,
) -> None: ) -> None:
self.modules = modules self.modules = modules
@@ -64,15 +64,42 @@ class ModuleGroup:
self.onload_device = onload_device self.onload_device = onload_device
self.offload_leader = offload_leader self.offload_leader = offload_leader
self.onload_leader = onload_leader self.onload_leader = onload_leader
self.parameters = parameters self.parameters = parameters or []
self.buffers = buffers self.buffers = buffers or []
self.non_blocking = non_blocking or stream is not None self.non_blocking = non_blocking or stream is not None
self.stream = stream self.stream = stream
self.cpu_param_dict = cpu_param_dict
self.onload_self = onload_self self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
if self.stream is not None and self.cpu_param_dict is None: self.cpu_param_dict = {}
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.") for module in self.modules:
for param in module.parameters():
self.cpu_param_dict[param] = (
param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
)
for param in self.parameters:
self.cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
for buffer in self.buffers:
self.cpu_param_dict[buffer] = (
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
)
@contextmanager
def _pinned_memory_tensors(self):
pinned_dict = {}
try:
for param, tensor in self.cpu_param_dict.items():
if not tensor.is_pinned():
pinned_dict[param] = tensor.pin_memory()
else:
pinned_dict[param] = tensor
yield pinned_dict
finally:
pinned_dict = None
def onload_(self): def onload_(self):
r"""Onloads the group of modules to the onload_device.""" r"""Onloads the group of modules to the onload_device."""
@@ -82,17 +109,32 @@ class ModuleGroup:
self.stream.synchronize() self.stream.synchronize()
with context: with context:
for group_module in self.modules: if self.stream is not None:
for param in group_module.parameters(): with self._pinned_memory_tensors() as pinned_memory:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) for group_module in self.modules:
for buffer in group_module.buffers(): for param in group_module.parameters():
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
for param in self.parameters: if self.parameters is not None:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) for param in self.parameters:
if self.buffers is not None: param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) if self.buffers is not None:
for buffer in self.buffers:
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
else:
for group_module in self.modules:
for param in group_module.parameters():
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
def offload_(self): def offload_(self):
r"""Offloads the group of modules to the offload_device.""" r"""Offloads the group of modules to the offload_device."""
@@ -108,12 +150,12 @@ class ModuleGroup:
for buffer in self.buffers: for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer] buffer.data = self.cpu_param_dict[buffer]
else: else:
for group_module in self.modules: for module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking) module.to(self.offload_device, non_blocking=self.non_blocking)
if self.parameters is not None: if self.parameters:
for param in self.parameters: for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
if self.buffers is not None: if self.buffers:
for buffer in self.buffers: for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
@@ -284,6 +326,7 @@ def apply_group_offloading(
num_blocks_per_group: Optional[int] = None, num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False, non_blocking: bool = False,
use_stream: bool = False, use_stream: bool = False,
low_cpu_mem_usage=False,
) -> None: ) -> None:
r""" r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -365,10 +408,12 @@ def apply_group_offloading(
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
_apply_group_offloading_block_level( _apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
) )
elif offload_type == "leaf_level": elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream) _apply_group_offloading_leaf_level(
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
)
else: else:
raise ValueError(f"Unsupported offload_type: {offload_type}") raise ValueError(f"Unsupported offload_type: {offload_type}")
@@ -380,6 +425,7 @@ def _apply_group_offloading_block_level(
onload_device: torch.device, onload_device: torch.device,
non_blocking: bool, non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None, stream: Optional[torch.cuda.Stream] = None,
low_cpu_mem_usage: bool = False,
) -> None: ) -> None:
r""" r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -400,11 +446,6 @@ def _apply_group_offloading_block_level(
for overlapping computation and data transfer. for overlapping computation and data transfer.
""" """
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
cpu_param_dict = _get_pinned_cpu_param_dict(module)
# Create module groups for ModuleList and Sequential blocks # Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set() modules_with_group_offloading = set()
unmatched_modules = [] unmatched_modules = []
@@ -425,7 +466,7 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0], onload_leader=current_modules[0],
non_blocking=non_blocking, non_blocking=non_blocking,
stream=stream, stream=stream,
cpu_param_dict=cpu_param_dict, low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=stream is None, onload_self=stream is None,
) )
matched_module_groups.append(group) matched_module_groups.append(group)
@@ -462,7 +503,6 @@ def _apply_group_offloading_block_level(
buffers=buffers, buffers=buffers,
non_blocking=False, non_blocking=False,
stream=None, stream=None,
cpu_param_dict=None,
onload_self=True, onload_self=True,
) )
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -475,6 +515,7 @@ def _apply_group_offloading_leaf_level(
onload_device: torch.device, onload_device: torch.device,
non_blocking: bool, non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None, stream: Optional[torch.cuda.Stream] = None,
low_cpu_mem_usage: bool = False,
) -> None: ) -> None:
r""" r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -497,11 +538,6 @@ def _apply_group_offloading_leaf_level(
for overlapping computation and data transfer. for overlapping computation and data transfer.
""" """
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
cpu_param_dict = _get_pinned_cpu_param_dict(module)
# Create module groups for leaf modules and apply group offloading hooks # Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set() modules_with_group_offloading = set()
for name, submodule in module.named_modules(): for name, submodule in module.named_modules():
@@ -515,7 +551,7 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule, onload_leader=submodule,
non_blocking=non_blocking, non_blocking=non_blocking,
stream=stream, stream=stream,
cpu_param_dict=cpu_param_dict, low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
_apply_group_offloading_hook(submodule, group, None) _apply_group_offloading_hook(submodule, group, None)
@@ -560,7 +596,7 @@ def _apply_group_offloading_leaf_level(
buffers=buffers, buffers=buffers,
non_blocking=non_blocking, non_blocking=non_blocking,
stream=stream, stream=stream,
cpu_param_dict=cpu_param_dict, low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
_apply_group_offloading_hook(parent_module, group, None) _apply_group_offloading_hook(parent_module, group, None)
@@ -579,7 +615,7 @@ def _apply_group_offloading_leaf_level(
buffers=None, buffers=None,
non_blocking=False, non_blocking=False,
stream=None, stream=None,
cpu_param_dict=None, low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
_apply_lazy_group_offloading_hook(module, unmatched_group, None) _apply_lazy_group_offloading_hook(module, unmatched_group, None)
@@ -616,17 +652,6 @@ def _apply_lazy_group_offloading_hook(
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
cpu_param_dict = {}
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict[param] = param.data
for buffer in module.buffers():
buffer.data = buffer.data.cpu().pin_memory()
cpu_param_dict[buffer] = buffer.data
return cpu_param_dict
def _gather_parameters_with_no_group_offloading_parent( def _gather_parameters_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str] module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> List[torch.nn.Parameter]: ) -> List[torch.nn.Parameter]:

View File

@@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
num_blocks_per_group: Optional[int] = None, num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False, non_blocking: bool = False,
use_stream: bool = False, use_stream: bool = False,
low_cpu_mem_usage=False,
) -> None: ) -> None:
r""" r"""
Activates group offloading for the current model. Activates group offloading for the current model.
@@ -584,7 +585,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"open an issue at https://github.com/huggingface/diffusers/issues." f"open an issue at https://github.com/huggingface/diffusers/issues."
) )
apply_group_offloading( apply_group_offloading(
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream self,
onload_device,
offload_device,
offload_type,
num_blocks_per_group,
non_blocking,
use_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
) )
def save_pretrained( def save_pretrained(