mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-31 17:01:08 +08:00
Compare commits
8 Commits
chroma-doc
...
chen/fix-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fff9e31eb7 | ||
|
|
72e22e86c3 | ||
|
|
0874dd04dc | ||
|
|
6184d8a433 | ||
|
|
5a6e386464 | ||
|
|
42077e6c73 | ||
|
|
3d8d8485fc | ||
|
|
195926bbdc |
@@ -180,6 +180,8 @@
|
||||
title: Caching
|
||||
- local: optimization/memory
|
||||
title: Reduce memory usage
|
||||
- local: optimization/speed-memory-optims
|
||||
title: Compile and offloading quantized models
|
||||
- local: optimization/pruna
|
||||
title: Pruna
|
||||
- local: optimization/xformers
|
||||
|
||||
@@ -27,9 +27,36 @@ Chroma can use all the same optimizations as Flux.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Inference (Single File)
|
||||
## Inference
|
||||
|
||||
The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
|
||||
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ChromaPipeline
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
|
||||
pipe.enabe_model_cpu_offload()
|
||||
|
||||
prompt = [
|
||||
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
|
||||
]
|
||||
negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=torch.Generator("cpu").manual_seed(433),
|
||||
num_inference_steps=40,
|
||||
guidance_scale=3.0,
|
||||
num_images_per_prompt=1,
|
||||
).images[0]
|
||||
image.save("chroma.png")
|
||||
```
|
||||
|
||||
## Loading from a single file
|
||||
|
||||
To use updated model checkpoints that are not in the Diffusers format, you can use the `ChromaTransformer2DModel` class to load the model from a single file in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
|
||||
|
||||
The following example demonstrates how to run Chroma from a single file.
|
||||
|
||||
@@ -38,30 +65,29 @@ Then run the following example
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ChromaTransformer2DModel, ChromaPipeline
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
bfl_repo = "black-forest-labs/FLUX.1-dev"
|
||||
model_id = "lodestones/Chroma"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
prompt = [
|
||||
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
|
||||
]
|
||||
negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
guidance_scale=4.0,
|
||||
output_type="pil",
|
||||
num_inference_steps=26,
|
||||
generator=torch.Generator("cpu").manual_seed(0)
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=torch.Generator("cpu").manual_seed(433),
|
||||
num_inference_steps=40,
|
||||
guidance_scale=3.0,
|
||||
).images[0]
|
||||
|
||||
image.save("image.png")
|
||||
image.save("chroma-single-file.png")
|
||||
```
|
||||
|
||||
## ChromaPipeline
|
||||
@@ -69,3 +95,9 @@ image.save("image.png")
|
||||
[[autodoc]] ChromaPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ChromaImg2ImgPipeline
|
||||
|
||||
[[autodoc]] ChromaImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -17,7 +17,7 @@ Modern diffusion models like [Flux](../api/pipelines/flux) and [Wan](../api/pipe
|
||||
This guide will show you how to reduce your memory usage.
|
||||
|
||||
> [!TIP]
|
||||
> Keep in mind these techniques may need to be adjusted depending on the model! For example, a transformer-based diffusion model may not benefit equally from these inference speed optimizations as a UNet-based model.
|
||||
> Keep in mind these techniques may need to be adjusted depending on the model. For example, a transformer-based diffusion model may not benefit equally from these memory optimizations as a UNet-based model.
|
||||
|
||||
## Multiple GPUs
|
||||
|
||||
@@ -63,7 +63,12 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
> [!WARNING]
|
||||
> Device placement is an experimental feature and the API may change. Only the `balanced` strategy is supported at the moment. We plan to support additional mapping strategies in the future.
|
||||
|
||||
The `device_map` parameter controls how the model components in a pipeline are distributed across devices. The `balanced` device placement strategy evenly splits the pipeline across all available devices.
|
||||
The `device_map` parameter controls how the model components in a pipeline or the layers in an individual model are distributed across devices.
|
||||
|
||||
<hfoptions id="device-map">
|
||||
<hfoption id="pipeline level">
|
||||
|
||||
The `balanced` device placement strategy evenly splits the pipeline across all available devices.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -83,7 +88,10 @@ print(pipeline.hf_device_map)
|
||||
{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
|
||||
```
|
||||
|
||||
The `device_map` parameter also works on the model-level. This is useful for loading large models, such as the Flux diffusion transformer which has 12.5B parameters. Instead of `balanced`, set it to `"auto"` to automatically distribute a model across the fastest device first before moving to slower devices. Refer to the [Model sharding](../training/distributed_inference#model-sharding) docs for more details.
|
||||
</hfoption>
|
||||
<hfoption id="model level">
|
||||
|
||||
The `device_map` is useful for loading large models, such as the Flux diffusion transformer which has 12.5B parameters. Set it to `"auto"` to automatically distribute a model across the fastest device first before moving to slower devices. Refer to the [Model sharding](../training/distributed_inference#model-sharding) docs for more details.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -97,7 +105,43 @@ transformer = AutoModel.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
For more fine-grained control, pass a dictionary to enforce the maximum GPU memory to use on each device. If a device is not in `max_memory`, it is ignored and pipeline components won't be distributed to it.
|
||||
You can inspect a model's device map with `hf_device_map`.
|
||||
|
||||
```py
|
||||
print(transformer.hf_device_map)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
When designing your own `device_map`, it should be a dictionary of a model's specific module name or layer and a device identifier (an integer for GPUs, `cpu` for CPUs, and `disk` for disk).
|
||||
|
||||
Call `hf_device_map` on a model to see how model layers are distributed and then design your own.
|
||||
|
||||
```py
|
||||
print(transformer.hf_device_map)
|
||||
{'pos_embed': 0, 'time_text_embed': 0, 'context_embedder': 0, 'x_embedder': 0, 'transformer_blocks': 0, 'single_transformer_blocks.0': 0, 'single_transformer_blocks.1': 0, 'single_transformer_blocks.2': 0, 'single_transformer_blocks.3': 0, 'single_transformer_blocks.4': 0, 'single_transformer_blocks.5': 0, 'single_transformer_blocks.6': 0, 'single_transformer_blocks.7': 0, 'single_transformer_blocks.8': 0, 'single_transformer_blocks.9': 0, 'single_transformer_blocks.10': 'cpu', 'single_transformer_blocks.11': 'cpu', 'single_transformer_blocks.12': 'cpu', 'single_transformer_blocks.13': 'cpu', 'single_transformer_blocks.14': 'cpu', 'single_transformer_blocks.15': 'cpu', 'single_transformer_blocks.16': 'cpu', 'single_transformer_blocks.17': 'cpu', 'single_transformer_blocks.18': 'cpu', 'single_transformer_blocks.19': 'cpu', 'single_transformer_blocks.20': 'cpu', 'single_transformer_blocks.21': 'cpu', 'single_transformer_blocks.22': 'cpu', 'single_transformer_blocks.23': 'cpu', 'single_transformer_blocks.24': 'cpu', 'single_transformer_blocks.25': 'cpu', 'single_transformer_blocks.26': 'cpu', 'single_transformer_blocks.27': 'cpu', 'single_transformer_blocks.28': 'cpu', 'single_transformer_blocks.29': 'cpu', 'single_transformer_blocks.30': 'cpu', 'single_transformer_blocks.31': 'cpu', 'single_transformer_blocks.32': 'cpu', 'single_transformer_blocks.33': 'cpu', 'single_transformer_blocks.34': 'cpu', 'single_transformer_blocks.35': 'cpu', 'single_transformer_blocks.36': 'cpu', 'single_transformer_blocks.37': 'cpu', 'norm_out': 'cpu', 'proj_out': 'cpu'}
|
||||
```
|
||||
|
||||
For example, the `device_map` below places `single_transformer_blocks.10` through `single_transformer_blocks.20` on a second GPU (`1`).
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
|
||||
device_map = {
|
||||
'pos_embed': 0, 'time_text_embed': 0, 'context_embedder': 0, 'x_embedder': 0, 'transformer_blocks': 0, 'single_transformer_blocks.0': 0, 'single_transformer_blocks.1': 0, 'single_transformer_blocks.2': 0, 'single_transformer_blocks.3': 0, 'single_transformer_blocks.4': 0, 'single_transformer_blocks.5': 0, 'single_transformer_blocks.6': 0, 'single_transformer_blocks.7': 0, 'single_transformer_blocks.8': 0, 'single_transformer_blocks.9': 0, 'single_transformer_blocks.10': 1, 'single_transformer_blocks.11': 1, 'single_transformer_blocks.12': 1, 'single_transformer_blocks.13': 1, 'single_transformer_blocks.14': 1, 'single_transformer_blocks.15': 1, 'single_transformer_blocks.16': 1, 'single_transformer_blocks.17': 1, 'single_transformer_blocks.18': 1, 'single_transformer_blocks.19': 1, 'single_transformer_blocks.20': 1, 'single_transformer_blocks.21': 'cpu', 'single_transformer_blocks.22': 'cpu', 'single_transformer_blocks.23': 'cpu', 'single_transformer_blocks.24': 'cpu', 'single_transformer_blocks.25': 'cpu', 'single_transformer_blocks.26': 'cpu', 'single_transformer_blocks.27': 'cpu', 'single_transformer_blocks.28': 'cpu', 'single_transformer_blocks.29': 'cpu', 'single_transformer_blocks.30': 'cpu', 'single_transformer_blocks.31': 'cpu', 'single_transformer_blocks.32': 'cpu', 'single_transformer_blocks.33': 'cpu', 'single_transformer_blocks.34': 'cpu', 'single_transformer_blocks.35': 'cpu', 'single_transformer_blocks.36': 'cpu', 'single_transformer_blocks.37': 'cpu', 'norm_out': 'cpu', 'proj_out': 'cpu'
|
||||
}
|
||||
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
Pass a dictionary mapping maximum memory usage to each device to enforce a limit. If a device is not in `max_memory`, it is ignored and pipeline components won't be distributed to it.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -145,7 +189,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support slicing.
|
||||
> The [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] classes don't support slicing.
|
||||
|
||||
## VAE tiling
|
||||
|
||||
@@ -172,7 +216,13 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
|
||||
> [!WARNING]
|
||||
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support tiling.
|
||||
|
||||
## CPU offloading
|
||||
## Offloading
|
||||
|
||||
Offloading strategies move not currently active layers or models to the CPU to avoid increasing GPU memory. These strategies can be combined with quantization and torch.compile to balance inference speed and memory usage.
|
||||
|
||||
Refer to the [Compile and offloading quantized models](./speed-memory-optims) guide for more details.
|
||||
|
||||
### CPU offloading
|
||||
|
||||
CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.
|
||||
|
||||
@@ -203,7 +253,7 @@ pipeline(
|
||||
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
|
||||
```
|
||||
|
||||
## Model offloading
|
||||
### Model offloading
|
||||
|
||||
Model offloading moves entire models to the GPU instead of selectively moving *some* layers or model components. One of the main pipeline models, usually the text encoder, UNet, and VAE, is placed on the GPU while the other components are held on the CPU. Components like the UNet that run multiple times stays on the GPU until its completely finished and no longer needed. This eliminates the communication overhead of [CPU offloading](#cpu-offloading) and makes model offloading a faster alternative. The tradeoff is memory savings won't be as large.
|
||||
|
||||
@@ -219,7 +269,7 @@ from diffusers import DiffusionPipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipline.enable_model_cpu_offload()
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
pipeline(
|
||||
prompt="An astronaut riding a horse on Mars",
|
||||
@@ -234,7 +284,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
|
||||
|
||||
[`~DiffusionPipeline.enable_model_cpu_offload`] also helps when you're using the [`~StableDiffusionXLPipeline.encode_prompt`] method on its own to generate the text encoders hidden state.
|
||||
|
||||
## Group offloading
|
||||
### Group offloading
|
||||
|
||||
Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) or [torch.nn.Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)) to the CPU. It uses less memory than [model offloading](#model-offloading) and it is faster than [CPU offloading](#cpu-offloading) because it reduces communication overhead.
|
||||
|
||||
@@ -278,7 +328,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
### CUDA stream
|
||||
#### CUDA stream
|
||||
|
||||
The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.
|
||||
|
||||
@@ -295,22 +345,25 @@ pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_d
|
||||
|
||||
The `low_cpu_mem_usage` parameter can be set to `True` to reduce CPU memory usage when using streams during group offloading. It is best for `leaf_level` offloading and when CPU memory is bottlenecked. Memory is saved by creating pinned tensors on the fly instead of pre-pinning them. However, this may increase overall execution time.
|
||||
|
||||
<Tip>
|
||||
#### Offloading to disk
|
||||
|
||||
The offloading strategies can be combined with [quantization](../quantization/overview.md) to enable further memory savings. For image generation, combining [quantization and model offloading](#model-offloading) can often give the best trade-off between quality, speed, and memory. However, for video generation, as the models are more
|
||||
compute-bound, [group-offloading](#group-offloading) tends to be better. Group offloading provides considerable benefits when weight transfers can be overlapped with computation (must use streams). When applying group offloading with quantization on image generation models at typical resolutions (1024x1024, for example), it is usually not possible to *fully* overlap weight transfers if the compute kernel finishes faster, making it communication bound between CPU/GPU (due to device synchronizations).
|
||||
Group offloading can consume significant system memory depending on the model size. On systems with limited memory, try group offloading onto the disk as a secondary memory.
|
||||
|
||||
</Tip>
|
||||
Set the `offload_to_disk_path` argument in either [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`] to offload the model to the disk.
|
||||
|
||||
### Offloading to disk
|
||||
```py
|
||||
pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", offload_to_disk_path="path/to/disk")
|
||||
|
||||
Group offloading can consume significant system RAM depending on the model size. In limited RAM environments,
|
||||
it can be useful to offload to the second memory, instead. You can do this by setting the `offload_to_disk_path`
|
||||
argument in either of [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`]. Refer [here](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) and
|
||||
[here](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) for the expected speed-memory trade-offs with this option enabled.
|
||||
apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2, offload_to_disk_path="path/to/disk")
|
||||
```
|
||||
|
||||
Refer to these [two](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) [tables](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) to compare the speed and memory trade-offs.
|
||||
|
||||
## Layerwise casting
|
||||
|
||||
> [!TIP]
|
||||
> Combine layerwise casting with [group offloading](#group-offloading) for even more memory savings.
|
||||
|
||||
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
|
||||
|
||||
> [!WARNING]
|
||||
@@ -500,7 +553,7 @@ with torch.inference_mode():
|
||||
## Memory-efficient attention
|
||||
|
||||
> [!TIP]
|
||||
> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention!
|
||||
> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention)!
|
||||
|
||||
The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types.
|
||||
|
||||
|
||||
199
docs/source/en/optimization/speed-memory-optims.md
Normal file
199
docs/source/en/optimization/speed-memory-optims.md
Normal file
@@ -0,0 +1,199 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Compile and offloading quantized models
|
||||
|
||||
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
|
||||
|
||||
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
|
||||
|
||||
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
|
||||
|
||||
The table below provides a comparison of optimization strategy combinations and their impact on latency and memory-usage for Flux.
|
||||
|
||||
| combination | latency (s) | memory-usage (GB) |
|
||||
|---|---|---|
|
||||
| quantization | 32.602 | 14.9453 |
|
||||
| quantization, torch.compile | 25.847 | 14.9448 |
|
||||
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
|
||||
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the <a href="https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d" benchmarking script</a> if you're interested in evaluating your own model.</small>
|
||||
|
||||
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
|
||||
|
||||
```bash
|
||||
pip install -U bitsandbytes
|
||||
```
|
||||
|
||||
## Quantization and torch.compile
|
||||
|
||||
Start by [quantizing](../quantization/overview) a model to reduce the memory required for storage and [compiling](./fp16#torchcompile) it to accelerate inference.
|
||||
|
||||
Configure the [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `capture_dynamic_output_shape_ops = True` to handle dynamic outputs when compiling bitsandbytes models.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
# quantize
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
# compile
|
||||
pipeline.transformer.to(memory_format=torch.channels_last)
|
||||
pipeline.transformer.compile(mode="max-autotune", fullgraph=True)
|
||||
pipeline("""
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Quantization, torch.compile, and offloading
|
||||
|
||||
In addition to quantization and torch.compile, try offloading if you need to reduce memory-usage further. Offloading moves various layers or model components from the CPU to the GPU as needed for computations.
|
||||
|
||||
Configure the [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `cache_size_limit` during offloading to avoid excessive recompilation and set `capture_dynamic_output_shape_ops = True` to handle dynamic outputs when compiling bitsandbytes models.
|
||||
|
||||
<hfoptions id="offloading">
|
||||
<hfoption id="model CPU offloading">
|
||||
|
||||
[Model CPU offloading](./memory#model-offloading) moves an individual pipeline component, like the transformer model, to the GPU when it is needed for computation. Otherwise, it is offloaded to the CPU.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 1000
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
# quantize
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
# model CPU offloading
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
# compile
|
||||
pipeline.transformer.compile()
|
||||
pipeline(
|
||||
"cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
|
||||
).images[0]
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="group offloading">
|
||||
|
||||
[Group offloading](./memory#group-offloading) moves the internal layers of an individual pipeline component, like the transformer model, to the GPU for computation and offloads it when it's not required. At the same time, it uses the [CUDA stream](./memory#cuda-stream) feature to prefetch the next layer for execution.
|
||||
|
||||
By overlapping computation and data transfer, it is faster than model CPU offloading while also saving memory.
|
||||
|
||||
```py
|
||||
# pip install ftfy
|
||||
import torch
|
||||
from diffusers import AutoModel, DiffusionPipeline
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from transformers import UMT5EncoderModel
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 1000
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
# quantize
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
|
||||
components_to_quantize=["transformer", "text_encoder"],
|
||||
)
|
||||
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
# group offloading
|
||||
onload_device = torch.device("cuda")
|
||||
offload_device = torch.device("cpu")
|
||||
|
||||
pipeline.transformer.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True,
|
||||
non_blocking=True
|
||||
)
|
||||
pipeline.vae.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True,
|
||||
non_blocking=True
|
||||
)
|
||||
apply_group_offloading(
|
||||
pipeline.text_encoder,
|
||||
onload_device=onload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True,
|
||||
non_blocking=True
|
||||
)
|
||||
|
||||
# compile
|
||||
pipeline.transformer.compile()
|
||||
|
||||
prompt = """
|
||||
The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
|
||||
"""
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=81,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
@@ -203,6 +203,46 @@ pipeline("bears, pizza bites").images[0]
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Scale scheduling
|
||||
|
||||
Dynamically adjusting the LoRA scale during sampling gives you better control over the overall composition and layout because certain steps may benefit more from an increased or reduced scale.
|
||||
|
||||
The [character LoRA](https://huggingface.co/alvarobartt/ghibli-characters-flux-lora) in the example below starts with a higher scale that gradually decays over the first 20 steps to establish the character generation. In the later steps, only a scale of 0.2 is applied to avoid adding too much of the LoRA features to other parts of the image the LoRA wasn't trained on.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
pipelne.load_lora_weights("alvarobartt/ghibli-characters-flux-lora", "lora")
|
||||
|
||||
num_inference_steps = 30
|
||||
lora_steps = 20
|
||||
lora_scales = torch.linspace(1.5, 0.7, lora_steps).tolist()
|
||||
lora_scales += [0.2] * (num_inference_steps - lora_steps + 1)
|
||||
|
||||
pipeline.set_adapters("lora", lora_scales[0])
|
||||
|
||||
def callback(pipeline: FluxPipeline, step: int, timestep: torch.LongTensor, callback_kwargs: dict):
|
||||
pipeline.set_adapters("lora", lora_scales[step + 1])
|
||||
return callback_kwargs
|
||||
|
||||
prompt = """
|
||||
Ghibli style The Grinch, a mischievous green creature with a sly grin, peeking out from behind a snow-covered tree while plotting his antics,
|
||||
in a quaint snowy village decorated for the holidays, warm light glowing from cozy homes, with playful snowflakes dancing in the air
|
||||
"""
|
||||
pipeline(
|
||||
prompt=prompt,
|
||||
guidance_scale=3.0,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
callback_on_step_end=callback,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Hotswapping
|
||||
|
||||
Hotswapping LoRAs is an efficient way to work with multiple LoRAs while avoiding accumulating memory from multiple calls to [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] and in some cases, recompilation, if a model is compiled. This workflow requires a loaded LoRA because the new LoRA weights are swapped in place for the existing loaded LoRA.
|
||||
|
||||
@@ -186,9 +186,15 @@ class CosmosAttnProcessor2_0:
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
||||
|
||||
# 4. Prepare for GQA
|
||||
query_idx = torch.tensor(query.size(3), device=query.device)
|
||||
key_idx = torch.tensor(key.size(3), device=key.device)
|
||||
value_idx = torch.tensor(value.size(3), device=value.device)
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
query_idx = torch.tensor(query.size(3), device=query.device)
|
||||
key_idx = torch.tensor(key.size(3), device=key.device)
|
||||
value_idx = torch.tensor(value.size(3), device=value.device)
|
||||
|
||||
else:
|
||||
query_idx = query.size(3)
|
||||
key_idx = key.size(3)
|
||||
value_idx = value.size(3)
|
||||
key = key.repeat_interleave(query_idx // key_idx, dim=3)
|
||||
value = value.repeat_interleave(query_idx // value_idx, dim=3)
|
||||
|
||||
|
||||
@@ -52,20 +52,21 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import ChromaPipeline
|
||||
|
||||
>>> model_id = "lodestones/Chroma"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
||||
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
|
||||
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.1-schnell",
|
||||
>>> pipe = ChromaPipeline.from_pretrained(
|
||||
... model_id,
|
||||
... transformer=transformer,
|
||||
... text_encoder=text_encoder,
|
||||
... tokenizer=tokenizer,
|
||||
... torch_dtype=torch.bfloat16,
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||
>>> prompt = [
|
||||
... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
|
||||
... ]
|
||||
>>> negative_prompt = [
|
||||
... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||
... ]
|
||||
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
|
||||
>>> image.save("chroma.png")
|
||||
```
|
||||
|
||||
@@ -51,26 +51,21 @@ EXAMPLE_DOC_STRING = """
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
|
||||
>>> from transformers import AutoModel, Autotokenizer
|
||||
|
||||
>>> model_id = "lodestones/Chroma"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
||||
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
|
||||
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.1-schnell",
|
||||
... model_id,
|
||||
... transformer=transformer,
|
||||
... text_encoder=text_encoder,
|
||||
... tokenizer=tokenizer,
|
||||
... torch_dtype=torch.bfloat16,
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> image = load_image(
|
||||
>>> init_image = load_image(
|
||||
... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
... )
|
||||
>>> prompt = "a scenic fastasy landscape with a river and mountains in the background, vibrant colors, detailed, high resolution"
|
||||
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||
>>> image = pipe(prompt, image=image, negative_prompt=negative_prompt).images[0]
|
||||
>>> image = pipe(prompt, image=init_image, negative_prompt=negative_prompt).images[0]
|
||||
>>> image.save("chroma-img2img.png")
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -44,6 +44,8 @@ def retrieve_latents(
|
||||
|
||||
|
||||
class LTXLatentUpsamplePipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = ""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLLTXVideo,
|
||||
|
||||
@@ -1131,3 +1131,26 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
|
||||
break
|
||||
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
|
||||
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
|
||||
|
||||
|
||||
def _maybe_warn_for_wrong_component_in_quant_config(pipe_init_dict, quant_config):
|
||||
if quant_config is None:
|
||||
return
|
||||
|
||||
actual_pipe_components = set(pipe_init_dict.keys())
|
||||
missing = ""
|
||||
quant_components = None
|
||||
if getattr(quant_config, "components_to_quantize", None) is not None:
|
||||
quant_components = set(quant_config.components_to_quantize)
|
||||
elif getattr(quant_config, "quant_mapping", None) is not None and isinstance(quant_config.quant_mapping, dict):
|
||||
quant_components = set(quant_config.quant_mapping.keys())
|
||||
|
||||
if quant_components and not quant_components.issubset(actual_pipe_components):
|
||||
missing = quant_components - actual_pipe_components
|
||||
|
||||
if missing:
|
||||
logger.warning(
|
||||
f"The following components in the quantization config {missing} will be ignored "
|
||||
"as they do not belong to the underlying pipeline. Acceptable values for the pipeline "
|
||||
f"components are: {', '.join(actual_pipe_components)}."
|
||||
)
|
||||
|
||||
@@ -88,6 +88,7 @@ from .pipeline_loading_utils import (
|
||||
_identify_model_variants,
|
||||
_maybe_raise_error_for_incorrect_transformers,
|
||||
_maybe_raise_warning_for_inpainting,
|
||||
_maybe_warn_for_wrong_component_in_quant_config,
|
||||
_resolve_custom_pipeline_and_cls,
|
||||
_unwrap_model,
|
||||
_update_init_kwargs_with_connected_pipeline,
|
||||
@@ -984,6 +985,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# 7. Load each module in the pipeline
|
||||
current_device_map = None
|
||||
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
||||
# 7.1 device_map shenanigans
|
||||
if final_device_map is not None and len(final_device_map) > 0:
|
||||
|
||||
@@ -16,10 +16,13 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import DiffusionPipeline, QuantoConfig
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_transformers_available,
|
||||
require_accelerate,
|
||||
require_bitsandbytes_version_greater,
|
||||
@@ -188,3 +191,55 @@ class PipelineQuantizationTests(unittest.TestCase):
|
||||
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
|
||||
|
||||
self.assertTrue(torch.allclose(output_1, output_2))
|
||||
|
||||
@parameterized.expand(["quant_kwargs", "quant_mapping"])
|
||||
def test_warn_invalid_component(self, method):
|
||||
invalid_component = "foo"
|
||||
if method == "quant_kwargs":
|
||||
components_to_quantize = ["transformer", invalid_component]
|
||||
quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
components_to_quantize=components_to_quantize,
|
||||
)
|
||||
else:
|
||||
quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": QuantoConfig("int8"),
|
||||
invalid_component: TranBitsAndBytesConfig(load_in_8bit=True),
|
||||
}
|
||||
)
|
||||
|
||||
logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils")
|
||||
logger.setLevel(logging.WARNING)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
self.model_name,
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
self.assertTrue(invalid_component in cap_logger.out)
|
||||
|
||||
@parameterized.expand(["quant_kwargs", "quant_mapping"])
|
||||
def test_no_quantization_for_all_invalid_components(self, method):
|
||||
invalid_component = "foo"
|
||||
if method == "quant_kwargs":
|
||||
components_to_quantize = [invalid_component]
|
||||
quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_8bit",
|
||||
quant_kwargs={"load_in_8bit": True},
|
||||
components_to_quantize=components_to_quantize,
|
||||
)
|
||||
else:
|
||||
quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={invalid_component: TranBitsAndBytesConfig(load_in_8bit=True)}
|
||||
)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
self.model_name,
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
for name, component in pipe.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
self.assertTrue(not hasattr(component.config, "quantization_config"))
|
||||
|
||||
Reference in New Issue
Block a user