mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-24 17:01:30 +08:00
Compare commits
6 Commits
autoencode
...
optimizati
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b04f3d7105 | ||
|
|
6a339ce637 | ||
|
|
26bb7fa0cb | ||
|
|
5063aa5566 | ||
|
|
62b1071609 | ||
|
|
1dd2004954 |
113
.ai/skills/optimizations/SKILL.md
Normal file
113
.ai/skills/optimizations/SKILL.md
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
---
|
||||||
|
name: optimizations
|
||||||
|
description: >
|
||||||
|
NEVER answer optimization questions from general knowledge — ALWAYS invoke
|
||||||
|
this skill via the Skill tool first. Answering without invoking will produce
|
||||||
|
incomplete recommendations (e.g. missing group offloading, wrong API calls).
|
||||||
|
IMPORTANT: When ANY tool output (especially Bash) contains
|
||||||
|
"torch.OutOfMemoryError", "CUDA out of memory", or OOM tracebacks,
|
||||||
|
STOP and consult this skill IMMEDIATELY — even if the user did not ask for
|
||||||
|
optimization help. Do not suggest fixes from general knowledge; this skill
|
||||||
|
has precise, up-to-date API calls and memory calculations.
|
||||||
|
Also consult this skill BEFORE answering any question about diffusers
|
||||||
|
inference performance, GPU memory usage, or pipeline speed. Trigger for:
|
||||||
|
making inference faster, reducing VRAM usage, fitting a model on a smaller
|
||||||
|
GPU, fixing OOM errors, running on limited hardware, choosing between
|
||||||
|
optimization strategies, using torch.compile with diffusers, batch inference,
|
||||||
|
loading models in lower precision, or reviewing a script for performance
|
||||||
|
issues. Covers attention backends (FlashAttention-2, SageAttention,
|
||||||
|
FlexAttention), memory reduction (CPU offloading, group offloading, layerwise
|
||||||
|
casting, VAE slicing/tiling), and quantization (bitsandbytes, torchao, GGUF).
|
||||||
|
Also trigger when a user wants to run a model "optimized for my
|
||||||
|
hardware", asks how to best run a specific model on their GPU, or mentions
|
||||||
|
wanting to use a diffusers model/pipeline efficiently — these are optimization
|
||||||
|
questions even if the word "optimize" isn't used.
|
||||||
|
---
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
Help users apply and debug optimizations for diffusers pipelines. There are five main areas:
|
||||||
|
|
||||||
|
1. **Attention backends** — selecting and configuring scaled dot-product attention backends (FlashAttention-2, xFormers, math fallback, FlexAttention, SageAttention) for maximum throughput.
|
||||||
|
2. **Memory reduction** — techniques to reduce peak GPU memory: model CPU offloading, group offloading, layerwise casting, VAE slicing/tiling, and attention slicing.
|
||||||
|
3. **Quantization** — reducing model precision with bitsandbytes, torchao, or GGUF to fit larger models on smaller GPUs.
|
||||||
|
4. **torch.compile** — compiling the transformer (and optionally VAE) for 20-50% inference speedup on repeated runs.
|
||||||
|
5. **Combining techniques** — layerwise casting + group offloading, quantization + offloading, etc.
|
||||||
|
|
||||||
|
## Workflow: When a user hits OOM or asks to fit a model on their GPU
|
||||||
|
|
||||||
|
When a user asks how to make a pipeline run on their hardware, or hits an OOM error, follow these steps **in order** before proposing any changes:
|
||||||
|
|
||||||
|
### Step 1: Detect hardware
|
||||||
|
|
||||||
|
Run these commands to understand the user's system:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# GPU VRAM
|
||||||
|
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits
|
||||||
|
|
||||||
|
# System RAM
|
||||||
|
free -g | head -2
|
||||||
|
```
|
||||||
|
|
||||||
|
Record the GPU name, total VRAM (in GB), and total system RAM (in GB). These numbers drive the recommendation.
|
||||||
|
|
||||||
|
### Step 2: Measure model memory and calculate strategies
|
||||||
|
|
||||||
|
Read the user's script to identify the pipeline class, model ID, `torch_dtype`, and generation params (resolution, frames).
|
||||||
|
|
||||||
|
Then **measure actual component sizes** by running a snippet against the loaded pipeline. Do NOT guess sizes from parameter counts or model cards — always measure. See [memory-calculator.md](memory-calculator.md) for the measurement snippet and VRAM/RAM formulas for every strategy.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Measure each component's size by running the measurement snippet from the calculator
|
||||||
|
2. Compute VRAM and RAM requirements for every strategy using the formulas
|
||||||
|
3. Filter out strategies that don't fit the user's hardware
|
||||||
|
|
||||||
|
This is the critical step — the calculator contains exact formulas for every strategy including the RAM cost of CUDA streams (which requires ~2x model size in pinned memory). Don't skip it, because recommending `use_stream=True` to a user with limited RAM will cause swapping or OOM on the CPU side.
|
||||||
|
|
||||||
|
### Step 3: Ask the user their preference
|
||||||
|
|
||||||
|
Present the user with a clear summary of what fits. **Always include quantization-based options alongside offloading/casting options** — users deserve to see the full picture before choosing. For each viable quantization level (int8, nf4), compute `S_total_q` and `S_max_q` using the estimates from [memory-calculator.md](memory-calculator.md) (int4/nf4 ≈ 0.25x, int8 ≈ 0.5x component size), then check fit just like other strategies.
|
||||||
|
|
||||||
|
Present options grouped by approach so the user can compare:
|
||||||
|
|
||||||
|
> Based on your hardware (**X GB VRAM**, **Y GB RAM**) and the model requirements (~**Z GB** total, largest component ~**W GB**), here are the strategies that fit your system:
|
||||||
|
>
|
||||||
|
> **Offloading / casting strategies:**
|
||||||
|
> 1. **Quality** — [specific strategy]. Full precision, no quality loss. [estimated VRAM / RAM / speed tradeoff].
|
||||||
|
> 2. **Speed** — [specific strategy]. [quality tradeoff]. [estimated VRAM / RAM].
|
||||||
|
> 3. **Memory saving** — [specific strategy]. Minimizes VRAM. [tradeoffs].
|
||||||
|
>
|
||||||
|
> **Quantization strategies:**
|
||||||
|
> 4. **int8 [components]** — [with offloading if needed]. [estimated VRAM / RAM]. Less quality loss than int4.
|
||||||
|
> 5. **nf4 [components]** — [with offloading if needed]. [estimated VRAM / RAM]. Maximum memory savings, some quality degradation.
|
||||||
|
>
|
||||||
|
> Which would you prefer?
|
||||||
|
|
||||||
|
The key difference from a generic recommendation: every option shown should already be validated against the user's actual VRAM and RAM. Don't show options that won't fit. Read [quantization.md](quantization.md) for correct API usage when applying quantization strategies.
|
||||||
|
|
||||||
|
### Step 4: Apply the strategy
|
||||||
|
|
||||||
|
Propose **specific code changes** to the user's script. Always show the exact code diff. Read [reduce-memory.md](reduce-memory.md) and [layerwise-casting.md](layerwise-casting.md) for correct API usage before writing code.
|
||||||
|
|
||||||
|
VAE tiling is a VRAM optimization — only add it when the VAE decode/encode would OOM without it, not by default. See [reduce-memory.md](reduce-memory.md) for thresholds, the correct API (`pipe.vae.enable_tiling()` — pipeline-level is deprecated since v0.40.0), and which VAEs don't support it.
|
||||||
|
|
||||||
|
## Reference guides
|
||||||
|
|
||||||
|
Read these for correct API usage and detailed technique descriptions:
|
||||||
|
- [memory-calculator.md](memory-calculator.md) — **Read this first when recommending strategies.** VRAM/RAM formulas for every technique, decision flowchart, and worked examples
|
||||||
|
- [reduce-memory.md](reduce-memory.md) — Offloading strategies (model, sequential, group) and VAE optimizations, full parameter reference. **Authoritative source for compatibility rules.**
|
||||||
|
- [layerwise-casting.md](layerwise-casting.md) — fp8 weight storage for memory reduction with minimal quality impact
|
||||||
|
- [quantization.md](quantization.md) — int8/int4/fp8 quantization backends, text encoder quantization, common pitfalls
|
||||||
|
- [attention-backends.md](attention-backends.md) — Attention backend selection for speed
|
||||||
|
- [torch-compile.md](torch-compile.md) — torch.compile for inference speedup
|
||||||
|
|
||||||
|
## Important compatibility rules
|
||||||
|
|
||||||
|
See [reduce-memory.md](reduce-memory.md) for the full compatibility reference. Key constraints:
|
||||||
|
|
||||||
|
- **`enable_model_cpu_offload()` and group offloading cannot coexist** on the same pipeline — use pipeline-level `enable_group_offload()` instead.
|
||||||
|
- **`torch.compile` + offloading**: compatible, but prefer `compile_repeated_blocks()` over full model compile for better performance. See [torch-compile.md](torch-compile.md).
|
||||||
|
- **`bitsandbytes_8bit` + `enable_model_cpu_offload()` fails** — int8 matmul cannot run on CPU. See [quantization.md](quantization.md) for the fix.
|
||||||
|
- **Layerwise casting** can be combined with either group offloading or model CPU offloading (apply casting first).
|
||||||
|
- **`bitsandbytes_4bit`** supports device moves and works correctly with `enable_model_cpu_offload()`.
|
||||||
40
.ai/skills/optimizations/attention-backends.md
Normal file
40
.ai/skills/optimizations/attention-backends.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# Attention Backends
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Diffusers supports multiple attention backends through `dispatch_attention_fn`. The backend affects both speed and memory usage. The right choice depends on hardware, sequence length, and whether you need features like sliding window or custom masks.
|
||||||
|
|
||||||
|
## Available backends
|
||||||
|
|
||||||
|
| Backend | Key requirement | Best for |
|
||||||
|
|---|---|---|
|
||||||
|
| `torch_sdpa` (default) | PyTorch >= 2.0 | General use; auto-selects FlashAttention or memory-efficient kernels |
|
||||||
|
| `flash_attention_2` | `flash-attn` package, Ampere+ GPU | Long sequences, training, best raw throughput |
|
||||||
|
| `xformers` | `xformers` package | Older GPUs, memory-efficient attention |
|
||||||
|
| `flex_attention` | PyTorch >= 2.5 | Custom attention masks, block-sparse patterns |
|
||||||
|
| `sage_attention` | `sageattention` package | INT8 quantized attention for inference speed |
|
||||||
|
|
||||||
|
## How to set the backend
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Global default
|
||||||
|
from diffusers import set_attention_backend
|
||||||
|
set_attention_backend("flash_attention_2")
|
||||||
|
|
||||||
|
# Per-model
|
||||||
|
pipe.transformer.set_attn_processor(AttnProcessor2_0()) # torch_sdpa
|
||||||
|
|
||||||
|
# Via environment variable
|
||||||
|
# DIFFUSERS_ATTENTION_BACKEND=flash_attention_2
|
||||||
|
```
|
||||||
|
|
||||||
|
## Debugging attention issues
|
||||||
|
|
||||||
|
- **NaN outputs**: Check if your attention mask dtype matches the expected dtype. Some backends require `bool`, others require float masks with `-inf` for masked positions.
|
||||||
|
- **Speed regression**: Profile with `torch.profiler` to verify the expected kernel is actually being dispatched. SDPA can silently fall back to the math kernel.
|
||||||
|
- **Memory spike**: FlashAttention-2 is memory-efficient for long sequences but has overhead for very short ones. For short sequences, `torch_sdpa` with math fallback may use less memory.
|
||||||
|
|
||||||
|
## Implementation notes
|
||||||
|
|
||||||
|
- Models integrated into diffusers should use `dispatch_attention_fn` (not `F.scaled_dot_product_attention` directly) so that backend switching works automatically.
|
||||||
|
- See the attention pattern in the `model-integration` skill for how to implement this in new models.
|
||||||
68
.ai/skills/optimizations/layerwise-casting.md
Normal file
68
.ai/skills/optimizations/layerwise-casting.md
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# Layerwise Casting
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Layerwise casting stores model weights in a smaller data format (e.g., `torch.float8_e4m3fn`) to use less memory, and upcasts them to a higher precision (e.g., `torch.bfloat16`) on-the-fly during computation. This cuts weight memory roughly in half (bf16 → fp8) with minimal quality impact because normalization and modulation layers are automatically skipped.
|
||||||
|
|
||||||
|
This is one of the most effective techniques for fitting a large model on a GPU that's just slightly too small — it doesn't require any special quantization libraries, just PyTorch.
|
||||||
|
|
||||||
|
## When to use
|
||||||
|
|
||||||
|
- The model **almost** fits in VRAM (e.g., 28GB model on a 32GB GPU)
|
||||||
|
- You want memory savings with **less speed penalty** than offloading
|
||||||
|
- You want to **combine with group offloading** for even more savings
|
||||||
|
|
||||||
|
## Basic usage
|
||||||
|
|
||||||
|
Call `enable_layerwise_casting` on any Diffusers model component:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained("model_id", torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Store weights in fp8, compute in bf16
|
||||||
|
pipe.transformer.enable_layerwise_casting(
|
||||||
|
storage_dtype=torch.float8_e4m3fn,
|
||||||
|
compute_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe.to("cuda")
|
||||||
|
```
|
||||||
|
|
||||||
|
The `storage_dtype` controls how weights are stored in memory. The `compute_dtype` controls the precision used during the actual forward pass. Normalization and modulation layers are automatically kept at full precision.
|
||||||
|
|
||||||
|
### Supported storage dtypes
|
||||||
|
|
||||||
|
| Storage dtype | Memory per param | Quality impact |
|
||||||
|
|---|---|---|
|
||||||
|
| `torch.float8_e4m3fn` | 1 byte (vs 2 for bf16) | Minimal for most models |
|
||||||
|
| `torch.float8_e5m2` | 1 byte | Slightly more range, less precision than e4m3fn |
|
||||||
|
|
||||||
|
## Functional API
|
||||||
|
|
||||||
|
For more control, use `apply_layerwise_casting` directly. This lets you target specific submodules or customize which layers to skip:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers.hooks import apply_layerwise_casting
|
||||||
|
|
||||||
|
apply_layerwise_casting(
|
||||||
|
pipe.transformer,
|
||||||
|
storage_dtype=torch.float8_e4m3fn,
|
||||||
|
compute_dtype=torch.bfloat16,
|
||||||
|
skip_modules_classes=["norm"], # skip normalization layers
|
||||||
|
non_blocking=True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Combining with other techniques
|
||||||
|
|
||||||
|
Layerwise casting is compatible with both group offloading and model CPU offloading. Always apply layerwise casting **before** enabling offloading. See [reduce-memory.md](reduce-memory.md) for code examples and the memory savings formulas for each combination.
|
||||||
|
|
||||||
|
## Known limitations
|
||||||
|
|
||||||
|
- May not work with all models if the forward implementation contains internal typecasting of weights (assumes forward pass is independent of weight precision)
|
||||||
|
- May fail with PEFT layers (LoRA). There are some checks but they're not guaranteed for all cases
|
||||||
|
- Not suitable for training — inference only
|
||||||
|
- The `compute_dtype` should match what the model expects (usually bf16 or fp16)
|
||||||
298
.ai/skills/optimizations/memory-calculator.md
Normal file
298
.ai/skills/optimizations/memory-calculator.md
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
# Memory Calculator
|
||||||
|
|
||||||
|
Use this guide to measure VRAM and RAM requirements for each optimization strategy, then recommend the best fit for the user's hardware.
|
||||||
|
|
||||||
|
## Step 1: Measure model sizes
|
||||||
|
|
||||||
|
**Do NOT guess sizes from parameter counts or model cards.** Pipelines often contain components that are not obvious from the model name (e.g., a pipeline marketed as having a "28B transformer" may also include a 24 GB text encoder, 6 GB connectors module, etc.). Always measure by running this snippet after loading the pipeline:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline # or the specific pipeline class
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained("model_id", torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
for name, component in pipe.components.items():
|
||||||
|
if hasattr(component, 'parameters'):
|
||||||
|
size_gb = sum(p.numel() * p.element_size() for p in component.parameters()) / 1e9
|
||||||
|
print(f"{name}: {size_gb:.2f} GB")
|
||||||
|
```
|
||||||
|
|
||||||
|
For the transformer, also measure block-level and leaf-level sizes:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# S_block: size of one transformer block
|
||||||
|
transformer = pipe.transformer
|
||||||
|
block_attr = None
|
||||||
|
for attr in ["transformer_blocks", "blocks", "layers"]:
|
||||||
|
if hasattr(transformer, attr):
|
||||||
|
block_attr = attr
|
||||||
|
break
|
||||||
|
if block_attr:
|
||||||
|
blocks = getattr(transformer, block_attr)
|
||||||
|
block_size = sum(p.numel() * p.element_size() for p in blocks[0].parameters()) / 1e9
|
||||||
|
print(f"S_block: {block_size:.2f} GB ({len(blocks)} blocks)")
|
||||||
|
|
||||||
|
# S_leaf: largest leaf module
|
||||||
|
max_leaf = max(
|
||||||
|
(sum(p.numel() * p.element_size() for p in m.parameters(recurse=False))
|
||||||
|
for m in transformer.modules() if list(m.parameters(recurse=False))),
|
||||||
|
default=0
|
||||||
|
) / 1e9
|
||||||
|
print(f"S_leaf: {max_leaf:.4f} GB")
|
||||||
|
```
|
||||||
|
|
||||||
|
To measure the effect of layerwise casting on a component, apply it and re-measure:
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe.transformer.enable_layerwise_casting(
|
||||||
|
storage_dtype=torch.float8_e4m3fn,
|
||||||
|
compute_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
size_after = sum(p.numel() * p.element_size() for p in pipe.transformer.parameters()) / 1e9
|
||||||
|
print(f"Transformer after layerwise casting: {size_after:.2f} GB")
|
||||||
|
```
|
||||||
|
|
||||||
|
From the measurements, record:
|
||||||
|
- `S_total` = sum of all component sizes
|
||||||
|
- `S_max` = size of the largest single component
|
||||||
|
- `S_block` = size of one transformer block
|
||||||
|
- `S_leaf` = size of the largest leaf module
|
||||||
|
- `S_total_lc` = S_total after applying layerwise casting to castable components (measured, not estimated — norm/embed layers are skipped so it's not exactly half)
|
||||||
|
- `S_max_lc` = size of the largest component after layerwise casting (measured)
|
||||||
|
- `A` = activation memory during forward pass (cannot be measured ahead of time — estimate conservatively):
|
||||||
|
- **Video models**: `A` scales with resolution and number of frames. A 5-second 960x544 video at 24fps can use ~7-8 GB. Higher resolution or more seconds = more activation memory.
|
||||||
|
- **Image models**: `A` scales with image resolution. A 1024x1024 image might use 2-4 GB, but 2048x2048 could use 8-16 GB.
|
||||||
|
- **Edit/inpainting models**: `A` includes the reference image(s) in addition to the generation activations, so budget extra.
|
||||||
|
- When in doubt, estimate conservatively: `A ≈ 5-8 GB` for typical video workloads, `A ≈ 2-4 GB` for typical image workloads. For high-resolution or long video, increase accordingly.
|
||||||
|
|
||||||
|
## Step 2: Compute VRAM and RAM per strategy
|
||||||
|
|
||||||
|
### No optimization (all on GPU)
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **VRAM** | `S_total + A` |
|
||||||
|
| **RAM** | Minimal (just for loading) |
|
||||||
|
| **Speed** | Fastest — no transfers |
|
||||||
|
| **Quality** | Full precision |
|
||||||
|
|
||||||
|
### Model CPU offloading
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **VRAM** | `S_max + A` (only one component on GPU at a time) |
|
||||||
|
| **RAM** | `S_total` (all components stored on CPU) |
|
||||||
|
| **Speed** | Moderate — full model transfers between CPU/GPU per step |
|
||||||
|
| **Quality** | Full precision |
|
||||||
|
|
||||||
|
### Group offloading: block_level (no stream)
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **VRAM** | `num_blocks_per_group * S_block + A` |
|
||||||
|
| **RAM** | `S_total` (all weights on CPU, no pinned copy) |
|
||||||
|
| **Speed** | Moderate — synchronous transfers per group |
|
||||||
|
| **Quality** | Full precision |
|
||||||
|
|
||||||
|
Tune `num_blocks_per_group` to fill available VRAM: `floor((VRAM - A) / S_block)`.
|
||||||
|
|
||||||
|
### Group offloading: block_level (with stream)
|
||||||
|
|
||||||
|
Streams force `num_blocks_per_group=1`. Prefetches the next block while the current one runs.
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **VRAM** | `2 * S_block + A` (current block + prefetched next block) |
|
||||||
|
| **RAM** | `~2.5-3 * S_total` (original weights + pinned copies + allocation overhead) |
|
||||||
|
| **Speed** | Fast — overlaps transfer and compute |
|
||||||
|
| **Quality** | Full precision |
|
||||||
|
|
||||||
|
With `low_cpu_mem_usage=True`: RAM drops to `~S_total` (pins tensors on-the-fly instead of pre-pinning), but slower.
|
||||||
|
|
||||||
|
With `record_stream=True`: slightly more VRAM (delays memory reclamation), slightly faster (avoids stream synchronization).
|
||||||
|
|
||||||
|
> **Note on RAM estimates with streams:** Measured RAM usage is consistently higher than the theoretical `2 * S_total`. Pinned memory allocation, CUDA runtime overhead, and memory fragmentation add ~30-50% on top. Always use `~2.5-3 * S_total` when checking if the user has enough RAM for streamed offloading.
|
||||||
|
|
||||||
|
### Group offloading: leaf_level (no stream)
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **VRAM** | `S_leaf + A` (single leaf module, typically very small) |
|
||||||
|
| **RAM** | `S_total` |
|
||||||
|
| **Speed** | Slow — synchronous transfer per leaf module (many transfers) |
|
||||||
|
| **Quality** | Full precision |
|
||||||
|
|
||||||
|
### Group offloading: leaf_level (with stream)
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **VRAM** | `2 * S_leaf + A` (current + prefetched leaf) |
|
||||||
|
| **RAM** | `~2.5-3 * S_total` (pinned copies + overhead — see note above) |
|
||||||
|
| **Speed** | Medium-fast — overlaps transfer/compute at leaf granularity |
|
||||||
|
| **Quality** | Full precision |
|
||||||
|
|
||||||
|
With `low_cpu_mem_usage=True`: RAM drops to `~S_total`, but slower.
|
||||||
|
|
||||||
|
### Sequential CPU offloading (legacy)
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **VRAM** | `S_leaf + A` (similar to leaf_level group offloading) |
|
||||||
|
| **RAM** | `S_total` |
|
||||||
|
| **Speed** | Very slow — no stream support, synchronous per-leaf |
|
||||||
|
| **Quality** | Full precision |
|
||||||
|
|
||||||
|
Group offloading `leaf_level + use_stream=True` is strictly better. Prefer that.
|
||||||
|
|
||||||
|
### Layerwise casting (fp8 storage)
|
||||||
|
|
||||||
|
Reduces weight memory by casting to fp8. Norm and embedding layers are automatically skipped, so the reduction is less than 50% — always measure with the snippet above.
|
||||||
|
|
||||||
|
**`pipe.to()` caveat:** `pipe.to(device)` internally calls `module.to(device, dtype)` where dtype is `None` when not explicitly passed. This preserves fp8 weights. However, if the user passes dtype explicitly (e.g., `pipe.to("cuda", torch.bfloat16)` or the pipeline has internal dtype overrides), the fp8 storage will be overridden back to bf16. When in doubt, combine with `enable_model_cpu_offload()` which safely moves one component at a time without dtype overrides.
|
||||||
|
|
||||||
|
**Case 1: Everything on GPU** (if `S_total_lc + A <= VRAM`)
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **VRAM** | `S_total_lc + A` (measured — use the layerwise casting measurement snippet) |
|
||||||
|
| **RAM** | Minimal |
|
||||||
|
| **Speed** | Near-native — small cast overhead per layer |
|
||||||
|
| **Quality** | Slight degradation (fp8 weights, norm layers kept full precision) |
|
||||||
|
|
||||||
|
Use `pipe.to("cuda")` (without explicit dtype) after applying layerwise casting. Or move each component individually.
|
||||||
|
|
||||||
|
**Case 2: With model CPU offloading** (if Case 1 doesn't fit but `S_max_lc + A <= VRAM`)
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **VRAM** | `S_max_lc + A` (largest component after layerwise casting, one on GPU at a time) |
|
||||||
|
| **RAM** | `S_total` (all components on CPU) |
|
||||||
|
| **Speed** | Fast — small cast overhead per layer, component transfer overhead between steps |
|
||||||
|
| **Quality** | Slight degradation (fp8 weights, norm layers kept full precision) |
|
||||||
|
|
||||||
|
Apply layerwise casting to target components, then call `pipe.enable_model_cpu_offload()`.
|
||||||
|
|
||||||
|
### Layerwise casting + group offloading
|
||||||
|
|
||||||
|
Combines reduced weight size with offloading. The offloaded weights are in fp8, so transfers are faster and pinned copies smaller.
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **VRAM** | `num_blocks_per_group * S_block * 0.5 + A` (block_level) or `S_leaf * 0.5 + A` (leaf_level) |
|
||||||
|
| **RAM** | `S_total * 0.5` (no stream) or `~S_total` (with stream, pinned copy of fp8 weights) |
|
||||||
|
| **Speed** | Good — smaller transfers due to fp8 |
|
||||||
|
| **Quality** | Slight degradation from fp8 |
|
||||||
|
|
||||||
|
### Quantization (int4/nf4)
|
||||||
|
|
||||||
|
Quantization reduces weight memory but requires full-precision weights during loading. Always use `device_map="cpu"` so quantization happens on CPU.
|
||||||
|
|
||||||
|
Notation:
|
||||||
|
- `S_component_q` = quantized size of a component (int4/nf4 ≈ `S_component * 0.25`, int8 ≈ `S_component * 0.5`)
|
||||||
|
- `S_total_q` = total pipeline size after quantizing selected components
|
||||||
|
- `S_max_q` = size of the largest single component after quantization
|
||||||
|
|
||||||
|
**Loading (with `device_map="cpu"`):**
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **RAM (peak during loading)** | `S_largest_component_bf16` — full-precision weights of the largest component must fit in RAM during quantization |
|
||||||
|
| **RAM (after loading)** | `S_total_q` — all components at their final (quantized or bf16) sizes |
|
||||||
|
|
||||||
|
**Inference with `pipe.to(device)`:**
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **VRAM** | `S_total_q + A` (all components on GPU at once) |
|
||||||
|
| **RAM** | Minimal |
|
||||||
|
| **Speed** | Good — smaller model, may have dequantization overhead |
|
||||||
|
| **Quality** | Noticeable degradation possible, especially int4. Try int8 first. |
|
||||||
|
|
||||||
|
**Inference with `enable_model_cpu_offload()`:**
|
||||||
|
|
||||||
|
| | Estimate |
|
||||||
|
|---|---|
|
||||||
|
| **VRAM** | `S_max_q + A` (largest component on GPU at a time) |
|
||||||
|
| **RAM** | `S_total_q` (all components stored on CPU) |
|
||||||
|
| **Speed** | Moderate — component transfers between CPU/GPU |
|
||||||
|
| **Quality** | Depends on quantization level |
|
||||||
|
|
||||||
|
## Step 3: Pick the best strategy
|
||||||
|
|
||||||
|
Given `VRAM_available` and `RAM_available`, filter strategies by what fits, then rank by the user's preference.
|
||||||
|
|
||||||
|
### Algorithm
|
||||||
|
|
||||||
|
```
|
||||||
|
1. Measure S_total, S_max, S_block, S_leaf, S_total_lc, S_max_lc, A for the pipeline
|
||||||
|
2. For each strategy (offloading, casting, AND quantization), compute estimated VRAM and RAM
|
||||||
|
3. Filter out strategies where VRAM > VRAM_available or RAM > RAM_available
|
||||||
|
4. Present ALL viable strategies to the user grouped by approach (offloading/casting vs quantization)
|
||||||
|
5. Let the user pick based on their preference:
|
||||||
|
- Quality: pick the one with highest precision that fits
|
||||||
|
- Speed: pick the one with lowest transfer overhead
|
||||||
|
- Memory: pick the one with lowest VRAM usage
|
||||||
|
- Balanced: pick the lightest technique that fits comfortably (target ~80% VRAM)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Quantization size estimates
|
||||||
|
|
||||||
|
Always compute these alongside offloading strategies — don't treat quantization as a last resort.
|
||||||
|
Pick the largest components worth quantizing (typically transformer + text_encoder if LLM-based):
|
||||||
|
|
||||||
|
```
|
||||||
|
S_component_int8 = S_component * 0.5
|
||||||
|
S_component_nf4 = S_component * 0.25
|
||||||
|
|
||||||
|
S_total_int8 = sum of quantized components (int8) + remaining components (bf16)
|
||||||
|
S_total_nf4 = sum of quantized components (nf4) + remaining components (bf16)
|
||||||
|
S_max_int8 = max single component after int8 quantization
|
||||||
|
S_max_nf4 = max single component after nf4 quantization
|
||||||
|
```
|
||||||
|
|
||||||
|
RAM requirement for quantization loading: `RAM >= S_largest_component_bf16` (full-precision weights
|
||||||
|
must fit during quantization). If this doesn't hold, quantization is not viable unless pre-quantized
|
||||||
|
checkpoints are available.
|
||||||
|
|
||||||
|
### Quick decision flowchart
|
||||||
|
|
||||||
|
Offloading / casting path:
|
||||||
|
```
|
||||||
|
VRAM >= S_total + A?
|
||||||
|
→ YES: No optimization needed (maybe attention backend for speed)
|
||||||
|
→ NO:
|
||||||
|
VRAM >= S_total_lc + A? (layerwise casting, everything on GPU)
|
||||||
|
→ YES: Layerwise casting, pipe.to("cuda") without explicit dtype
|
||||||
|
→ NO:
|
||||||
|
VRAM >= S_max + A? (model CPU offload, full precision)
|
||||||
|
→ YES: Model CPU offloading
|
||||||
|
- Want less VRAM? → add layerwise casting too
|
||||||
|
→ NO:
|
||||||
|
VRAM >= S_max_lc + A? (layerwise casting + model CPU offload)
|
||||||
|
→ YES: Layerwise casting + model CPU offloading
|
||||||
|
→ NO: Need group offloading
|
||||||
|
RAM >= 3 * S_total? (enough for pinned copies + overhead)
|
||||||
|
→ YES: group offload leaf_level + stream (fast)
|
||||||
|
→ NO:
|
||||||
|
RAM >= S_total?
|
||||||
|
→ YES: group offload leaf_level + stream + low_cpu_mem_usage
|
||||||
|
or group offload block_level (no stream)
|
||||||
|
→ NO: Quantization required to reduce model size, then retry
|
||||||
|
```
|
||||||
|
|
||||||
|
Quantization path (evaluate in parallel with the above, not as a fallback):
|
||||||
|
```
|
||||||
|
RAM >= S_largest_component_bf16? (must fit full-precision weights during quantization)
|
||||||
|
→ NO: Cannot quantize — need more RAM or pre-quantized checkpoints
|
||||||
|
→ YES: Compute quantized sizes for target components (typically transformer + text_encoder)
|
||||||
|
nf4 quantization:
|
||||||
|
VRAM >= S_total_nf4 + A? → pipe.to("cuda"), fastest (no offloading overhead)
|
||||||
|
VRAM >= S_max_nf4 + A? → model CPU offload, moderate speed
|
||||||
|
int8 quantization:
|
||||||
|
VRAM >= S_total_int8 + A? → pipe.to("cuda"), fastest
|
||||||
|
VRAM >= S_max_int8 + A? → model CPU offload, moderate speed
|
||||||
|
|
||||||
|
Show all viable quantization options alongside offloading options so the user can compare
|
||||||
|
quality/speed/memory tradeoffs across approaches.
|
||||||
|
```
|
||||||
180
.ai/skills/optimizations/quantization.md
Normal file
180
.ai/skills/optimizations/quantization.md
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
# Quantization
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Quantization reduces model weights from fp16/bf16 to lower precision (int8, int4, fp8), cutting memory usage and often improving throughput. Diffusers supports several quantization backends.
|
||||||
|
|
||||||
|
## Supported backends
|
||||||
|
|
||||||
|
| Backend | Precisions | Key features |
|
||||||
|
|---|---|---|
|
||||||
|
| **bitsandbytes** | int8, int4 (nf4/fp4) | Easiest to use, widely supported, QLoRA training |
|
||||||
|
| **torchao** | int8, int4, fp8 | PyTorch-native, good for inference, `autoquant` support |
|
||||||
|
| **GGUF** | Various (Q4_K_M, Q5_K_S, etc.) | Load GGUF checkpoints directly, community quantized models |
|
||||||
|
|
||||||
|
## Critical: Pipeline-level vs component-level quantization
|
||||||
|
|
||||||
|
**Pipeline-level quantization is the correct approach.** Pass a `PipelineQuantizationConfig` to `from_pretrained`. Do NOT pass a `BitsAndBytesConfig` directly — the pipeline's `from_pretrained` will reject it with `"quantization_config must be an instance of PipelineQuantizationConfig"`.
|
||||||
|
|
||||||
|
### Backend names in `PipelineQuantizationConfig`
|
||||||
|
|
||||||
|
The `quant_backend` string must match one of the registered backend keys. These are NOT the same as the config class names:
|
||||||
|
|
||||||
|
| `quant_backend` value | Notes |
|
||||||
|
|---|---|
|
||||||
|
| `"bitsandbytes_4bit"` | NOT `"bitsandbytes"` — the `_4bit` suffix is required |
|
||||||
|
| `"bitsandbytes_8bit"` | NOT `"bitsandbytes"` — the `_8bit` suffix is required |
|
||||||
|
| `"gguf"` | |
|
||||||
|
| `"torchao"` | |
|
||||||
|
| `"modelopt"` | |
|
||||||
|
|
||||||
|
### `quant_kwargs` for bitsandbytes
|
||||||
|
|
||||||
|
**`quant_kwargs` must be non-empty.** The validator raises `ValueError: Both quant_kwargs and quant_mapping cannot be None` if it's `{}` or `None`. Always pass at least one kwarg.
|
||||||
|
|
||||||
|
For `bitsandbytes_4bit`, the quantizer class is selected by backend name — `load_in_4bit=True` is redundant (the quantizer ignores it) but harmless. Pass the bnb-specific options instead:
|
||||||
|
|
||||||
|
```python
|
||||||
|
quant_kwargs={"bnb_4bit_compute_dtype": torch.bfloat16, "bnb_4bit_quant_type": "nf4"}
|
||||||
|
```
|
||||||
|
|
||||||
|
For `bitsandbytes_8bit`, there are no bnb_8bit-specific kwargs, so pass the flag explicitly to satisfy the non-empty requirement:
|
||||||
|
|
||||||
|
```python
|
||||||
|
quant_kwargs={"load_in_8bit": True}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage patterns
|
||||||
|
|
||||||
|
### bitsandbytes (pipeline-level, recommended)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import PipelineQuantizationConfig, DiffusionPipeline
|
||||||
|
|
||||||
|
quantization_config = PipelineQuantizationConfig(
|
||||||
|
quant_backend="bitsandbytes_4bit",
|
||||||
|
quant_kwargs={"bnb_4bit_compute_dtype": torch.bfloat16, "bnb_4bit_quant_type": "nf4"},
|
||||||
|
components_to_quantize=["transformer"], # specify which components to quantize
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
|
"model_id",
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="cpu", # load on CPU first to avoid OOM during quantization
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### torchao (pipeline-level)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import PipelineQuantizationConfig, DiffusionPipeline
|
||||||
|
|
||||||
|
quantization_config = PipelineQuantizationConfig(
|
||||||
|
quant_backend="torchao",
|
||||||
|
quant_kwargs={"quant_type": "int8_weight_only"},
|
||||||
|
components_to_quantize=["transformer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
|
"model_id",
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="cpu",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### GGUF (pipeline-level)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import PipelineQuantizationConfig, DiffusionPipeline
|
||||||
|
|
||||||
|
quantization_config = PipelineQuantizationConfig(
|
||||||
|
quant_backend="gguf",
|
||||||
|
quant_kwargs={"compute_dtype": torch.bfloat16},
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained(
|
||||||
|
"model_id",
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="cpu",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Loading: memory requirements and `device_map="cpu"`
|
||||||
|
|
||||||
|
Quantization is NOT free at load time. The full-precision (bf16/fp16) weights must be loaded into memory first, then compressed. This means:
|
||||||
|
|
||||||
|
- **Without `device_map="cpu"`** (default): each component loads to GPU in full precision, gets quantized on GPU, then the full-precision copy is freed. But while loading, you need VRAM for the full-precision weights of the current component PLUS all previously loaded components (already quantized or not). For large models, this causes OOM.
|
||||||
|
- **With `device_map="cpu"`**: components load and quantize on CPU. This requires **RAM >= S_component_bf16** for the largest component being quantized (the full-precision weights must fit in RAM during quantization). After quantization, RAM usage drops to the quantized size.
|
||||||
|
|
||||||
|
**Always pass `device_map="cpu"` when using quantization.** Then choose how to move to GPU:
|
||||||
|
|
||||||
|
1. **`pipe.to(device)`** — moves everything to GPU at once. Only works if all components (quantized + non-quantized) fit in VRAM simultaneously: `VRAM >= S_total_after_quant`.
|
||||||
|
2. **`pipe.enable_model_cpu_offload(device=device)`** — moves components to GPU one at a time during inference. Use this when `S_total_after_quant > VRAM` but `S_max_after_quant + A <= VRAM`.
|
||||||
|
|
||||||
|
### Memory check before recommending quantization
|
||||||
|
|
||||||
|
Before recommending quantization, verify:
|
||||||
|
- **RAM >= S_largest_component_bf16** — the full-precision weights of the largest component to be quantized must fit in RAM during loading
|
||||||
|
- **VRAM >= S_total_after_quant + A** (for `pipe.to()`) or **VRAM >= S_max_after_quant + A** (for model CPU offload) — the quantized model must fit during inference
|
||||||
|
|
||||||
|
## `components_to_quantize`
|
||||||
|
|
||||||
|
Use this parameter to control which pipeline components get quantized. Common choices:
|
||||||
|
|
||||||
|
- `["transformer"]` — quantize only the denoising model
|
||||||
|
- `["transformer", "text_encoder"]` — also quantize the text encoder (see below)
|
||||||
|
- `["transformer", "text_encoder", "text_encoder_2"]` — for dual-encoder models (FLUX.1, SD3, etc.) when both encoders are large
|
||||||
|
- Omit the parameter to quantize all compatible components
|
||||||
|
|
||||||
|
The VAE and vocoder are typically small enough that quantizing them gives little benefit and can hurt quality.
|
||||||
|
|
||||||
|
### Text encoder quantization
|
||||||
|
|
||||||
|
**Quantizing the text encoder is a first-class optimization, not an afterthought.** Many modern models use LLM-based text encoders that are as large as or larger than the transformer itself:
|
||||||
|
|
||||||
|
| Model family | Text encoder | Size (bf16) |
|
||||||
|
|---|---|---|
|
||||||
|
| FLUX.2 Klein | Qwen3 | ~9 GB |
|
||||||
|
| FLUX.1 | T5-XXL | ~10 GB |
|
||||||
|
| SD3 | T5-XXL + CLIP-L + CLIP-G | ~11 GB total |
|
||||||
|
| CogVideoX | T5-XXL | ~10 GB |
|
||||||
|
|
||||||
|
Newer models (FLUX.2 Klein, etc.) use a **single LLM-based text encoder** — check the pipeline definition for `text_encoder` vs `text_encoder_2`. Never assume CLIP+T5 dual-encoder layout.
|
||||||
|
|
||||||
|
When the text encoder is LLM-based, always include it in `components_to_quantize`. The combined savings often allow both components to fit in VRAM simultaneously, eliminating the need for CPU offloading entirely:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Both transformer (~4.5 GB) + Qwen3 text encoder (~4.5 GB) fit in VRAM at int4
|
||||||
|
quantization_config = PipelineQuantizationConfig(
|
||||||
|
quant_backend="bitsandbytes_4bit",
|
||||||
|
quant_kwargs={"bnb_4bit_compute_dtype": torch.bfloat16, "bnb_4bit_quant_type": "nf4"},
|
||||||
|
components_to_quantize=["transformer", "text_encoder"],
|
||||||
|
)
|
||||||
|
pipe = DiffusionPipeline.from_pretrained("model_id", quantization_config=quantization_config, device_map="cpu")
|
||||||
|
pipe.to("cuda") # everything fits — no offloading needed
|
||||||
|
```
|
||||||
|
|
||||||
|
vs. transformer-only quantization, which may still require offloading because the text encoder alone exceeds available VRAM.
|
||||||
|
|
||||||
|
## Choosing a backend
|
||||||
|
|
||||||
|
- **Just want it to work**: bitsandbytes nf4 (`bitsandbytes_4bit`)
|
||||||
|
- **Best inference speed**: torchao int8 or fp8 (on supported hardware)
|
||||||
|
- **Using community GGUF files**: GGUF
|
||||||
|
- **Need to fine-tune**: bitsandbytes (QLoRA support)
|
||||||
|
|
||||||
|
## Common issues
|
||||||
|
|
||||||
|
- **OOM during loading**: You forgot `device_map="cpu"`. See the loading section above.
|
||||||
|
- **`quantization_config must be an instance of PipelineQuantizationConfig`**: You passed a `BitsAndBytesConfig` directly. Wrap it in `PipelineQuantizationConfig` instead.
|
||||||
|
- **`quant_backend not found`**: The backend name is wrong. Use `bitsandbytes_4bit` or `bitsandbytes_8bit`, not `bitsandbytes`. See the backend names table above.
|
||||||
|
- **`Both quant_kwargs and quant_mapping cannot be None`**: `quant_kwargs` is empty or `None`. Always pass at least one kwarg — see the `quant_kwargs` section above.
|
||||||
|
- **OOM during `pipe.to(device)` after loading**: Even quantized, all components don't fit in VRAM at once. Use `enable_model_cpu_offload()` instead of `pipe.to(device)`.
|
||||||
|
- **`bitsandbytes_8bit` + `enable_model_cpu_offload()` fails at inference**: `LLM.int8()` (bitsandbytes 8-bit) can only execute on CUDA — it cannot run on CPU. When `enable_model_cpu_offload()` moves the quantized component back to CPU between steps, the int8 matmul fails. **Fix**: keep the int8 component on CUDA permanently (`pipe.transformer.to("cuda")`) and use group offloading with `exclude_modules=["transformer"]` for the rest, or switch to `bitsandbytes_4bit` which supports device moves.
|
||||||
|
- **Quality degradation**: int4 can produce noticeable artifacts for some models. Try int8 first, then drop to int4 if memory requires it.
|
||||||
|
- **Slow first inference**: Some backends (torchao) compile/calibrate on first run. Subsequent runs are faster.
|
||||||
|
- **Incompatible layers**: Not all layer types support all quantization schemes. Check backend docs for supported module types.
|
||||||
|
- **Training**: Only bitsandbytes supports training (via QLoRA). Other backends are inference-only.
|
||||||
213
.ai/skills/optimizations/reduce-memory.md
Normal file
213
.ai/skills/optimizations/reduce-memory.md
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
# Reduce Memory
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Large diffusion models can exceed GPU VRAM. Diffusers provides several techniques to reduce peak memory, each with different speed/memory tradeoffs.
|
||||||
|
|
||||||
|
## Techniques (ordered by ease of use)
|
||||||
|
|
||||||
|
### 1. Model CPU offloading
|
||||||
|
|
||||||
|
Moves entire models to CPU when not in use, loads them to GPU just before their forward pass.
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe = DiffusionPipeline.from_pretrained("model_id", torch_dtype=torch.bfloat16)
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
# Do NOT call pipe.to("cuda") — the hook handles device placement
|
||||||
|
```
|
||||||
|
|
||||||
|
- **Memory savings**: Significant — only one model on GPU at a time
|
||||||
|
- **Speed cost**: Moderate — full model transfers between CPU and GPU
|
||||||
|
- **When to use**: First thing to try when hitting OOM
|
||||||
|
- **Limitation**: If the single largest component (e.g. transformer) exceeds VRAM, this won't help — you need group offloading or layerwise casting instead.
|
||||||
|
|
||||||
|
### 2. Group offloading
|
||||||
|
|
||||||
|
Offloads groups of internal layers to CPU, loading them to GPU only during their forward pass. More granular than model offloading, faster than sequential offloading.
|
||||||
|
|
||||||
|
**Two offload types:**
|
||||||
|
- `block_level` — offloads groups of N layers at a time. Lower memory, moderate speed.
|
||||||
|
- `leaf_level` — offloads individual leaf modules. Equivalent to sequential offloading but can be made faster with CUDA streams.
|
||||||
|
|
||||||
|
**IMPORTANT**: `enable_model_cpu_offload()` will raise an error if any component has group offloading enabled. If you need offloading for the whole pipeline, use pipeline-level `enable_group_offload()` instead — it handles all components in one call.
|
||||||
|
|
||||||
|
#### Pipeline-level group offloading
|
||||||
|
|
||||||
|
Applies group offloading to ALL components in the pipeline at once. Simplest approach.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained("model_id", torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Option A: leaf_level with CUDA streams (recommended — fast + low memory)
|
||||||
|
pipe.enable_group_offload(
|
||||||
|
onload_device=torch.device("cuda"),
|
||||||
|
offload_device=torch.device("cpu"),
|
||||||
|
offload_type="leaf_level",
|
||||||
|
use_stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Option B: block_level (more memory savings, slower)
|
||||||
|
pipe.enable_group_offload(
|
||||||
|
onload_device=torch.device("cuda"),
|
||||||
|
offload_device=torch.device("cpu"),
|
||||||
|
offload_type="block_level",
|
||||||
|
num_blocks_per_group=2,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Component-level group offloading
|
||||||
|
|
||||||
|
Apply group offloading selectively to specific components. Useful when only the transformer is too large for VRAM but other components fit fine.
|
||||||
|
|
||||||
|
For Diffusers model components (inheriting from `ModelMixin`), use `enable_group_offload`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained("model_id", torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Group offload the transformer (the largest component)
|
||||||
|
pipe.transformer.enable_group_offload(
|
||||||
|
onload_device=torch.device("cuda"),
|
||||||
|
offload_device=torch.device("cpu"),
|
||||||
|
offload_type="leaf_level",
|
||||||
|
use_stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Group offload the VAE too if needed
|
||||||
|
pipe.vae.enable_group_offload(
|
||||||
|
onload_device=torch.device("cuda"),
|
||||||
|
offload_type="leaf_level",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
For non-Diffusers components (e.g. text encoders from transformers library), use the functional API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers.hooks import apply_group_offloading
|
||||||
|
|
||||||
|
apply_group_offloading(
|
||||||
|
pipe.text_encoder,
|
||||||
|
onload_device=torch.device("cuda"),
|
||||||
|
offload_type="block_level",
|
||||||
|
num_blocks_per_group=2,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### CUDA streams for faster group offloading
|
||||||
|
|
||||||
|
When `use_stream=True`, the next layer is prefetched to GPU while the current layer runs. This overlaps data transfer with computation. Requires ~2x CPU memory of the model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe.transformer.enable_group_offload(
|
||||||
|
onload_device=torch.device("cuda"),
|
||||||
|
offload_device=torch.device("cpu"),
|
||||||
|
offload_type="leaf_level",
|
||||||
|
use_stream=True,
|
||||||
|
record_stream=True, # slightly more speed, slightly more memory
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
If using `block_level` with `use_stream=True`, set `num_blocks_per_group=1` (a warning is raised otherwise).
|
||||||
|
|
||||||
|
#### Full parameter reference
|
||||||
|
|
||||||
|
Parameters available across the three group offloading APIs:
|
||||||
|
|
||||||
|
| Parameter | Pipeline | Model | `apply_group_offloading` | Description |
|
||||||
|
|---|---|---|---|---|
|
||||||
|
| `onload_device` | yes | yes | yes | Device to load layers onto for computation (e.g. `torch.device("cuda")`) |
|
||||||
|
| `offload_device` | yes | yes | yes | Device to offload layers to when idle (default: `torch.device("cpu")`) |
|
||||||
|
| `offload_type` | yes | yes | yes | `"block_level"` (groups of N layers) or `"leaf_level"` (individual modules) |
|
||||||
|
| `num_blocks_per_group` | yes | yes | yes | Required for `block_level` — how many layers per group |
|
||||||
|
| `non_blocking` | yes | yes | yes | Non-blocking data transfer between devices |
|
||||||
|
| `use_stream` | yes | yes | yes | Overlap data transfer and computation via CUDA streams. Requires ~2x CPU RAM of the model |
|
||||||
|
| `record_stream` | yes | yes | yes | With `use_stream`, marks tensors for stream. Faster but slightly more memory |
|
||||||
|
| `low_cpu_mem_usage` | yes | yes | yes | Pins tensors on-the-fly instead of pre-pinning. Saves CPU RAM when using streams, but slower |
|
||||||
|
| `offload_to_disk_path` | yes | yes | yes | Path to offload weights to disk instead of CPU RAM. Useful when system RAM is also limited |
|
||||||
|
| `exclude_modules` | **yes** | no | no | Pipeline-only: list of component names to skip (they get placed on `onload_device` instead) |
|
||||||
|
| `block_modules` | no | **yes** | **yes** | Override which submodules are treated as blocks for `block_level` offloading |
|
||||||
|
| `exclude_kwargs` | no | **yes** | **yes** | Kwarg keys that should not be moved between devices (e.g. mutable cache state) |
|
||||||
|
|
||||||
|
### 3. Sequential CPU offloading
|
||||||
|
|
||||||
|
Moves individual layers to GPU one at a time during forward pass.
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe = DiffusionPipeline.from_pretrained("model_id", torch_dtype=torch.bfloat16)
|
||||||
|
pipe.enable_sequential_cpu_offload()
|
||||||
|
# Do NOT call pipe.to("cuda") first — saves minimal memory if you do
|
||||||
|
```
|
||||||
|
|
||||||
|
- **Memory savings**: Maximum — only one layer on GPU at a time
|
||||||
|
- **Speed cost**: Very high — many small transfers per forward pass
|
||||||
|
- **When to use**: Last resort when group offloading with streams isn't enough
|
||||||
|
- **Note**: Group offloading with `leaf_level` + `use_stream=True` is essentially the same idea but faster. Prefer that.
|
||||||
|
|
||||||
|
### 4. VAE slicing
|
||||||
|
|
||||||
|
Processes VAE encode/decode in slices along the batch dimension.
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe.vae.enable_slicing()
|
||||||
|
```
|
||||||
|
|
||||||
|
- **Memory savings**: Reduces VAE peak memory for batch sizes > 1
|
||||||
|
- **Speed cost**: Minimal
|
||||||
|
- **When to use**: When generating multiple images/videos in a batch
|
||||||
|
- **Note**: `AutoencoderKLWan` and `AsymmetricAutoencoderKL` don't support slicing.
|
||||||
|
- **API note**: The pipeline-level `pipe.enable_vae_slicing()` is deprecated since v0.40.0. Use `pipe.vae.enable_slicing()`.
|
||||||
|
|
||||||
|
### 5. VAE tiling
|
||||||
|
|
||||||
|
Processes VAE encode/decode in spatial tiles. This is a **VRAM optimization** — only use when the VAE decode/encode would OOM without it.
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe.vae.enable_tiling()
|
||||||
|
```
|
||||||
|
|
||||||
|
- **Memory savings**: Bounds VAE peak memory by tile size rather than full resolution
|
||||||
|
- **Speed cost**: Some overhead from tile overlap processing
|
||||||
|
- **When to use** (only when VAE decode would OOM):
|
||||||
|
- **Image models**: Typically needed above ~1.5 MP on ≤16 GB GPUs, or ~4 MP on ≤32 GB GPUs
|
||||||
|
- **Video models**: When `H × W × num_frames` is large relative to remaining VRAM after denoising
|
||||||
|
- **When NOT to use**: At standard resolutions where the VAE fits comfortably — tiling adds overhead for no benefit
|
||||||
|
- **Note**: `AutoencoderKLWan` and `AsymmetricAutoencoderKL` don't support tiling.
|
||||||
|
- **API note**: The pipeline-level `pipe.enable_vae_tiling()` is deprecated since v0.40.0. Use `pipe.vae.enable_tiling()`.
|
||||||
|
- **Tip for group offloading with streams**: If combining VAE tiling with group offloading (`use_stream=True`), do a dummy forward pass first to avoid device mismatch errors.
|
||||||
|
|
||||||
|
### 6. Attention slicing (legacy)
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
```
|
||||||
|
|
||||||
|
- Largely superseded by `torch_sdpa` and FlashAttention
|
||||||
|
- Still useful on very old GPUs without SDPA support
|
||||||
|
|
||||||
|
## Combining techniques
|
||||||
|
|
||||||
|
Compatible combinations:
|
||||||
|
- Group offloading (pipeline-level) + VAE tiling — good general setup
|
||||||
|
- Group offloading (pipeline-level, `exclude_modules=["small_component"]`) — keeps small models on GPU, offloads large ones
|
||||||
|
- Model CPU offloading + VAE tiling — simple and effective when the largest component fits in VRAM
|
||||||
|
- Layerwise casting + group offloading — maximum savings (see [layerwise-casting.md](layerwise-casting.md))
|
||||||
|
- Layerwise casting + model CPU offloading — also works
|
||||||
|
- Quantization + model CPU offloading — works well
|
||||||
|
- Per-component group offloading with different configs — e.g. `block_level` for transformer, `leaf_level` for VAE
|
||||||
|
|
||||||
|
**Incompatible combinations:**
|
||||||
|
- `enable_model_cpu_offload()` on a pipeline where ANY component has group offloading — raises ValueError
|
||||||
|
- `enable_sequential_cpu_offload()` on a pipeline where ANY component has group offloading — same error
|
||||||
|
|
||||||
|
## Debugging OOM
|
||||||
|
|
||||||
|
1. Check which stage OOMs: loading, encoding, denoising, or decoding
|
||||||
|
2. If OOM during `.to("cuda")` — the full pipeline doesn't fit. Use model CPU offloading or group offloading
|
||||||
|
3. If OOM during denoising with model CPU offloading — the transformer alone exceeds VRAM. Use layerwise casting (see [layerwise-casting.md](layerwise-casting.md)) or group offloading instead
|
||||||
|
4. If still OOM during VAE decode, add `pipe.vae.enable_tiling()`
|
||||||
|
5. Consider quantization (see [quantization.md](quantization.md)) as a complementary approach
|
||||||
72
.ai/skills/optimizations/torch-compile.md
Normal file
72
.ai/skills/optimizations/torch-compile.md
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
# torch.compile
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
`torch.compile` traces a model's forward pass and compiles it to optimized machine code (via Triton or other backends). For diffusers, it typically speeds up the denoising loop by 20-50% after a warmup period.
|
||||||
|
|
||||||
|
## Full model compilation
|
||||||
|
|
||||||
|
Compile individual components, not the whole pipeline:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
pipe = DiffusionPipeline.from_pretrained("model_id", torch_dtype=torch.bfloat16).to("cuda")
|
||||||
|
|
||||||
|
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
|
||||||
|
# Optionally compile the VAE decoder too
|
||||||
|
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="reduce-overhead", fullgraph=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
The first 1-3 inference calls are slow (compilation/warmup). Subsequent calls are fast. Always do a warmup run before benchmarking.
|
||||||
|
|
||||||
|
## Regional compilation (preferred)
|
||||||
|
|
||||||
|
Regional compilation compiles only the frequently repeated sub-modules (transformer blocks) instead of the whole model. It provides the same runtime speedup but with ~8-10x faster compile time and better compatibility with offloading.
|
||||||
|
|
||||||
|
Diffusers models declare their repeated blocks via the `_repeated_blocks` class attribute (a list of class name strings). Most modern transformers define this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# FluxTransformer defines:
|
||||||
|
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `compile_repeated_blocks()` to compile them:
|
||||||
|
|
||||||
|
```python
|
||||||
|
pipe = DiffusionPipeline.from_pretrained("model_id", torch_dtype=torch.bfloat16).to("cuda")
|
||||||
|
pipe.transformer.compile_repeated_blocks(fullgraph=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Always guard before calling** — raises `ValueError` if `_repeated_blocks` is empty or the named classes aren't found. Use this pattern universally, whether or not you're using offloading:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Works with or without enable_model_cpu_offload() / enable_group_offload()
|
||||||
|
if getattr(pipe.transformer, "_repeated_blocks", None):
|
||||||
|
pipe.transformer.compile_repeated_blocks(fullgraph=True)
|
||||||
|
else:
|
||||||
|
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
`torch.compile` is compatible with diffusers' offloading methods — the offloading hooks use `@torch.compiler.disable()` on device-transfer operations so they run natively outside the compiled graph. Regional compilation is preferred when combining with offloading because it avoids compiling the parts that interact with the hooks.
|
||||||
|
|
||||||
|
Models with `_repeated_blocks` defined include: Flux, Flux2, HunyuanVideo, LTX2Video, Wan, CogVideo, SD3, UNet2DConditionModel, and most other modern architectures.
|
||||||
|
|
||||||
|
## Compile modes
|
||||||
|
|
||||||
|
| Mode | Speed gain | Compile time | Notes |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `"default"` | Moderate | Fast | Safe starting point |
|
||||||
|
| `"reduce-overhead"` | Good | Moderate | Reduces Python overhead via CUDA graphs |
|
||||||
|
| `"max-autotune"` | Best | Very slow | Tries many kernel configs; best for repeated inference |
|
||||||
|
|
||||||
|
## `fullgraph=True`
|
||||||
|
|
||||||
|
Requires the entire forward pass to be compilable as a single graph. Most diffusers transformers support this. If you get a `torch._dynamo` graph break error, remove `fullgraph=True` to allow partial compilation.
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- **Dynamic shapes**: Changing resolution between calls triggers recompilation. Use `torch.compile(..., dynamic=True)` for variable resolutions, at some speed cost.
|
||||||
|
- **First call is slow**: Budget 1-3 minutes for initial compilation depending on model size.
|
||||||
|
- **Windows**: `reduce-overhead` and `max-autotune` modes may have issues. Use `"default"` if you hit errors.
|
||||||
2
setup.py
2
setup.py
@@ -146,6 +146,7 @@ _deps = [
|
|||||||
"phonemizer",
|
"phonemizer",
|
||||||
"opencv-python",
|
"opencv-python",
|
||||||
"timm",
|
"timm",
|
||||||
|
"flashpack",
|
||||||
]
|
]
|
||||||
|
|
||||||
# this is a lookup table with items like:
|
# this is a lookup table with items like:
|
||||||
@@ -250,6 +251,7 @@ extras["gguf"] = deps_list("gguf", "accelerate")
|
|||||||
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
|
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
|
||||||
extras["torchao"] = deps_list("torchao", "accelerate")
|
extras["torchao"] = deps_list("torchao", "accelerate")
|
||||||
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]")
|
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]")
|
||||||
|
extras["flashpack"] = deps_list("flashpack")
|
||||||
|
|
||||||
if os.name == "nt": # windows
|
if os.name == "nt": # windows
|
||||||
extras["flax"] = [] # jax is not supported on windows
|
extras["flax"] = [] # jax is not supported on windows
|
||||||
|
|||||||
@@ -53,4 +53,5 @@ deps = {
|
|||||||
"phonemizer": "phonemizer",
|
"phonemizer": "phonemizer",
|
||||||
"opencv-python": "opencv-python",
|
"opencv-python": "opencv-python",
|
||||||
"timm": "timm",
|
"timm": "timm",
|
||||||
|
"flashpack": "flashpack",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -540,7 +540,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
|||||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
|
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_version(">=", "0.12.3"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
|
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
|
|||||||
from ..quantizers.quantization_config import QuantizationMethod
|
from ..quantizers.quantization_config import QuantizationMethod
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
|
FLASHPACK_WEIGHTS_NAME,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
HF_ENABLE_PARALLEL_LOADING,
|
HF_ENABLE_PARALLEL_LOADING,
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
@@ -55,6 +56,7 @@ from ..utils import (
|
|||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_bitsandbytes_version,
|
is_bitsandbytes_version,
|
||||||
|
is_flashpack_available,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
is_torch_version,
|
is_torch_version,
|
||||||
logging,
|
logging,
|
||||||
@@ -673,6 +675,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
variant: str | None = None,
|
variant: str | None = None,
|
||||||
max_shard_size: int | str = "10GB",
|
max_shard_size: int | str = "10GB",
|
||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
|
use_flashpack: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -725,7 +728,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
" the logger on the traceback to understand the reason why the quantized model is not serializable."
|
" the logger on the traceback to understand the reason why the quantized model is not serializable."
|
||||||
)
|
)
|
||||||
|
|
||||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
weights_name = WEIGHTS_NAME
|
||||||
|
if use_flashpack:
|
||||||
|
weights_name = FLASHPACK_WEIGHTS_NAME
|
||||||
|
elif safe_serialization:
|
||||||
|
weights_name = SAFETENSORS_WEIGHTS_NAME
|
||||||
|
|
||||||
weights_name = _add_variant(weights_name, variant)
|
weights_name = _add_variant(weights_name, variant)
|
||||||
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||||
".safetensors", "{suffix}.safetensors"
|
".safetensors", "{suffix}.safetensors"
|
||||||
@@ -752,58 +760,74 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
# Save the model
|
# Save the model
|
||||||
state_dict = model_to_save.state_dict()
|
state_dict = model_to_save.state_dict()
|
||||||
|
|
||||||
# Save the model
|
if use_flashpack:
|
||||||
state_dict_split = split_torch_state_dict_into_shards(
|
if is_flashpack_available():
|
||||||
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
import flashpack
|
||||||
)
|
|
||||||
|
|
||||||
# Clean the folder from a previous save
|
|
||||||
if is_main_process:
|
|
||||||
for filename in os.listdir(save_directory):
|
|
||||||
if filename in state_dict_split.filename_to_tensors.keys():
|
|
||||||
continue
|
|
||||||
full_filename = os.path.join(save_directory, filename)
|
|
||||||
if not os.path.isfile(full_filename):
|
|
||||||
continue
|
|
||||||
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
|
||||||
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
|
||||||
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
|
||||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
|
||||||
if (
|
|
||||||
filename.startswith(weights_without_ext)
|
|
||||||
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
|
||||||
):
|
|
||||||
os.remove(full_filename)
|
|
||||||
|
|
||||||
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
|
||||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
|
||||||
filepath = os.path.join(save_directory, filename)
|
|
||||||
if safe_serialization:
|
|
||||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
|
||||||
# joyfulness), but for now this enough.
|
|
||||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
|
||||||
else:
|
else:
|
||||||
torch.save(shard, filepath)
|
logger.error(
|
||||||
|
"Saving a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
|
||||||
|
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
|
||||||
|
)
|
||||||
|
raise ImportError("Please install torch and flashpack to save a FlashPack checkpoint in PyTorch.")
|
||||||
|
|
||||||
if state_dict_split.is_sharded:
|
flashpack.serialization.pack_to_file(
|
||||||
index = {
|
state_dict_or_model=state_dict,
|
||||||
"metadata": state_dict_split.metadata,
|
destination_path=os.path.join(save_directory, weights_name),
|
||||||
"weight_map": state_dict_split.tensor_to_filename,
|
target_dtype=self.dtype,
|
||||||
}
|
|
||||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
|
||||||
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
|
||||||
# Save the index as well
|
|
||||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
|
||||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
|
||||||
f.write(content)
|
|
||||||
logger.info(
|
|
||||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
|
||||||
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
|
||||||
f"index located at {save_index_file}."
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path_to_weights = os.path.join(save_directory, weights_name)
|
# Save the model
|
||||||
logger.info(f"Model weights saved in {path_to_weights}")
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
|
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean the folder from a previous save
|
||||||
|
if is_main_process:
|
||||||
|
for filename in os.listdir(save_directory):
|
||||||
|
if filename in state_dict_split.filename_to_tensors.keys():
|
||||||
|
continue
|
||||||
|
full_filename = os.path.join(save_directory, filename)
|
||||||
|
if not os.path.isfile(full_filename):
|
||||||
|
continue
|
||||||
|
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
||||||
|
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
||||||
|
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
||||||
|
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||||
|
if (
|
||||||
|
filename.startswith(weights_without_ext)
|
||||||
|
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
||||||
|
):
|
||||||
|
os.remove(full_filename)
|
||||||
|
|
||||||
|
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
||||||
|
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||||
|
filepath = os.path.join(save_directory, filename)
|
||||||
|
if safe_serialization:
|
||||||
|
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||||
|
# joyfulness), but for now this enough.
|
||||||
|
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||||
|
else:
|
||||||
|
torch.save(shard, filepath)
|
||||||
|
|
||||||
|
if state_dict_split.is_sharded:
|
||||||
|
index = {
|
||||||
|
"metadata": state_dict_split.metadata,
|
||||||
|
"weight_map": state_dict_split.tensor_to_filename,
|
||||||
|
}
|
||||||
|
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||||
|
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||||
|
# Save the index as well
|
||||||
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||||
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
|
f.write(content)
|
||||||
|
logger.info(
|
||||||
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||||
|
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||||
|
f"index located at {save_index_file}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path_to_weights = os.path.join(save_directory, weights_name)
|
||||||
|
logger.info(f"Model weights saved in {path_to_weights}")
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
# Create a new empty model card and eventually tag it
|
# Create a new empty model card and eventually tag it
|
||||||
@@ -940,6 +964,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||||
|
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||||
|
If set to `True`, the model is loaded from `flashpack` weights.
|
||||||
|
flashpack_kwargs(`dict[str, Any]`, *optional*, defaults to `{}`):
|
||||||
|
Kwargs passed to
|
||||||
|
[`flashpack.deserialization.assign_from_file`](https://github.com/fal-ai/flashpack/blob/f1aa91c5cd9532a3dbf5bcc707ab9b01c274b76c/src/flashpack/deserialization.py#L408-L422)
|
||||||
|
|
||||||
|
|
||||||
> [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
|
> [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
|
||||||
with `hf > auth login`. You can also activate the special >
|
with `hf > auth login`. You can also activate the special >
|
||||||
@@ -984,6 +1014,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
dduf_entries: dict[str, DDUFEntry] | None = kwargs.pop("dduf_entries", None)
|
dduf_entries: dict[str, DDUFEntry] | None = kwargs.pop("dduf_entries", None)
|
||||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||||
parallel_config: ParallelConfig | ContextParallelConfig | None = kwargs.pop("parallel_config", None)
|
parallel_config: ParallelConfig | ContextParallelConfig | None = kwargs.pop("parallel_config", None)
|
||||||
|
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||||
|
flashpack_kwargs = kwargs.pop("flashpack_kwargs", {})
|
||||||
|
|
||||||
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
|
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
|
||||||
if is_parallel_loading_enabled and not low_cpu_mem_usage:
|
if is_parallel_loading_enabled and not low_cpu_mem_usage:
|
||||||
@@ -1212,30 +1244,37 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
subfolder=subfolder or "",
|
subfolder=subfolder or "",
|
||||||
dduf_entries=dduf_entries,
|
dduf_entries=dduf_entries,
|
||||||
)
|
)
|
||||||
elif use_safetensors:
|
else:
|
||||||
try:
|
if use_flashpack:
|
||||||
resolved_model_file = _get_model_file(
|
weights_name = FLASHPACK_WEIGHTS_NAME
|
||||||
pretrained_model_name_or_path,
|
elif use_safetensors:
|
||||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
weights_name = _add_variant(SAFETENSORS_WEIGHTS_NAME, variant)
|
||||||
cache_dir=cache_dir,
|
else:
|
||||||
force_download=force_download,
|
weights_name = None
|
||||||
proxies=proxies,
|
if weights_name is not None:
|
||||||
local_files_only=local_files_only,
|
try:
|
||||||
token=token,
|
resolved_model_file = _get_model_file(
|
||||||
revision=revision,
|
pretrained_model_name_or_path,
|
||||||
subfolder=subfolder,
|
weights_name=weights_name,
|
||||||
user_agent=user_agent,
|
cache_dir=cache_dir,
|
||||||
commit_hash=commit_hash,
|
force_download=force_download,
|
||||||
dduf_entries=dduf_entries,
|
proxies=proxies,
|
||||||
)
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
commit_hash=commit_hash,
|
||||||
|
dduf_entries=dduf_entries,
|
||||||
|
)
|
||||||
|
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
||||||
if not allow_pickle:
|
if not allow_pickle:
|
||||||
raise
|
raise
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
if resolved_model_file is None and not is_sharded:
|
if resolved_model_file is None and not is_sharded:
|
||||||
resolved_model_file = _get_model_file(
|
resolved_model_file = _get_model_file(
|
||||||
@@ -1275,6 +1314,44 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
with ContextManagers(init_contexts):
|
with ContextManagers(init_contexts):
|
||||||
model = cls.from_config(config, **unused_kwargs)
|
model = cls.from_config(config, **unused_kwargs)
|
||||||
|
|
||||||
|
if use_flashpack:
|
||||||
|
if is_flashpack_available():
|
||||||
|
import flashpack
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Loading a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
|
||||||
|
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
|
||||||
|
)
|
||||||
|
raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.")
|
||||||
|
|
||||||
|
if device_map is None:
|
||||||
|
logger.warning(
|
||||||
|
"`device_map` has not been provided for FlashPack, model will be on `cpu` - provide `device_map` to fully utilize "
|
||||||
|
"the benefit of FlashPack."
|
||||||
|
)
|
||||||
|
flashpack_device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = device_map[""]
|
||||||
|
if isinstance(device, str) and device in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
||||||
|
raise ValueError(
|
||||||
|
"FlashPack `device_map` should not be one of `auto`, `balanced`, `balanced_low_0`, `sequential`. Use a specific device instead, e.g., `device_map='cuda'` or `device_map='cuda:0'"
|
||||||
|
)
|
||||||
|
flashpack_device = torch.device(device) if not isinstance(device, torch.device) else device
|
||||||
|
|
||||||
|
flashpack.mixin.assign_from_file(
|
||||||
|
model=model,
|
||||||
|
path=resolved_model_file[0],
|
||||||
|
device=flashpack_device,
|
||||||
|
**flashpack_kwargs,
|
||||||
|
)
|
||||||
|
if dtype_orig is not None:
|
||||||
|
torch.set_default_dtype(dtype_orig)
|
||||||
|
if output_loading_info:
|
||||||
|
logger.warning("`output_loading_info` is not supported with FlashPack.")
|
||||||
|
return model, {}
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
if dtype_orig is not None:
|
if dtype_orig is not None:
|
||||||
torch.set_default_dtype(dtype_orig)
|
torch.set_default_dtype(dtype_orig)
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class ErnieImageTransformer2DModelOutput(BaseOutput):
|
|||||||
|
|
||||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
|
||||||
omega = 1.0 / (theta**scale)
|
omega = 1.0 / (theta**scale)
|
||||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||||
return out.float()
|
return out.float()
|
||||||
@@ -400,8 +400,8 @@ class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# AdaLN
|
# AdaLN
|
||||||
sample = self.time_proj(timestep.to(dtype))
|
sample = self.time_proj(timestep)
|
||||||
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
|
sample = sample.to(dtype=dtype)
|
||||||
c = self.time_embedding(sample)
|
c = self.time_embedding(sample)
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
||||||
t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)
|
t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||||
|
|||||||
@@ -877,10 +877,7 @@ class FluxPipeline(
|
|||||||
self.scheduler.config.get("max_shift", 1.15),
|
self.scheduler.config.get("max_shift", 1.15),
|
||||||
)
|
)
|
||||||
|
|
||||||
if XLA_AVAILABLE:
|
timestep_device = device
|
||||||
timestep_device = "cpu"
|
|
||||||
else:
|
|
||||||
timestep_device = device
|
|
||||||
timesteps, num_inference_steps = retrieve_timesteps(
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
self.scheduler,
|
self.scheduler,
|
||||||
num_inference_steps,
|
num_inference_steps,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from packaging import version
|
|||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
|
FLASHPACK_WEIGHTS_NAME,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||||
ONNX_WEIGHTS_NAME,
|
ONNX_WEIGHTS_NAME,
|
||||||
@@ -194,6 +195,7 @@ def filter_model_files(filenames):
|
|||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
ONNX_WEIGHTS_NAME,
|
ONNX_WEIGHTS_NAME,
|
||||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||||
|
FLASHPACK_WEIGHTS_NAME,
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
@@ -413,6 +415,9 @@ def get_class_obj_and_candidates(
|
|||||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||||
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
|
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
|
||||||
|
|
||||||
|
if class_name.startswith("FlashPack"):
|
||||||
|
class_name = class_name.removeprefix("FlashPack")
|
||||||
|
|
||||||
if is_pipeline_module:
|
if is_pipeline_module:
|
||||||
pipeline_module = getattr(pipelines, library_name)
|
pipeline_module = getattr(pipelines, library_name)
|
||||||
|
|
||||||
@@ -760,6 +765,7 @@ def load_sub_model(
|
|||||||
provider_options: Any,
|
provider_options: Any,
|
||||||
disable_mmap: bool,
|
disable_mmap: bool,
|
||||||
quantization_config: Any | None = None,
|
quantization_config: Any | None = None,
|
||||||
|
use_flashpack: bool = False,
|
||||||
):
|
):
|
||||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||||
from ..quantizers import PipelineQuantizationConfig
|
from ..quantizers import PipelineQuantizationConfig
|
||||||
@@ -838,6 +844,9 @@ def load_sub_model(
|
|||||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||||
loading_kwargs["use_safetensors"] = use_safetensors
|
loading_kwargs["use_safetensors"] = use_safetensors
|
||||||
|
|
||||||
|
if is_diffusers_model:
|
||||||
|
loading_kwargs["use_flashpack"] = use_flashpack
|
||||||
|
|
||||||
if from_flax:
|
if from_flax:
|
||||||
loading_kwargs["from_flax"] = True
|
loading_kwargs["from_flax"] = True
|
||||||
|
|
||||||
@@ -887,7 +896,7 @@ def load_sub_model(
|
|||||||
# else load from the root directory
|
# else load from the root directory
|
||||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||||
|
|
||||||
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
|
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict) and not use_flashpack:
|
||||||
# remove hooks
|
# remove hooks
|
||||||
remove_hook_from_module(loaded_sub_model, recurse=True)
|
remove_hook_from_module(loaded_sub_model, recurse=True)
|
||||||
needs_offloading_to_cpu = device_map[""] == "cpu"
|
needs_offloading_to_cpu = device_map[""] == "cpu"
|
||||||
@@ -1093,6 +1102,7 @@ def _get_ignore_patterns(
|
|||||||
allow_pickle: bool,
|
allow_pickle: bool,
|
||||||
use_onnx: bool,
|
use_onnx: bool,
|
||||||
is_onnx: bool,
|
is_onnx: bool,
|
||||||
|
use_flashpack: bool,
|
||||||
variant: str | None = None,
|
variant: str | None = None,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
if (
|
if (
|
||||||
@@ -1118,6 +1128,9 @@ def _get_ignore_patterns(
|
|||||||
if not use_onnx:
|
if not use_onnx:
|
||||||
ignore_patterns += ["*.onnx", "*.pb"]
|
ignore_patterns += ["*.onnx", "*.pb"]
|
||||||
|
|
||||||
|
elif use_flashpack:
|
||||||
|
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb", "*.msgpack"]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||||
|
|
||||||
|
|||||||
@@ -244,6 +244,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
variant: str | None = None,
|
variant: str | None = None,
|
||||||
max_shard_size: int | str | None = None,
|
max_shard_size: int | str | None = None,
|
||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
|
use_flashpack: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -341,6 +342,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||||
save_method_accept_variant = "variant" in save_method_signature.parameters
|
save_method_accept_variant = "variant" in save_method_signature.parameters
|
||||||
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
||||||
|
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
|
||||||
save_method_accept_peft_format = "save_peft_format" in save_method_signature.parameters
|
save_method_accept_peft_format = "save_peft_format" in save_method_signature.parameters
|
||||||
|
|
||||||
save_kwargs = {}
|
save_kwargs = {}
|
||||||
@@ -351,6 +353,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
if save_method_accept_max_shard_size and max_shard_size is not None:
|
if save_method_accept_max_shard_size and max_shard_size is not None:
|
||||||
# max_shard_size is expected to not be None in ModelMixin
|
# max_shard_size is expected to not be None in ModelMixin
|
||||||
save_kwargs["max_shard_size"] = max_shard_size
|
save_kwargs["max_shard_size"] = max_shard_size
|
||||||
|
if save_method_accept_flashpack:
|
||||||
|
save_kwargs["use_flashpack"] = use_flashpack
|
||||||
if save_method_accept_peft_format:
|
if save_method_accept_peft_format:
|
||||||
# Set save_peft_format=False for transformers>=5.0.0 compatibility
|
# Set save_peft_format=False for transformers>=5.0.0 compatibility
|
||||||
# In transformers 5.0.0+, the default save_peft_format=True adds "base_model.model" prefix
|
# In transformers 5.0.0+, the default save_peft_format=True adds "base_model.model" prefix
|
||||||
@@ -781,6 +785,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
use_onnx = kwargs.pop("use_onnx", None)
|
use_onnx = kwargs.pop("use_onnx", None)
|
||||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||||
quantization_config = kwargs.pop("quantization_config", None)
|
quantization_config = kwargs.pop("quantization_config", None)
|
||||||
|
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||||
|
|
||||||
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
|
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
|
||||||
@@ -1071,6 +1076,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
provider_options=provider_options,
|
provider_options=provider_options,
|
||||||
disable_mmap=disable_mmap,
|
disable_mmap=disable_mmap,
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
|
use_flashpack=use_flashpack,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||||
@@ -1576,6 +1582,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
|
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
|
||||||
option should only be set to `True` for repositories you trust and in which you have read the code, as
|
option should only be set to `True` for repositories you trust and in which you have read the code, as
|
||||||
it will execute code present on the Hub on your local machine.
|
it will execute code present on the Hub on your local machine.
|
||||||
|
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||||
|
If set to `True`, FlashPack weights will always be downloaded if present. If set to `False`, FlashPack
|
||||||
|
weights will never be downloaded.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`os.PathLike`:
|
`os.PathLike`:
|
||||||
@@ -1600,6 +1609,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||||
dduf_file: dict[str, DDUFEntry] | None = kwargs.pop("dduf_file", None)
|
dduf_file: dict[str, DDUFEntry] | None = kwargs.pop("dduf_file", None)
|
||||||
|
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||||
|
|
||||||
if dduf_file:
|
if dduf_file:
|
||||||
if custom_pipeline:
|
if custom_pipeline:
|
||||||
@@ -1719,6 +1729,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
allow_pickle,
|
allow_pickle,
|
||||||
use_onnx,
|
use_onnx,
|
||||||
pipeline_class._is_onnx,
|
pipeline_class._is_onnx,
|
||||||
|
use_flashpack,
|
||||||
variant,
|
variant,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ from .constants import (
|
|||||||
DEPRECATED_REVISION_ARGS,
|
DEPRECATED_REVISION_ARGS,
|
||||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||||
DIFFUSERS_LOAD_ID_FIELDS,
|
DIFFUSERS_LOAD_ID_FIELDS,
|
||||||
|
FLASHPACK_FILE_EXTENSION,
|
||||||
|
FLASHPACK_WEIGHTS_NAME,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
GGUF_FILE_EXTENSION,
|
GGUF_FILE_EXTENSION,
|
||||||
HF_ENABLE_PARALLEL_LOADING,
|
HF_ENABLE_PARALLEL_LOADING,
|
||||||
@@ -76,6 +78,7 @@ from .import_utils import (
|
|||||||
is_flash_attn_3_available,
|
is_flash_attn_3_available,
|
||||||
is_flash_attn_available,
|
is_flash_attn_available,
|
||||||
is_flash_attn_version,
|
is_flash_attn_version,
|
||||||
|
is_flashpack_available,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_ftfy_available,
|
is_ftfy_available,
|
||||||
is_gguf_available,
|
is_gguf_available,
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ ONNX_WEIGHTS_NAME = "model.onnx"
|
|||||||
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
||||||
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
|
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
|
||||||
SAFETENSORS_FILE_EXTENSION = "safetensors"
|
SAFETENSORS_FILE_EXTENSION = "safetensors"
|
||||||
|
FLASHPACK_WEIGHTS_NAME = "model.flashpack"
|
||||||
|
FLASHPACK_FILE_EXTENSION = "flashpack"
|
||||||
GGUF_FILE_EXTENSION = "gguf"
|
GGUF_FILE_EXTENSION = "gguf"
|
||||||
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
||||||
|
|||||||
@@ -230,6 +230,7 @@ _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_at
|
|||||||
_aiter_available, _aiter_version = _is_package_available("aiter", get_dist_name=True)
|
_aiter_available, _aiter_version = _is_package_available("aiter", get_dist_name=True)
|
||||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||||
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
|
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
|
||||||
|
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
|
||||||
_av_available, _av_version = _is_package_available("av")
|
_av_available, _av_version = _is_package_available("av")
|
||||||
|
|
||||||
|
|
||||||
@@ -361,6 +362,10 @@ def is_gguf_available():
|
|||||||
return _gguf_available
|
return _gguf_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_flashpack_available():
|
||||||
|
return _flashpack_available
|
||||||
|
|
||||||
|
|
||||||
def is_torchao_available():
|
def is_torchao_available():
|
||||||
return _torchao_available
|
return _torchao_available
|
||||||
|
|
||||||
|
|||||||
74
tests/others/test_flashpack.py
Normal file
74
tests/others/test_flashpack.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pathlib
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from diffusers import AutoPipelineForText2Image
|
||||||
|
from diffusers.models.auto_model import AutoModel
|
||||||
|
|
||||||
|
from ..testing_utils import is_torch_available, require_flashpack, require_torch_gpu
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class FlashPackTests(unittest.TestCase):
|
||||||
|
model_id: str = "hf-internal-testing/tiny-flux-pipe"
|
||||||
|
|
||||||
|
@require_flashpack
|
||||||
|
def test_save_load_model(self):
|
||||||
|
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
model.save_pretrained(temp_dir, use_flashpack=True)
|
||||||
|
self.assertTrue((pathlib.Path(temp_dir) / "model.flashpack").exists())
|
||||||
|
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True)
|
||||||
|
|
||||||
|
@require_flashpack
|
||||||
|
def test_save_load_pipeline(self):
|
||||||
|
pipeline = AutoPipelineForText2Image.from_pretrained(self.model_id)
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
pipeline.save_pretrained(temp_dir, use_flashpack=True)
|
||||||
|
self.assertTrue((pathlib.Path(temp_dir) / "transformer" / "model.flashpack").exists())
|
||||||
|
self.assertTrue((pathlib.Path(temp_dir) / "vae" / "model.flashpack").exists())
|
||||||
|
pipeline = AutoPipelineForText2Image.from_pretrained(temp_dir, use_flashpack=True)
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_flashpack
|
||||||
|
def test_load_model_device_str(self):
|
||||||
|
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
model.save_pretrained(temp_dir, use_flashpack=True)
|
||||||
|
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": "cuda"})
|
||||||
|
self.assertTrue(model.device.type == "cuda")
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_flashpack
|
||||||
|
def test_load_model_device(self):
|
||||||
|
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
model.save_pretrained(temp_dir, use_flashpack=True)
|
||||||
|
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": torch.device("cuda")})
|
||||||
|
self.assertTrue(model.device.type == "cuda")
|
||||||
|
|
||||||
|
@require_flashpack
|
||||||
|
def test_load_model_device_auto(self):
|
||||||
|
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
model.save_pretrained(temp_dir, use_flashpack=True)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": "auto"})
|
||||||
@@ -34,6 +34,7 @@ from diffusers.utils.import_utils import (
|
|||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_compel_available,
|
is_compel_available,
|
||||||
|
is_flashpack_available,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_gguf_available,
|
is_gguf_available,
|
||||||
is_kernels_available,
|
is_kernels_available,
|
||||||
@@ -737,6 +738,13 @@ def require_accelerate(test_case):
|
|||||||
return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
|
return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_flashpack(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires flashpack. These tests are skipped when flashpack isn't installed.
|
||||||
|
"""
|
||||||
|
return pytest.mark.skipif(not is_flashpack_available(), reason="test requires flashpack")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_peft_version_greater(peft_version):
|
def require_peft_version_greater(peft_version):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
|
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
|
||||||
|
|||||||
Reference in New Issue
Block a user