Compare commits

...

20 Commits

Author SHA1 Message Date
sayakpaul
f60afe5cba error out for the offload to disk option. 2026-03-30 13:19:12 +05:30
Sayak Paul
06509796dd Merge branch 'main' into fix-torchao-groupoffloading 2026-03-30 11:48:22 +05:30
Steven Liu
a93f7f137a [docs] refactor model skill (#13334)
* refactor

* feedback

* feedback

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-29 23:13:52 -07:00
Sayak Paul
59c1b2534a Merge branch 'main' into fix-torchao-groupoffloading 2026-03-30 11:24:51 +05:30
sayakpaul
7eaeb99fcd address feedback. 2026-03-30 11:24:40 +05:30
Sayak Paul
10ec3040a2 [ci] move to assert instead of self.Assert* (#13366)
move to assert instead of self.Assert*
2026-03-30 11:09:14 +05:30
Sayak Paul
867192364c Merge branch 'main' into fix-torchao-groupoffloading 2026-03-30 10:53:47 +05:30
Howard Zhang
f2be8bd6b3 change minimum version guard for torchao to 0.15.0 (#13355) 2026-03-28 09:11:51 +05:30
Sayak Paul
a8cef0740a Merge branch 'main' into fix-torchao-groupoffloading 2026-03-27 21:16:15 +05:30
Sayak Paul
7da22b9db5 [ci] include checkout step in claude review workflow (#13352)
up
2026-03-27 17:28:31 +05:30
Howard Zhang
1fe2125802 remove str option for quantization config in torchao (#13291)
* remove str option for quantization config in torchao

* Apply style fixes

* minor fixes

* Added AOBaseConfig docs to torchao.md

* minor fixes for removing str option torchao

* minor change to add back int and uint check

* minor fixes

* minor fixes to tests

* Update tests/quantization/torchao/test_torchao.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/quantization/torchao.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update tests/quantization/torchao/test_torchao.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* version=2 update to test_torchao.py

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-27 08:52:37 +05:30
dg845
7298f5be93 Update LTX-2 Docs to Cover LTX-2.3 Models (#13337)
* Update LTX-2 docs to cover multimodal guidance and prompt enhancement

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply reviewer feedback

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-03-26 17:51:29 -07:00
Sayak Paul
b757035df6 fix claude workflow to include id-token with write. (#13338) 2026-03-26 15:39:10 +05:30
Sayak Paul
70067734a2 Merge branch 'main' into fix-torchao-groupoffloading 2026-03-26 11:29:51 +05:30
Sayak Paul
6125a4f540 Merge branch 'main' into fix-torchao-groupoffloading 2026-03-25 08:07:01 +05:30
Sayak Paul
d2666a9d0a Merge branch 'main' into fix-torchao-groupoffloading 2026-03-24 09:06:42 +05:30
sayakpaul
9b9e2e17a6 up 2026-03-23 11:22:36 +05:30
sayakpaul
1a959dc26f switch to swap_tensors. 2026-03-23 10:56:16 +05:30
Sayak Paul
8797398d3b Merge branch 'main' into fix-torchao-groupoffloading 2026-03-23 09:05:37 +05:30
sayakpaul
019a9deafb fix group offloading when using torchao 2026-03-17 10:40:03 +05:30
15 changed files with 698 additions and 711 deletions

View File

@@ -10,24 +10,34 @@ Strive to write code as simple and explicit as possible.
---
### Dependencies
- No new mandatory dependency without discussion (e.g. `einops`)
- Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`
## Code formatting
- `make style` and `make fix-copies` should be run as the final step before opening a PR
### Copied Code
- Many classes are kept in sync with a source via a `# Copied from ...` header comment
- Do not edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source
- Remove the header to intentionally break the link
### Models
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
- See the **model-integration** skill for the attention pattern, pipeline rules, test setup instructions, and other important details.
- See [models.md](models.md) for model conventions, attention pattern, implementation rules, dependencies, and gotchas.
- See the [model-integration](./skills/model-integration/SKILL.md) skill for the full integration workflow, file structure, test setup, and other details.
### Pipelines & Schedulers
- Pipelines inherit from `DiffusionPipeline`
- Schedulers use `SchedulerMixin` with `ConfigMixin`
- Use `@torch.no_grad()` on pipeline `__call__`
- Support `output_type="latent"` for skipping VAE decode
- Support `generator` parameter for reproducibility
- Use `self.progress_bar(timesteps)` for progress tracking
- Don't subclass an existing pipeline for a variant — DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`)
## Skills
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents.
Available skills: **model-integration** (adding/converting pipelines), **parity-testing** (debugging numerical parity).
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents. Available skills include:
- [model-integration](./skills/model-integration/SKILL.md) (adding/converting pipelines)
- [parity-testing](./skills/parity-testing/SKILL.md) (debugging numerical parity).

76
.ai/models.md Normal file
View File

@@ -0,0 +1,76 @@
# Model conventions and rules
Shared reference for model-related conventions, patterns, and gotchas.
Linked from `AGENTS.md`, `skills/model-integration/SKILL.md`, and `review-rules.md`.
## Coding style
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
- No new mandatory dependency without discussion (e.g. `einops`). Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`.
## Common model conventions
- Models use `ModelMixin` with `register_to_config` for config serialization
## Attention pattern
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
```python
# transformer_mymodel.py
class MyModelAttnProcessor:
_attention_backend = None
_parallel_config = None
def __call__(self, attn, hidden_states, attention_mask=None, ...):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# reshape, apply rope, etc.
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
return attn.to_out[0](hidden_states)
class MyModelAttention(nn.Module, AttentionModuleMixin):
_default_processor_cls = MyModelAttnProcessor
_available_processors = [MyModelAttnProcessor]
def __init__(self, query_dim, heads=8, dim_head=64, ...):
super().__init__()
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
self.set_processor(MyModelAttnProcessor())
def forward(self, hidden_states, attention_mask=None, **kwargs):
return self.processor(self, hidden_states, attention_mask, **kwargs)
```
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
## Gotchas
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.

View File

@@ -3,8 +3,8 @@
Review-specific rules for Claude. Focus on correctness — style is handled by ruff.
Before reviewing, read and apply the guidelines in:
- [AGENTS.md](AGENTS.md) — coding style, dependencies, copied code, model conventions
- [skills/model-integration/SKILL.md](skills/model-integration/SKILL.md) — attention pattern, pipeline rules, implementation checklist, gotchas
- [AGENTS.md](AGENTS.md) — coding style, copied code
- [models.md](models.md) — model conventions, attention pattern, implementation rules, dependencies, gotchas
- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities
- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.)

View File

@@ -65,89 +65,19 @@ docs/source/en/api/
- [ ] Run `make style` and `make quality`
- [ ] Test parity with reference implementation (see `parity-testing` skill)
### Attention pattern
### Model conventions, attention pattern, and implementation rules
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
See [../../models.md](../../models.md) for the attention pattern, implementation rules, common conventions, dependencies, and gotchas. These apply to all model work.
```python
# transformer_mymodel.py
### Model integration specific rules
class MyModelAttnProcessor:
_attention_backend = None
_parallel_config = None
def __call__(self, attn, hidden_states, attention_mask=None, ...):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# reshape, apply rope, etc.
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
return attn.to_out[0](hidden_states)
class MyModelAttention(nn.Module, AttentionModuleMixin):
_default_processor_cls = MyModelAttnProcessor
_available_processors = [MyModelAttnProcessor]
def __init__(self, query_dim, heads=8, dim_head=64, ...):
super().__init__()
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
self.set_processor(MyModelAttnProcessor())
def forward(self, hidden_states, attention_mask=None, **kwargs):
return self.processor(self, hidden_states, attention_mask, **kwargs)
```
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
### Implementation rules
1. **Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
2. **Pipelines must inherit from `DiffusionPipeline`.** Consult implementations in `src/diffusers/pipelines` in case you need references.
3. **Don't subclass an existing pipeline for a variant.** DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`).
**Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
### Test setup
- Slow tests gated with `@slow` and `RUN_SLOW=1`
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.
### Common diffusers conventions
- Pipelines inherit from `DiffusionPipeline`
- Models use `ModelMixin` with `register_to_config` for config serialization
- Schedulers use `SchedulerMixin` with `ConfigMixin`
- Use `@torch.no_grad()` on pipeline `__call__`
- Support `output_type="latent"` for skipping VAE decode
- Support `generator` parameter for reproducibility
- Use `self.progress_bar(timesteps)` for progress tracking
## Gotchas
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
---
## Modular Pipeline Conversion

View File

@@ -10,6 +10,7 @@ permissions:
contents: write
pull-requests: write
issues: read
id-token: write
jobs:
claude-review:
@@ -31,6 +32,9 @@ jobs:
)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 1
- uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}

View File

@@ -41,16 +41,15 @@ The quantized CogVideoX 5B model below requires ~16GB of VRAM.
```py
import torch
from diffusers import CogVideoXPipeline, AutoModel
from diffusers import CogVideoXPipeline, AutoModel, TorchAoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video
from torchao.quantization import Int8WeightOnlyConfig
# quantize weights to int8 with torchao
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="torchao",
quant_kwargs={"quant_type": "int8wo"},
components_to_quantize="transformer"
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig())}
)
# fp8 layerwise weight-casting

View File

@@ -18,7 +18,7 @@
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
[LTX-2](https://hf.co/papers/2601.03233) is a DiT-based foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
@@ -293,6 +293,7 @@ import torch
from diffusers import LTX2ConditionPipeline
from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT
from diffusers.utils import load_image, load_video
device = "cuda"
@@ -315,19 +316,6 @@ prompt = (
"landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the "
"solitude and beauty of a winter drive through a mountainous region."
)
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
cond_video = load_video(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
@@ -343,7 +331,7 @@ frame_rate = 24.0
video, audio = pipe(
conditions=conditions,
prompt=prompt,
negative_prompt=negative_prompt,
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
width=width,
height=height,
num_frames=121,
@@ -366,6 +354,154 @@ encode_video(
Because the conditioning is done via latent frames, the 8 data space frames corresponding to the specified latent frame for an image condition will tend to be static.
## Multimodal Guidance
LTX-2.X pipelines support multimodal guidance. It is composed of three terms, all using a CFG-style update rule:
1. Classifier-Free Guidance (CFG): standard [CFG](https://huggingface.co/papers/2207.12598) where the perturbed ("weaker") output is generated using the negative prompt.
2. Spatio-Temporal Guidance (STG): [STG](https://huggingface.co/papers/2411.18664) moves away from a perturbed output created from short-cutting self-attention operations and substitutes in the attention values instead. The idea is that this creates sharper videos and better spatiotemporal consistency.
3. Modality Isolation Guidance: moves away from a perturbed output created from disabling cross-modality (audio-to-video and video-to-audio) cross attention. This guidance is more specific to [LTX-2.X](https://huggingface.co/papers/2601.03233) models, with the idea that this produces better consistency between the generated audio and video.
These are controlled by the `guidance_scale`, `stg_scale`, and `modality_scale` arguments and can be set separately for video and audio. Additionally, for STG the transformer block indices where self-attention is skipped needs to be specified via the `spatio_temporal_guidance_blocks` argument. The LTX-2.X pipelines also support [guidance rescaling](https://huggingface.co/papers/2305.08891) to help reduce over-exposure, which can be a problem when the guidance scales are set to high values.
```py
import torch
from diffusers import LTX2ImageToVideoPipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT
from diffusers.utils import load_image
device = "cuda"
width = 768
height = 512
random_seed = 42
frame_rate = 24.0
generator = torch.Generator(device).manual_seed(random_seed)
model_path = "dg845/LTX-2.3-Diffusers"
pipe = LTX2ImageToVideoPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload(device=device)
pipe.vae.enable_tiling()
prompt = (
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
"gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
"before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
"fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
"shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
"smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
"distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
"breath-taking, movie-like shot."
)
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg",
)
video, audio = pipe(
image=image,
prompt=prompt,
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=30,
guidance_scale=3.0, # Recommended LTX-2.3 guidance parameters
stg_scale=1.0, # Note that 0.0 (not 1.0) means that STG is disabled (all other guidance is disabled at 1.0)
modality_scale=3.0,
guidance_rescale=0.7,
audio_guidance_scale=7.0, # Note that a higher CFG guidance scale is recommended for audio
audio_stg_scale=1.0,
audio_modality_scale=3.0,
audio_guidance_rescale=0.7,
spatio_temporal_guidance_blocks=[28],
use_cross_timestep=True,
generator=generator,
output_type="np",
return_dict=False,
)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_3_i2v_stage_1.mp4",
)
```
## Prompt Enhancement
The LTX-2.X models are sensitive to prompting style. Refer to the [official prompting guide](https://ltx.io/model/model-blog/prompting-guide-for-ltx-2) for recommendations on how to write a good prompt. Using prompt enhancement, where the supplied prompts are enhanced using the pipeline's text encoder (by default a [Gemma 3](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized) model) given a system prompt, can also improve sample quality. The optional `processor` pipeline component needs to be present to use prompt enhancement. Enable prompt enhancement by supplying a `system_prompt` argument:
```py
import torch
from transformers import Gemma3Processor
from diffusers import LTX2Pipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT, T2V_DEFAULT_SYSTEM_PROMPT
device = "cuda"
width = 768
height = 512
random_seed = 42
frame_rate = 24.0
generator = torch.Generator(device).manual_seed(random_seed)
model_path = "dg845/LTX-2.3-Diffusers"
pipe = LTX2Pipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload(device=device)
pipe.vae.enable_tiling()
if getattr(pipe, "processor", None) is None:
processor = Gemma3Processor.from_pretrained("google/gemma-3-12b-it-qat-q4_0-unquantized")
pipe.processor = processor
prompt = (
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
"gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
"before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
"fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
"shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
"smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
"distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
"breath-taking, movie-like shot."
)
video, audio = pipe(
prompt=prompt,
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=30,
guidance_scale=3.0,
stg_scale=1.0,
modality_scale=3.0,
guidance_rescale=0.7,
audio_guidance_scale=7.0,
audio_stg_scale=1.0,
audio_modality_scale=3.0,
audio_guidance_rescale=0.7,
spatio_temporal_guidance_blocks=[28],
use_cross_timestep=True,
system_prompt=T2V_DEFAULT_SYSTEM_PROMPT,
generator=generator,
output_type="np",
return_dict=False,
)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_3_t2v_stage_1.mp4",
)
```
## LTX2Pipeline
[[autodoc]] LTX2Pipeline

View File

@@ -29,24 +29,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
from torchao.quantization import Int8WeightOnlyConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
```
For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.
```py
import torch
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig("int8wo")}
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128, version=2))}
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
@@ -91,18 +74,15 @@ Weight-only quantization stores the model weights in a specific low-bit data typ
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
The quantization methods supported are as follows:
Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods. An exhaustive list of configuration options are available [here](https://docs.pytorch.org/ao/main/workflows/inference.html#inference-workflows).
| **Category** | **Full Function Names** | **Shorthands** |
|--------------|-------------------------|----------------|
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8dq_e4m3_tensor`, `float8dq_e4m3_row` |
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
Some example popular quantization configurations are as follows:
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
| **Category** | **Configuration Classes** |
|---|---|
| **Integer quantization** | [`Int4WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int4WeightOnlyConfig.html), [`Int8WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int8WeightOnlyConfig.html), [`Int8DynamicActivationInt8WeightConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int8DynamicActivationInt8WeightConfig.html) |
| **Floating point 8-bit quantization** | [`Float8WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Float8WeightOnlyConfig.html), [`Float8DynamicActivationFloat8WeightConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Float8DynamicActivationFloat8WeightConfig.html) |
| **Unsigned integer quantization** | [`IntxWeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.IntxWeightOnlyConfig.html) |
## Serializing and Deserializing quantized models
@@ -111,8 +91,9 @@ To serialize a quantized model in a given dtype, first load the model with the d
```python
import torch
from diffusers import AutoModel, TorchAoConfig
from torchao.quantization import Int8WeightOnlyConfig
quantization_config = TorchAoConfig("int8wo")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
transformer = AutoModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
@@ -137,18 +118,19 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4` weight-only, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
```python
import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
from torchao.quantization import IntxWeightOnlyConfig
# Serialize the model
transformer = AutoModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
quantization_config=TorchAoConfig(IntxWeightOnlyConfig(dtype=torch.uint4)),
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")

View File

@@ -22,7 +22,7 @@ from typing import Set
import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
@@ -35,6 +35,54 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
if not is_torchao_available():
return False
from torchao.utils import TorchAOBaseTensor
return isinstance(tensor, TorchAOBaseTensor)
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
cls = type(tensor)
names = list(getattr(cls, "tensor_data_names", []))
for attr_name in getattr(cls, "optional_tensor_data_names", []):
if getattr(tensor, attr_name, None) is not None:
names.append(attr_name)
return names
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
that any dict keyed by `id(param)` remains valid.
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
"""
torch.utils.swap_tensors(param, source)
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
`cpu_param_dict`).
"""
for attr_name in _get_torchao_inner_tensor_names(source):
setattr(param, attr_name, getattr(source, attr_name))
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
"""Record stream for all internal tensors of a TorchAO parameter."""
for attr_name in _get_torchao_inner_tensor_names(param):
getattr(param, attr_name).record_stream(stream)
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
@@ -124,6 +172,13 @@ class ModuleGroup:
else torch.cuda
)
@staticmethod
def _to_cpu(tensor, low_cpu_mem_usage):
# For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes
# (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly.
t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu()
return t if low_cpu_mem_usage else t.pin_memory()
def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
@@ -131,17 +186,15 @@ class ModuleGroup:
for module in self.modules:
for param in module.parameters():
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
for buffer in module.buffers():
cpu_param_dict[buffer] = (
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
)
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
for param in self.parameters:
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
for buffer in self.buffers:
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
return cpu_param_dict
@@ -157,9 +210,16 @@ class ModuleGroup:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if _is_torchao_tensor(tensor):
_swap_torchao_tensor(tensor, moved)
else:
tensor.data = moved
if self.record_stream:
tensor.data.record_stream(default_stream)
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
else:
tensor.data.record_stream(default_stream)
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules:
@@ -178,7 +238,19 @@ class ModuleGroup:
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, default_stream)
def _check_disk_offload_torchao(self):
all_tensors = list(self.tensor_to_key.keys())
has_torchao = any(_is_torchao_tensor(t) for t in all_tensors)
if has_torchao:
raise ValueError(
"Disk offloading is not supported for TorchAO quantized tensors because safetensors "
"cannot serialize TorchAO subclass tensors. Use memory offloading instead by not "
"setting `offload_to_disk_path`."
)
def _onload_from_disk(self):
self._check_disk_offload_torchao()
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
@@ -221,6 +293,8 @@ class ModuleGroup:
self._process_tensors_from_modules(None)
def _offload_to_disk(self):
self._check_disk_offload_torchao()
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
@@ -245,18 +319,35 @@ class ModuleGroup:
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
if _is_torchao_tensor(buffer):
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
else:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(param):
moved = param.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(param, moved)
else:
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(buffer):
moved = buffer.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(buffer, moved)
else:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):

View File

@@ -1,6 +1,155 @@
# Copyright 2026 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Pre-trained sigma values for distilled model are taken from
# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]
# Reduced schedule for super-resolution stage 2 (subset of distilled values)
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875]
# Default negative prompt from
# https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py#L131-L143
DEFAULT_NEGATIVE_PROMPT = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
# System prompts for prompt enhancement
# https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt#L1
# Disable line-too-long rule in ruff to keep the prompts exactly the same (e.g. in terms of newlines)
# Supported in ruff>=0.15.0
# ruff: disable[E501]
T2V_DEFAULT_SYSTEM_PROMPT = """
You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed
video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
#### Guidelines
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions,
actions, camera movement, audio).
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural
movements.
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested).
Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g.,
"ambient sound is present").
- Speech (only when requested):
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with
voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
- Specify language if not English and accent if relevant.
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if
unspecified. Omit if unclear.
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
#### Important notes:
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is
requested.
- Camera motion: DO NOT invent camera motion unless requested by the user.
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological
scene description.
- Format: DO NOT start your response with special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits
or introduce new elements. Add/enhance audio descriptions if missing.
#### Output Format (Strict):
- Single continuous paragraph in natural language (English).
- NO titles, headings, prefaces, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video
generation.
#### Example Input: "A woman at a coffee shop talking on the phone" Output: Style: realistic with cinematic lighting.
In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the
window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone
to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle
clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a
soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd
love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her
chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully,
lowering the phone.
"""
# ruff: enable[E501]
# ruff: disable[E501]
I2V_DEFAULT_SYSTEM_PROMPT = """
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and
user Raw Input Prompt, generate a prompt to guide video generation from that image.
#### Guidelines:
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in
conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image
to user's scene).
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause
scene cuts.
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural
movements.
- Chronological flow: Use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio
intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when
requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The
tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation
mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should
include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just
saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room
underscores his animated speech.")
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
#### Important notes:
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion
only if specified in the input.
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style
(optional) and chronological scene description.
- Format: Never start output with punctuation marks or special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio
descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
#### Output Format (Strict):
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
#### Example output: Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a
cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the
counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine
hissing softly blends with gentle background chatter and the light clinking of cups on saucers.
"""
# ruff: enable[E501]

View File

@@ -23,20 +23,17 @@ https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e
from __future__ import annotations
import copy
import dataclasses
import importlib.metadata
import inspect
import json
import os
import warnings
from dataclasses import dataclass, is_dataclass
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Any, Callable
from packaging import version
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
from ..utils import deprecate, is_torch_available, is_torchao_version, logging
if is_torch_available():
@@ -53,16 +50,6 @@ class QuantizationMethod(str, Enum):
MODELOPT = "modelopt"
if is_torchao_available():
from torchao.quantization.quant_primitives import MappingType
class TorchAoJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, MappingType):
return obj.name
return super().default(obj)
@dataclass
class QuantizationConfigMixin:
"""
@@ -446,49 +433,21 @@ class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.
Args:
quant_type (`str` | AOBaseConfig):
The type of quantization we want to use, currently supporting:
- **Integer quantization:**
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
`int8_weight_only`, `int8_dynamic_activation_int8_weight`
- Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq`
- **Floating point 8-bit quantization:**
- Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`,
`float8_static_activation_float8_weight`
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
`float8_e4m3_tensor`, `float8_e4m3_row`,
- **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0)
- Full function names: `fpx_weight_only`
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
be satisfied for a given shorthand notation.
- **Unsigned Integer quantization:**
- Full function names: `uintx_weight_only`
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
- An AOBaseConfig instance: for more advanced configuration options.
quant_type (`AOBaseConfig`):
An `AOBaseConfig` subclass instance specifying the quantization type. See the [torchao
documentation](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) for
available config classes (e.g. `Int4WeightOnlyConfig`, `Int8WeightOnlyConfig`, `Float8WeightOnlyConfig`,
`Float8DynamicActivationFloat8WeightConfig`, etc.).
modules_to_not_convert (`list[str]`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision.
kwargs (`dict[str, Any]`, *optional*):
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization
supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and
documentation of arguments can be found in
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
Example:
```python
from diffusers import FluxTransformer2DModel, TorchAoConfig
# AOBaseConfig-based configuration
from torchao.quantization import Int8WeightOnlyConfig
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
# String-based config
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
@@ -500,7 +459,7 @@ class TorchAoConfig(QuantizationConfigMixin):
def __init__(
self,
quant_type: str | "AOBaseConfig", # noqa: F821
quant_type: "AOBaseConfig", # noqa: F821
modules_to_not_convert: list[str] | None = None,
**kwargs,
) -> None:
@@ -508,102 +467,39 @@ class TorchAoConfig(QuantizationConfigMixin):
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
# When we load from serialized config, "quant_type_kwargs" will be the key
if "quant_type_kwargs" in kwargs:
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
else:
self.quant_type_kwargs = kwargs
self.post_init()
def post_init(self):
if not isinstance(self.quant_type, str):
if is_torchao_version("<=", "0.9.0"):
raise ValueError(
f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
)
if is_torchao_version("<", "0.15.0"):
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")
from torchao.quantization.quant_api import AOBaseConfig
from torchao.quantization.quant_api import AOBaseConfig
if not isinstance(self.quant_type, AOBaseConfig):
raise TypeError(f"quant_type must be a AOBaseConfig instance, got {type(self.quant_type).__name__}")
elif isinstance(self.quant_type, str):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floatx_quant_type = self.quant_type.startswith("fp")
is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type
if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
)
elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"):
raise ValueError(
f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. "
f"Please downgrade to torchao <= 0.14.1 to use this quantization type."
)
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
signature = inspect.signature(method)
all_kwargs = {
param.name
for param in signature.parameters.values()
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
}
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
if len(unsupported_kwargs) > 0:
raise ValueError(
f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
)
if not isinstance(self.quant_type, AOBaseConfig):
raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}")
def to_dict(self):
"""Convert configuration to a dictionary."""
d = super().to_dict()
if isinstance(self.quant_type, str):
# Handle layout serialization if present
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
if is_dataclass(d["quant_type_kwargs"]["layout"]):
d["quant_type_kwargs"]["layout"] = [
d["quant_type_kwargs"]["layout"].__class__.__name__,
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
]
if isinstance(d["quant_type_kwargs"]["layout"], list):
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs"
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
else:
raise ValueError("layout must be a list")
else:
# Handle AOBaseConfig serialization
from torchao.core.config import config_to_dict
# Handle AOBaseConfig serialization
from torchao.core.config import config_to_dict
# For now we assume there is 1 config per Transformer, however in the future
# We may want to support a config per fqn.
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
# For now we assume there is 1 config per Transformer, however in the future
# we may want to support a config per fqn.
# See: https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.quantize_.html
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
return d
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
"""Create configuration from a dictionary."""
if not is_torchao_version(">", "0.9.0"):
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
if not is_torchao_version(">=", "0.15.0"):
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")
if isinstance(quant_type, str):
return cls(quant_type=quant_type, **config_dict)
# Check if we only have one key which is "default"
# In the future we may update this
assert len(quant_type) == 1 and "default" in quant_type, (
@@ -618,210 +514,13 @@ class TorchAoConfig(QuantizationConfigMixin):
return cls(quant_type=quant_type, **config_dict)
@classmethod
def _get_torchao_quant_type_to_method(cls):
r"""
Returns supported torchao quantization types with all commonly used notations.
"""
if is_torchao_available():
# TODO(aryan): Support sparsify
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)
if is_torchao_version("<=", "0.14.1"):
from torchao.quantization import fpx_weight_only
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
from torchao.quantization.observer import PerRow, PerTensor
def generate_float8dq_types(dtype: torch.dtype):
name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3"
types = {}
for granularity_cls in [PerTensor, PerRow]:
# Note: Activation and Weights cannot have different granularities
granularity_name = "tensor" if granularity_cls is PerTensor else "row"
types[f"float8dq_{name}_{granularity_name}"] = partial(
float8_dynamic_activation_float8_weight,
activation_dtype=dtype,
weight_dtype=dtype,
granularity=(granularity_cls(), granularity_cls()),
)
return types
def generate_fpx_quantization_types(bits: int):
if is_torchao_version("<=", "0.14.1"):
types = {}
for ebits in range(1, bits):
mbits = bits - ebits - 1
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
non_sign_bits = bits - 1
default_ebits = (non_sign_bits + 1) // 2
default_mbits = non_sign_bits - default_ebits
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
return types
else:
raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0")
INT4_QUANTIZATION_TYPES = {
# int4 weight + bfloat16/float16 activation
"int4wo": int4_weight_only,
"int4_weight_only": int4_weight_only,
# int4 weight + int8 activation
"int4dq": int8_dynamic_activation_int4_weight,
"int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight,
}
INT8_QUANTIZATION_TYPES = {
# int8 weight + bfloat16/float16 activation
"int8wo": int8_weight_only,
"int8_weight_only": int8_weight_only,
# int8 weight + int8 activation
"int8dq": int8_dynamic_activation_int8_weight,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
}
# TODO(aryan): handle torch 2.2/2.3
FLOATX_QUANTIZATION_TYPES = {
# float8_e5m2 weight + bfloat16/float16 activation
"float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
"float8_weight_only": float8_weight_only,
"float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
# float8_e4m3 weight + bfloat16/float16 activation
"float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
# float8_e5m2 weight + float8 activation (dynamic)
"float8dq": float8_dynamic_activation_float8_weight,
"float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight,
# ===== Matrix multiplication is not supported in float8_e5m2 so the following errors out.
# However, changing activation_dtype=torch.float8_e4m3 might work here =====
# "float8dq_e5m2": partial(
# float8_dynamic_activation_float8_weight,
# activation_dtype=torch.float8_e5m2,
# weight_dtype=torch.float8_e5m2,
# ),
# **generate_float8dq_types(torch.float8_e5m2),
# ===== =====
# float8_e4m3 weight + float8 activation (dynamic)
"float8dq_e4m3": partial(
float8_dynamic_activation_float8_weight,
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
),
**generate_float8dq_types(torch.float8_e4m3fn),
# float8 weight + float8 activation (static)
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
}
if is_torchao_version("<=", "0.14.1"):
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7))
UINTX_QUANTIZATION_DTYPES = {
"uintx_weight_only": uintx_weight_only,
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),
"uint2wo": partial(uintx_weight_only, dtype=torch.uint2),
"uint3wo": partial(uintx_weight_only, dtype=torch.uint3),
"uint4wo": partial(uintx_weight_only, dtype=torch.uint4),
"uint5wo": partial(uintx_weight_only, dtype=torch.uint5),
"uint6wo": partial(uintx_weight_only, dtype=torch.uint6),
"uint7wo": partial(uintx_weight_only, dtype=torch.uint7),
# "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported
}
QUANTIZATION_TYPES = {}
QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
if cls._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
return QUANTIZATION_TYPES
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
)
@staticmethod
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
if major == 8:
return minor >= 9
return major >= 9
elif torch.xpu.is_available():
return True
else:
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
def get_apply_tensor_subclass(self):
"""Create the appropriate quantization method based on configuration."""
if not isinstance(self.quant_type, str):
return self.quant_type
else:
methods = self._get_torchao_quant_type_to_method()
quant_type_kwargs = self.quant_type_kwargs.copy()
if (
not torch.cuda.is_available()
and is_torchao_available()
and self.quant_type == "int4_weight_only"
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
and quant_type_kwargs.get("layout", None) is None
):
if torch.xpu.is_available():
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
"0.11.0"
) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"):
from torchao.dtypes import Int4XPULayout
from torchao.quantization.quant_primitives import ZeroPointDomain
quant_type_kwargs["layout"] = Int4XPULayout()
quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT
else:
raise ValueError(
"TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
)
else:
from torchao.dtypes import Int4CPULayout
quant_type_kwargs["layout"] = Int4CPULayout()
return methods[self.quant_type](**quant_type_kwargs)
return self.quant_type
def __repr__(self):
r"""
Example of how this looks for `TorchAoConfig("uint4wo", group_size=32)`:
```
TorchAoConfig {
"modules_to_not_convert": null,
"quant_method": "torchao",
"quant_type": "uint4wo",
"quant_type_kwargs": {
"group_size": 32
}
}
```
"""
config_dict = self.to_dict()
return (
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
)
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
@dataclass

View File

@@ -20,7 +20,6 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
import importlib
import re
import types
from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any
from packaging import version
@@ -114,7 +113,7 @@ if (
is_torch_available()
and is_torch_version(">=", "2.6.0")
and is_torchao_available()
and is_torchao_version(">=", "0.7.0")
and is_torchao_version(">=", "0.15.0")
):
_update_torch_safe_globals()
@@ -169,10 +168,10 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
raise ImportError(
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
)
torchao_version = version.parse(importlib.metadata.version("torch"))
if torchao_version < version.parse("0.7.0"):
torchao_version = version.parse(importlib.metadata.version("torchao"))
if torchao_version < version.parse("0.15.0"):
raise RuntimeError(
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
f"The minimum required version of `torchao` is 0.15.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
)
self.offload = False
@@ -199,13 +198,13 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
)
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type
if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
)
config_name = self.quantization_config.quant_type.__class__.__name__
is_int_quant = config_name.startswith("Int") or config_name.startswith("Uint")
if is_int_quant and torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for integer quantization, but "
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
)
if torch_dtype is None:
# We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
@@ -219,45 +218,16 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
return torch_dtype
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
quant_type = self.quantization_config.quant_type
from accelerate.utils import CustomDtype
if isinstance(quant_type, str):
if quant_type.startswith("int8"):
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
return torch.int8
elif quant_type.startswith("int4"):
return CustomDtype.INT4
elif quant_type == "uintx_weight_only":
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
elif quant_type.startswith("uint"):
return {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
}[int(quant_type[4])]
elif quant_type.startswith("float") or quant_type.startswith("fp"):
return torch.bfloat16
quant_type = self.quantization_config.quant_type
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
elif is_torchao_version(">", "0.9.0"):
from torchao.core.config import AOBaseConfig
quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
else:
# Default to int8
return torch.int8
if size_digit == "4":
return CustomDtype.INT4
else:
return torch.int8
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
return target_dtype
@@ -337,29 +307,14 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
- Use a division factor of 8 for int4 weights
- Use a division factor of 4 for int8 weights
"""
# Original mapping for non-AOBaseConfig types
# For the uint types, this is a best guess. Once these types become more used
# we can look into their nuances.
if is_torchao_version(">", "0.9.0"):
from torchao.core.config import AOBaseConfig
quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
if size_digit == "4":
return 8
else:
return 4
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
quant_type = self.quantization_config.quant_type
for pattern, target_dtype in map_to_target_dtype.items():
if fnmatch(quant_type, pattern):
return target_dtype
raise ValueError(f"Unsupported quant_type: {quant_type!r}")
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
if size_digit == "4":
return 8
else:
return 4
def _process_model_before_weight_loading(
self,
@@ -415,9 +370,17 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
return _is_torchao_serializable
_TRAINABLE_QUANTIZATION_CONFIGS = (
"Int8WeightOnlyConfig",
"Int8DynamicActivationInt8WeightConfig",
"Int8StaticActivationInt8WeightConfig",
"Float8WeightOnlyConfig",
"Float8DynamicActivationFloat8WeightConfig",
)
@property
def is_trainable(self):
return self.quantization_config.quant_type.startswith("int8")
return self.quantization_config.quant_type.__class__.__name__ in self._TRAINABLE_QUANTIZATION_CONFIGS
@property
def is_compileable(self) -> bool:

View File

@@ -44,9 +44,9 @@ class AutoencoderTesterMixin:
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_enable_disable_tiling(self):
if not hasattr(self.model_class, "enable_tiling"):

View File

@@ -25,7 +25,6 @@ from diffusers.utils.import_utils import (
is_nvidia_modelopt_available,
is_optimum_quanto_available,
is_torchao_available,
is_torchao_version,
)
from ...testing_utils import (
@@ -63,8 +62,7 @@ if is_gguf_available():
pass
if is_torchao_available():
if is_torchao_version(">=", "0.9.0"):
pass
import torchao.quantization as _torchao_quantization
class LoRALayer(torch.nn.Module):
@@ -806,9 +804,9 @@ class TorchAoConfigMixin:
"""
TORCHAO_QUANT_TYPES = {
"int4wo": {"quant_type": "int4_weight_only"},
"int8wo": {"quant_type": "int8_weight_only"},
"int8dq": {"quant_type": "int8_dynamic_activation_int8_weight"},
"int4wo": "Int4WeightOnlyConfig",
"int8wo": "Int8WeightOnlyConfig",
"int8dq": "Int8DynamicActivationInt8WeightConfig",
}
TORCHAO_EXPECTED_MEMORY_REDUCTIONS = {
@@ -817,8 +815,13 @@ class TorchAoConfigMixin:
"int8dq": 1.5,
}
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
config = TorchAoConfig(**config_kwargs)
@staticmethod
def _get_quant_config(config_name):
config_cls = getattr(_torchao_quantization, config_name)
return TorchAoConfig(config_cls())
def _create_quantized_model(self, config_name, **extra_kwargs):
config = self._get_quant_config(config_name)
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
kwargs["quantization_config"] = config
kwargs["device_map"] = str(torch_device)

View File

@@ -14,13 +14,11 @@
# limitations under the License.
import gc
import importlib.metadata
import tempfile
import unittest
from typing import List
import numpy as np
from packaging import version
from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -55,6 +53,20 @@ from ..test_torch_compile_utils import QuantCompileTests
enable_full_determinism()
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
if is_torch_available():
import torch
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
if major == 8:
return minor >= 9
return major >= 9
elif torch.xpu.is_available():
return True
return False
if is_torch_available():
import torch
import torch.nn as nn
@@ -64,75 +76,56 @@ if is_torch_available():
if is_torchao_available():
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization import (
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
)
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.9.0"):
from torchao.quantization import Int8WeightOnlyConfig
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Makes sure the config format is properly set
"""
quantization_config = TorchAoConfig("int4_weight_only")
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(version=2))
torchao_orig_config = quantization_config.to_dict()
for key in torchao_orig_config:
self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key])
self.assertIn("quant_type", torchao_orig_config)
self.assertIn("quant_method", torchao_orig_config)
def test_post_init_check(self):
"""
Test kwargs validations in TorchAoConfig
Test that non-AOBaseConfig types are rejected
"""
_ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "is not supported"):
_ = TorchAoConfig("uint8")
_ = TorchAoConfig(Int4WeightOnlyConfig())
with self.assertRaises(TypeError):
_ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"):
_ = TorchAoConfig("int4_weight_only", group_size1=32)
with self.assertRaises(TypeError):
_ = TorchAoConfig(42)
def test_repr(self):
"""
Check that there is no error in the repr
"""
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
expected_repr = """TorchAoConfig {
"modules_to_not_convert": [
"conv"
],
"quant_method": "torchao",
"quant_type": "int4_weight_only",
"quant_type_kwargs": {
"group_size": 8
}
}""".replace(" ", "").replace("\n", "")
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
self.assertEqual(quantization_repr, expected_repr)
quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC)
expected_repr = """TorchAoConfig {
"modules_to_not_convert": null,
"quant_method": "torchao",
"quant_type": "int4dq",
"quant_type_kwargs": {
"act_mapping_type": "SYMMETRIC",
"group_size": 64
}
}""".replace(" ", "").replace("\n", "")
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
self.assertEqual(quantization_repr, expected_repr)
quantization_config = TorchAoConfig(Int8WeightOnlyConfig(version=2), modules_to_not_convert=["conv"])
quantization_repr = repr(quantization_config)
self.assertIn("TorchAoConfig", quantization_repr)
self.assertIn("torchao", quantization_repr)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
@@ -234,79 +227,30 @@ class TorchAoTest(unittest.TestCase):
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
# fmt: off
QUANTIZATION_TYPES_TO_TEST = [
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
(Int4WeightOnlyConfig(version=2), np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
(Int8DynamicActivationIntxWeightConfig(version=2), np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
(Int8WeightOnlyConfig(version=2), np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
(Int8DynamicActivationInt8WeightConfig(version=2), np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
(IntxWeightOnlyConfig(dtype=torch.uint4, group_size=16, version=2), np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
(IntxWeightOnlyConfig(dtype=torch.uint7, group_size=16, version=2), np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
]
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
if _is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
# =====
# The following lead to an internal torch error:
# RuntimeError: mat2 shape (32x4 must be divisible by 16
# Skip these for now; TODO(aryan): investigate later
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
# Cutlass fails to initialize for below
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
(Float8WeightOnlyConfig(weight_dtype=torch.float8_e5m2), np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
(Float8WeightOnlyConfig(weight_dtype=torch.float8_e4m3fn), np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
])
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
QUANTIZATION_TYPES_TO_TEST.extend([
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
])
# fmt: on
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quant_kwargs = {}
if quantization_name in ["uint4wo", "uint7wo"]:
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
quant_kwargs.update({"group_size": 16})
quantization_config = TorchAoConfig(
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
)
for quant_config, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quantization_config = TorchAoConfig(quant_type=quant_config, modules_to_not_convert=["x_embedder"])
self._test_quant_type(quantization_config, expected_slice, model_id)
@unittest.skip("Skipping floatx quantization tests")
def test_floatx_quantization(self):
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
self._test_quant_type(
quantization_config,
np.array(
[
0.4648,
0.5195,
0.5547,
0.4180,
0.4434,
0.6445,
0.4316,
0.4531,
0.5625,
]
),
model_id,
)
else:
# Make sure the correct error is thrown
with self.assertRaisesRegex(ValueError, "Please downgrade"):
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
def test_int4wo_quant_bfloat16_conversion(self):
"""
Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization.
"""
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64))
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
@@ -361,7 +305,7 @@ class TorchAoTest(unittest.TestCase):
else:
expected_slice = expected_slice_offload
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64))
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
@@ -385,7 +329,7 @@ class TorchAoTest(unittest.TestCase):
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64))
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
@@ -406,7 +350,7 @@ class TorchAoTest(unittest.TestCase):
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
quantization_config = TorchAoConfig(Int8WeightOnlyConfig(), modules_to_not_convert=["transformer_blocks.0"])
quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
@@ -422,7 +366,7 @@ class TorchAoTest(unittest.TestCase):
quantized_layer = quantized_model_with_not_convert.proj_out
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
quantization_config = TorchAoConfig("int8_weight_only")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
@@ -436,7 +380,7 @@ class TorchAoTest(unittest.TestCase):
self.assertTrue(size_quantized < size_quantized_with_not_convert)
def test_training(self):
quantization_config = TorchAoConfig("int8_weight_only")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
@@ -470,7 +414,7 @@ class TorchAoTest(unittest.TestCase):
def test_torch_compile(self):
r"""Test that verifies if torch.compile works with torchao quantization."""
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
quantization_config = TorchAoConfig("int8_weight_only")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
components = self.get_dummy_components(quantization_config, model_id=model_id)
pipe = FluxPipeline(**components)
pipe.to(device=torch_device)
@@ -491,11 +435,15 @@ class TorchAoTest(unittest.TestCase):
memory footprint of the converted model and the class type of the linear layers of the converted models
"""
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"]
transformer_int4wo = self.get_dummy_components(TorchAoConfig(Int4WeightOnlyConfig()), model_id=model_id)[
"transformer"
]
transformer_int4wo_gs32 = self.get_dummy_components(
TorchAoConfig("int4wo", group_size=32), model_id=model_id
TorchAoConfig(Int4WeightOnlyConfig(group_size=32)), model_id=model_id
)["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()), model_id=model_id)[
"transformer"
]
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
# Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
@@ -553,20 +501,22 @@ class TorchAoTest(unittest.TestCase):
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
del transformer_bf16
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()), model_id=model_id)[
"transformer"
]
transformer_int8wo.to(torch_device)
quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
def test_wrong_config(self):
with self.assertRaises(ValueError):
with self.assertRaises(TypeError):
self.get_dummy_components(TorchAoConfig("int42"))
def test_sequential_cpu_offload(self):
r"""
A test that checks if inference runs as expected when sequential cpu offloading is enabled.
"""
quantization_config = TorchAoConfig("int8wo")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components)
pipe.enable_sequential_cpu_offload()
@@ -574,7 +524,7 @@ class TorchAoTest(unittest.TestCase):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)
@require_torchao_version_greater_or_equal("0.9.0")
@require_torchao_version_greater_or_equal("0.15.0")
def test_aobase_config(self):
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
components = self.get_dummy_components(quantization_config)
@@ -587,7 +537,7 @@ class TorchAoTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
@@ -595,8 +545,8 @@ class TorchAoSerializationTest(unittest.TestCase):
gc.collect()
backend_empty_cache(torch_device)
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
def get_dummy_model(self, quant_type, device=None):
quantization_config = TorchAoConfig(quant_type)
quantized_model = FluxTransformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
@@ -632,8 +582,8 @@ class TorchAoSerializationTest(unittest.TestCase):
"timestep": timestep,
}
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
def _test_original_model_expected_slice(self, quant_type, expected_slice):
quantized_model = self.get_dummy_model(quant_type, torch_device)
inputs = self.get_dummy_tensor_inputs(torch_device)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
@@ -641,8 +591,8 @@ class TorchAoSerializationTest(unittest.TestCase):
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
def _check_serialization_expected_slice(self, quant_type, expected_slice, device):
quantized_model = self.get_dummy_model(quant_type, device)
with tempfile.TemporaryDirectory() as tmp_dir:
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
@@ -662,43 +612,42 @@ class TorchAoSerializationTest(unittest.TestCase):
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def test_int_a8w8_accelerator(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
quant_type = Int8DynamicActivationInt8WeightConfig()
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
self._test_original_model_expected_slice(quant_type, expected_slice)
self._check_serialization_expected_slice(quant_type, expected_slice, device)
def test_int_a16w8_accelerator(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
quant_type = Int8WeightOnlyConfig()
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
self._test_original_model_expected_slice(quant_type, expected_slice)
self._check_serialization_expected_slice(quant_type, expected_slice, device)
def test_int_a8w8_cpu(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
quant_type = Int8DynamicActivationInt8WeightConfig()
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cpu"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
self._test_original_model_expected_slice(quant_type, expected_slice)
self._check_serialization_expected_slice(quant_type, expected_slice, device)
def test_int_a16w8_cpu(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
quant_type = Int8WeightOnlyConfig()
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cpu"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
self._test_original_model_expected_slice(quant_type, expected_slice)
self._check_serialization_expected_slice(quant_type, expected_slice, device)
@require_torchao_version_greater_or_equal("0.9.0")
def test_aobase_config(self):
quant_method, quant_method_kwargs = Int8WeightOnlyConfig(), {}
quant_type = Int8WeightOnlyConfig()
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
self._test_original_model_expected_slice(quant_type, expected_slice)
self._check_serialization_expected_slice(quant_type, expected_slice, device)
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
@@ -744,7 +693,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
@@ -817,29 +766,25 @@ class SlowTorchAoTests(unittest.TestCase):
def test_quantization(self):
# fmt: off
QUANTIZATION_TYPES_TO_TEST = [
("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
(Int8WeightOnlyConfig(), np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
(Int8DynamicActivationInt8WeightConfig(), np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
]
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
if _is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
(Float8WeightOnlyConfig(weight_dtype=torch.float8_e4m3fn), np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
])
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
QUANTIZATION_TYPES_TO_TEST.extend([
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
])
# fmt: on
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
for quant_config, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quantization_config = TorchAoConfig(quant_type=quant_config, modules_to_not_convert=["x_embedder"])
self._test_quant_type(quantization_config, expected_slice)
gc.collect()
backend_empty_cache(torch_device)
backend_synchronize(torch_device)
def test_serialization_int8wo(self):
quantization_config = TorchAoConfig("int8wo")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components)
pipe.enable_model_cpu_offload()
@@ -876,7 +821,7 @@ class SlowTorchAoTests(unittest.TestCase):
def test_memory_footprint_int4wo(self):
# The original checkpoints are in bf16 and about 24 GB
expected_memory_in_gb = 6.0
quantization_config = TorchAoConfig("int4wo")
quantization_config = TorchAoConfig(Int4WeightOnlyConfig())
cache_dir = None
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
@@ -891,7 +836,7 @@ class SlowTorchAoTests(unittest.TestCase):
def test_memory_footprint_int8wo(self):
# The original checkpoints are in bf16 and about 24 GB
expected_memory_in_gb = 12.0
quantization_config = TorchAoConfig("int8wo")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
cache_dir = None
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
@@ -906,7 +851,7 @@ class SlowTorchAoTests(unittest.TestCase):
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):