mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-31 21:06:45 +08:00
Compare commits
16 Commits
hunyuan-te
...
autoencode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39d7b1aa41 | ||
|
|
e231b433a3 | ||
|
|
7e463ea4cc | ||
|
|
7f2b34bced | ||
|
|
e1e7d58a4a | ||
|
|
a93f7f137a | ||
|
|
10ec3040a2 | ||
|
|
f2be8bd6b3 | ||
|
|
7da22b9db5 | ||
|
|
1fe2125802 | ||
|
|
7298f5be93 | ||
|
|
b757035df6 | ||
|
|
41e1003316 | ||
|
|
85ffcf1db2 | ||
|
|
cbf4d9a3c3 | ||
|
|
426daabad9 |
@@ -10,24 +10,34 @@ Strive to write code as simple and explicit as possible.
|
||||
|
||||
---
|
||||
|
||||
### Dependencies
|
||||
- No new mandatory dependency without discussion (e.g. `einops`)
|
||||
- Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`
|
||||
|
||||
## Code formatting
|
||||
|
||||
- `make style` and `make fix-copies` should be run as the final step before opening a PR
|
||||
|
||||
### Copied Code
|
||||
|
||||
- Many classes are kept in sync with a source via a `# Copied from ...` header comment
|
||||
- Do not edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source
|
||||
- Remove the header to intentionally break the link
|
||||
|
||||
### Models
|
||||
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
|
||||
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
|
||||
- See the **model-integration** skill for the attention pattern, pipeline rules, test setup instructions, and other important details.
|
||||
|
||||
- See [models.md](models.md) for model conventions, attention pattern, implementation rules, dependencies, and gotchas.
|
||||
- See the [model-integration](./skills/model-integration/SKILL.md) skill for the full integration workflow, file structure, test setup, and other details.
|
||||
|
||||
### Pipelines & Schedulers
|
||||
|
||||
- Pipelines inherit from `DiffusionPipeline`
|
||||
- Schedulers use `SchedulerMixin` with `ConfigMixin`
|
||||
- Use `@torch.no_grad()` on pipeline `__call__`
|
||||
- Support `output_type="latent"` for skipping VAE decode
|
||||
- Support `generator` parameter for reproducibility
|
||||
- Use `self.progress_bar(timesteps)` for progress tracking
|
||||
- Don't subclass an existing pipeline for a variant — DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`)
|
||||
|
||||
## Skills
|
||||
|
||||
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents.
|
||||
Available skills: **model-integration** (adding/converting pipelines), **parity-testing** (debugging numerical parity).
|
||||
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents. Available skills include:
|
||||
|
||||
- [model-integration](./skills/model-integration/SKILL.md) (adding/converting pipelines)
|
||||
- [parity-testing](./skills/parity-testing/SKILL.md) (debugging numerical parity).
|
||||
|
||||
76
.ai/models.md
Normal file
76
.ai/models.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# Model conventions and rules
|
||||
|
||||
Shared reference for model-related conventions, patterns, and gotchas.
|
||||
Linked from `AGENTS.md`, `skills/model-integration/SKILL.md`, and `review-rules.md`.
|
||||
|
||||
## Coding style
|
||||
|
||||
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
|
||||
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
|
||||
- No new mandatory dependency without discussion (e.g. `einops`). Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`.
|
||||
|
||||
## Common model conventions
|
||||
|
||||
- Models use `ModelMixin` with `register_to_config` for config serialization
|
||||
|
||||
## Attention pattern
|
||||
|
||||
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
|
||||
|
||||
```python
|
||||
# transformer_mymodel.py
|
||||
|
||||
class MyModelAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(self, attn, hidden_states, attention_mask=None, ...):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
# reshape, apply rope, etc.
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
return attn.to_out[0](hidden_states)
|
||||
|
||||
|
||||
class MyModelAttention(nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = MyModelAttnProcessor
|
||||
_available_processors = [MyModelAttnProcessor]
|
||||
|
||||
def __init__(self, query_dim, heads=8, dim_head=64, ...):
|
||||
super().__init__()
|
||||
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
|
||||
self.set_processor(MyModelAttnProcessor())
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
||||
return self.processor(self, hidden_states, attention_mask, **kwargs)
|
||||
```
|
||||
|
||||
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
|
||||
|
||||
## Gotchas
|
||||
|
||||
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
|
||||
|
||||
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
|
||||
|
||||
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
|
||||
|
||||
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
|
||||
|
||||
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
|
||||
|
||||
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
|
||||
|
||||
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
|
||||
|
||||
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
|
||||
11
.ai/review-rules.md
Normal file
11
.ai/review-rules.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# PR Review Rules
|
||||
|
||||
Review-specific rules for Claude. Focus on correctness — style is handled by ruff.
|
||||
|
||||
Before reviewing, read and apply the guidelines in:
|
||||
- [AGENTS.md](AGENTS.md) — coding style, copied code
|
||||
- [models.md](models.md) — model conventions, attention pattern, implementation rules, dependencies, gotchas
|
||||
- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities
|
||||
- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.)
|
||||
|
||||
## Common mistakes (add new rules below this line)
|
||||
@@ -65,89 +65,19 @@ docs/source/en/api/
|
||||
- [ ] Run `make style` and `make quality`
|
||||
- [ ] Test parity with reference implementation (see `parity-testing` skill)
|
||||
|
||||
### Attention pattern
|
||||
### Model conventions, attention pattern, and implementation rules
|
||||
|
||||
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
|
||||
See [../../models.md](../../models.md) for the attention pattern, implementation rules, common conventions, dependencies, and gotchas. These apply to all model work.
|
||||
|
||||
```python
|
||||
# transformer_mymodel.py
|
||||
### Model integration specific rules
|
||||
|
||||
class MyModelAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(self, attn, hidden_states, attention_mask=None, ...):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
# reshape, apply rope, etc.
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
return attn.to_out[0](hidden_states)
|
||||
|
||||
|
||||
class MyModelAttention(nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = MyModelAttnProcessor
|
||||
_available_processors = [MyModelAttnProcessor]
|
||||
|
||||
def __init__(self, query_dim, heads=8, dim_head=64, ...):
|
||||
super().__init__()
|
||||
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
|
||||
self.set_processor(MyModelAttnProcessor())
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
||||
return self.processor(self, hidden_states, attention_mask, **kwargs)
|
||||
```
|
||||
|
||||
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
|
||||
|
||||
### Implementation rules
|
||||
|
||||
1. **Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
|
||||
2. **Pipelines must inherit from `DiffusionPipeline`.** Consult implementations in `src/diffusers/pipelines` in case you need references.
|
||||
3. **Don't subclass an existing pipeline for a variant.** DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`).
|
||||
**Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
|
||||
|
||||
### Test setup
|
||||
|
||||
- Slow tests gated with `@slow` and `RUN_SLOW=1`
|
||||
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.
|
||||
|
||||
### Common diffusers conventions
|
||||
|
||||
- Pipelines inherit from `DiffusionPipeline`
|
||||
- Models use `ModelMixin` with `register_to_config` for config serialization
|
||||
- Schedulers use `SchedulerMixin` with `ConfigMixin`
|
||||
- Use `@torch.no_grad()` on pipeline `__call__`
|
||||
- Support `output_type="latent"` for skipping VAE decode
|
||||
- Support `generator` parameter for reproducibility
|
||||
- Use `self.progress_bar(timesteps)` for progress tracking
|
||||
|
||||
## Gotchas
|
||||
|
||||
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
|
||||
|
||||
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
|
||||
|
||||
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
|
||||
|
||||
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
|
||||
|
||||
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
|
||||
|
||||
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
|
||||
|
||||
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
|
||||
|
||||
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
|
||||
|
||||
---
|
||||
|
||||
## Modular Pipeline Conversion
|
||||
|
||||
42
.github/workflows/claude_review.yml
vendored
Normal file
42
.github/workflows/claude_review.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: Claude PR Review
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_review_comment:
|
||||
types: [created]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: read
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
if: |
|
||||
(
|
||||
github.event_name == 'issue_comment' &&
|
||||
github.event.issue.pull_request &&
|
||||
github.event.issue.state == 'open' &&
|
||||
contains(github.event.comment.body, '@claude') &&
|
||||
(github.event.comment.author_association == 'MEMBER' ||
|
||||
github.event.comment.author_association == 'OWNER' ||
|
||||
github.event.comment.author_association == 'COLLABORATOR')
|
||||
) || (
|
||||
github.event_name == 'pull_request_review_comment' &&
|
||||
contains(github.event.comment.body, '@claude') &&
|
||||
(github.event.comment.author_association == 'MEMBER' ||
|
||||
github.event.comment.author_association == 'OWNER' ||
|
||||
github.event.comment.author_association == 'COLLABORATOR')
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--append-system-prompt "Review this PR against the rules in .ai/review-rules.md. Focus on correctness, not style (ruff handles style). Only review changes under src/diffusers/. Do NOT commit changes unless the comment explicitly asks you to using the phrase 'commit this'."
|
||||
@@ -161,6 +161,8 @@
|
||||
- local: training/ddpo
|
||||
title: Reinforcement learning training with DDPO
|
||||
title: Methods
|
||||
- local: training/nemo_automodel
|
||||
title: NeMo Automodel
|
||||
title: Training
|
||||
- isExpanded: false
|
||||
sections:
|
||||
|
||||
@@ -41,16 +41,15 @@ The quantized CogVideoX 5B model below requires ~16GB of VRAM.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, AutoModel
|
||||
from diffusers import CogVideoXPipeline, AutoModel, TorchAoConfig
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.utils import export_to_video
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
# quantize weights to int8 with torchao
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="torchao",
|
||||
quant_kwargs={"quant_type": "int8wo"},
|
||||
components_to_quantize="transformer"
|
||||
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig())}
|
||||
)
|
||||
|
||||
# fp8 layerwise weight-casting
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
|
||||
[LTX-2](https://hf.co/papers/2601.03233) is a DiT-based foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
|
||||
|
||||
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
|
||||
|
||||
@@ -293,6 +293,7 @@ import torch
|
||||
from diffusers import LTX2ConditionPipeline
|
||||
from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT
|
||||
from diffusers.utils import load_image, load_video
|
||||
|
||||
device = "cuda"
|
||||
@@ -315,19 +316,6 @@ prompt = (
|
||||
"landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the "
|
||||
"solitude and beauty of a winter drive through a mountainous region."
|
||||
)
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
|
||||
cond_video = load_video(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
|
||||
@@ -343,7 +331,7 @@ frame_rate = 24.0
|
||||
video, audio = pipe(
|
||||
conditions=conditions,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
||||
width=width,
|
||||
height=height,
|
||||
num_frames=121,
|
||||
@@ -366,6 +354,154 @@ encode_video(
|
||||
|
||||
Because the conditioning is done via latent frames, the 8 data space frames corresponding to the specified latent frame for an image condition will tend to be static.
|
||||
|
||||
## Multimodal Guidance
|
||||
|
||||
LTX-2.X pipelines support multimodal guidance. It is composed of three terms, all using a CFG-style update rule:
|
||||
|
||||
1. Classifier-Free Guidance (CFG): standard [CFG](https://huggingface.co/papers/2207.12598) where the perturbed ("weaker") output is generated using the negative prompt.
|
||||
2. Spatio-Temporal Guidance (STG): [STG](https://huggingface.co/papers/2411.18664) moves away from a perturbed output created from short-cutting self-attention operations and substitutes in the attention values instead. The idea is that this creates sharper videos and better spatiotemporal consistency.
|
||||
3. Modality Isolation Guidance: moves away from a perturbed output created from disabling cross-modality (audio-to-video and video-to-audio) cross attention. This guidance is more specific to [LTX-2.X](https://huggingface.co/papers/2601.03233) models, with the idea that this produces better consistency between the generated audio and video.
|
||||
|
||||
These are controlled by the `guidance_scale`, `stg_scale`, and `modality_scale` arguments and can be set separately for video and audio. Additionally, for STG the transformer block indices where self-attention is skipped needs to be specified via the `spatio_temporal_guidance_blocks` argument. The LTX-2.X pipelines also support [guidance rescaling](https://huggingface.co/papers/2305.08891) to help reduce over-exposure, which can be a problem when the guidance scales are set to high values.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import LTX2ImageToVideoPipeline
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT
|
||||
from diffusers.utils import load_image
|
||||
|
||||
device = "cuda"
|
||||
width = 768
|
||||
height = 512
|
||||
random_seed = 42
|
||||
frame_rate = 24.0
|
||||
generator = torch.Generator(device).manual_seed(random_seed)
|
||||
model_path = "dg845/LTX-2.3-Diffusers"
|
||||
|
||||
pipe = LTX2ImageToVideoPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
|
||||
pipe.enable_sequential_cpu_offload(device=device)
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
|
||||
"gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
|
||||
"before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
|
||||
"fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
|
||||
"shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
|
||||
"smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
|
||||
"distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
|
||||
"breath-taking, movie-like shot."
|
||||
)
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg",
|
||||
)
|
||||
|
||||
video, audio = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
||||
width=width,
|
||||
height=height,
|
||||
num_frames=121,
|
||||
frame_rate=frame_rate,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=3.0, # Recommended LTX-2.3 guidance parameters
|
||||
stg_scale=1.0, # Note that 0.0 (not 1.0) means that STG is disabled (all other guidance is disabled at 1.0)
|
||||
modality_scale=3.0,
|
||||
guidance_rescale=0.7,
|
||||
audio_guidance_scale=7.0, # Note that a higher CFG guidance scale is recommended for audio
|
||||
audio_stg_scale=1.0,
|
||||
audio_modality_scale=3.0,
|
||||
audio_guidance_rescale=0.7,
|
||||
spatio_temporal_guidance_blocks=[28],
|
||||
use_cross_timestep=True,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
fps=frame_rate,
|
||||
audio=audio[0].float().cpu(),
|
||||
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
|
||||
output_path="ltx2_3_i2v_stage_1.mp4",
|
||||
)
|
||||
```
|
||||
|
||||
## Prompt Enhancement
|
||||
|
||||
The LTX-2.X models are sensitive to prompting style. Refer to the [official prompting guide](https://ltx.io/model/model-blog/prompting-guide-for-ltx-2) for recommendations on how to write a good prompt. Using prompt enhancement, where the supplied prompts are enhanced using the pipeline's text encoder (by default a [Gemma 3](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized) model) given a system prompt, can also improve sample quality. The optional `processor` pipeline component needs to be present to use prompt enhancement. Enable prompt enhancement by supplying a `system_prompt` argument:
|
||||
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import Gemma3Processor
|
||||
from diffusers import LTX2Pipeline
|
||||
from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT, T2V_DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
device = "cuda"
|
||||
width = 768
|
||||
height = 512
|
||||
random_seed = 42
|
||||
frame_rate = 24.0
|
||||
generator = torch.Generator(device).manual_seed(random_seed)
|
||||
model_path = "dg845/LTX-2.3-Diffusers"
|
||||
|
||||
pipe = LTX2Pipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload(device=device)
|
||||
pipe.vae.enable_tiling()
|
||||
if getattr(pipe, "processor", None) is None:
|
||||
processor = Gemma3Processor.from_pretrained("google/gemma-3-12b-it-qat-q4_0-unquantized")
|
||||
pipe.processor = processor
|
||||
|
||||
prompt = (
|
||||
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
|
||||
"gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
|
||||
"before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
|
||||
"fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
|
||||
"shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
|
||||
"smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
|
||||
"distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
|
||||
"breath-taking, movie-like shot."
|
||||
)
|
||||
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
||||
width=width,
|
||||
height=height,
|
||||
num_frames=121,
|
||||
frame_rate=frame_rate,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=3.0,
|
||||
stg_scale=1.0,
|
||||
modality_scale=3.0,
|
||||
guidance_rescale=0.7,
|
||||
audio_guidance_scale=7.0,
|
||||
audio_stg_scale=1.0,
|
||||
audio_modality_scale=3.0,
|
||||
audio_guidance_rescale=0.7,
|
||||
spatio_temporal_guidance_blocks=[28],
|
||||
use_cross_timestep=True,
|
||||
system_prompt=T2V_DEFAULT_SYSTEM_PROMPT,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
fps=frame_rate,
|
||||
audio=audio[0].float().cpu(),
|
||||
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
|
||||
output_path="ltx2_3_t2v_stage_1.mp4",
|
||||
)
|
||||
```
|
||||
|
||||
## LTX2Pipeline
|
||||
|
||||
[[autodoc]] LTX2Pipeline
|
||||
|
||||
@@ -248,6 +248,24 @@ Refer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/be
|
||||
|
||||
The [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results) repository also contains benchmarking results for compiled versions of Flux and CogVideoX.
|
||||
|
||||
## Kernels
|
||||
|
||||
[Kernels](https://huggingface.co/docs/kernels/index) is a library for building, distributing, and loading optimized compute kernels on the [Hub](https://huggingface.co/kernels-community). It supports [attention](./attention_backends#set_attention_backend) kernels and custom CUDA kernels for operations like RMSNorm, GEGLU, RoPE, and AdaLN.
|
||||
|
||||
The [Diffusers Pipeline Integration](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/references/diffusers-integration.md) guide shows how to integrate a kernel with the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill. This skill enables an agent, like Claude or Codex, to write custom kernels targeted towards a specific model and your hardware.
|
||||
|
||||
> [!TIP]
|
||||
> Install the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill to teach an agent how to write a kernel. The [Custom kernels for all from Codex and Claude](https://huggingface.co/blog/custom-cuda-kernels-agent-skills) blog post covers this in more detail.
|
||||
|
||||
For example, a custom RMSNorm kernel (generated by the `add cuda-kernels` skill) with [torch.compile](#torchcompile) speeds up LTX-Video generation 1.43x on an H100.
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/docs-benchmarks/kernel-ltx-video/embed/viewer/default/train"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
## Dynamic quantization
|
||||
|
||||
[Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.
|
||||
|
||||
@@ -29,24 +29,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={"transformer": TorchAoConfig("int8wo")}
|
||||
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128, version=2))}
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
@@ -91,18 +74,15 @@ Weight-only quantization stores the model weights in a specific low-bit data typ
|
||||
|
||||
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
|
||||
|
||||
The quantization methods supported are as follows:
|
||||
Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods. An exhaustive list of configuration options are available [here](https://docs.pytorch.org/ao/main/workflows/inference.html#inference-workflows).
|
||||
|
||||
| **Category** | **Full Function Names** | **Shorthands** |
|
||||
|--------------|-------------------------|----------------|
|
||||
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
|
||||
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8dq_e4m3_tensor`, `float8dq_e4m3_row` |
|
||||
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
|
||||
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
|
||||
Some example popular quantization configurations are as follows:
|
||||
|
||||
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
|
||||
|
||||
Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
|
||||
| **Category** | **Configuration Classes** |
|
||||
|---|---|
|
||||
| **Integer quantization** | [`Int4WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int4WeightOnlyConfig.html), [`Int8WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int8WeightOnlyConfig.html), [`Int8DynamicActivationInt8WeightConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int8DynamicActivationInt8WeightConfig.html) |
|
||||
| **Floating point 8-bit quantization** | [`Float8WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Float8WeightOnlyConfig.html), [`Float8DynamicActivationFloat8WeightConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Float8DynamicActivationFloat8WeightConfig.html) |
|
||||
| **Unsigned integer quantization** | [`IntxWeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.IntxWeightOnlyConfig.html) |
|
||||
|
||||
## Serializing and Deserializing quantized models
|
||||
|
||||
@@ -111,8 +91,9 @@ To serialize a quantized model in a given dtype, first load the model with the d
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoModel, TorchAoConfig
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
@@ -137,18 +118,19 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4` weight-only, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
|
||||
from torchao.quantization import IntxWeightOnlyConfig
|
||||
|
||||
# Serialize the model
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=TorchAoConfig("uint4wo"),
|
||||
quantization_config=TorchAoConfig(IntxWeightOnlyConfig(dtype=torch.uint4)),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
|
||||
|
||||
378
docs/source/en/training/nemo_automodel.md
Normal file
378
docs/source/en/training/nemo_automodel.md
Normal file
@@ -0,0 +1,378 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# NeMo Automodel
|
||||
|
||||
[NeMo Automodel](https://github.com/NVIDIA-NeMo/Automodel) is a PyTorch DTensor-native training library from NVIDIA for fine-tuning and pretraining diffusion models at scale. It is Hugging Face native — train any Diffusers-format model from the Hub with no checkpoint conversion. The same YAML recipe and hackable training script runs on any scale from 1 GPU to hundreds of nodes, with [FSDP2](https://pytorch.org/docs/stable/fsdp.html) distributed training, multiresolution bucketed dataloading, and pre-encoded latent space training for maximum GPU utilization. It uses [flow matching](https://huggingface.co/papers/2210.02747) for training and is fully open source (Apache 2.0), NVIDIA-supported, and actively maintained.
|
||||
|
||||
NeMo Automodel integrates directly with Diffusers. It loads pretrained models from the Hugging Face Hub using Diffusers model classes and generates outputs with the [`DiffusionPipeline`].
|
||||
|
||||
The typical workflow is to install NeMo Automodel (pip or Docker), prepare your data by encoding it into `.meta` files, configure a YAML recipe, launch training with `torchrun`, and run inference with the resulting checkpoint.
|
||||
|
||||
## Supported models
|
||||
|
||||
| Model | Hugging Face ID | Task | Parameters | Use case |
|
||||
|-------|----------------|------|------------|----------|
|
||||
| Wan 2.1 T2V 1.3B | [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) | Text-to-Video | 1.3B | video generation on limited hardware (fits on single 40GB A100) |
|
||||
| FLUX.1-dev | [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) | Text-to-Image | 12B | high-quality image generation |
|
||||
| HunyuanVideo 1.5 | [hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v) | Text-to-Video | 13B | high-quality video generation |
|
||||
|
||||
## Installation
|
||||
|
||||
### Hardware requirements
|
||||
|
||||
| Component | Minimum | Recommended |
|
||||
|-----------|---------|-------------|
|
||||
| GPU | A100 40GB | A100 80GB / H100 |
|
||||
| GPUs | 4 | 8+ |
|
||||
| RAM | 128 GB | 256 GB+ |
|
||||
| Storage | 500 GB SSD | 2 TB NVMe |
|
||||
|
||||
Install NeMo Automodel with pip. For the full set of installation methods (including from source), see the [NeMo Automodel installation guide](https://docs.nvidia.com/nemo/automodel/latest/guides/installation.html).
|
||||
|
||||
```bash
|
||||
pip3 install nemo-automodel
|
||||
```
|
||||
|
||||
Alternatively, use the pre-built Docker container which includes all dependencies.
|
||||
|
||||
```bash
|
||||
docker pull nvcr.io/nvidia/nemo-automodel:26.02.00
|
||||
docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/nemo-automodel:26.02.00
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Checkpoints are lost when the container exits unless you bind-mount the checkpoint directory to the host. For example, add `-v /host/path/checkpoints:/workspace/checkpoints` to the `docker run` command.
|
||||
|
||||
|
||||
## Data preparation
|
||||
|
||||
NeMo Automodel trains diffusion models in latent space. Raw images or videos must be preprocessed into `.meta` files containing VAE latents and text embeddings before training. This avoids re-encoding on every training step.
|
||||
|
||||
Use the built-in preprocessing tool to encode your data. The tool automatically distributes work across all available GPUs.
|
||||
|
||||
<hfoptions id="data-prep">
|
||||
<hfoption id="video preprocessing">
|
||||
|
||||
The video preprocessing command is the same for both Wan 2.1 and HunyuanVideo, but the flags differ. Wan 2.1 uses `--processor wan` with `--resolution_preset` and `--caption_format sidecar`, while HunyuanVideo uses `--processor hunyuan` with `--target_frames` to set the frame count and `--caption_format meta_json`.
|
||||
|
||||
**Wan 2.1:**
|
||||
|
||||
```bash
|
||||
python -m tools.diffusion.preprocessing_multiprocess video \
|
||||
--video_dir /data/videos \
|
||||
--output_dir /cache \
|
||||
--processor wan \
|
||||
--resolution_preset 512p \
|
||||
--caption_format sidecar
|
||||
```
|
||||
|
||||
**HunyuanVideo:**
|
||||
|
||||
```bash
|
||||
python -m tools.diffusion.preprocessing_multiprocess video \
|
||||
--video_dir /data/videos \
|
||||
--output_dir /cache \
|
||||
--processor hunyuan \
|
||||
--target_frames 121 \
|
||||
--caption_format meta_json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="image preprocessing">
|
||||
|
||||
```bash
|
||||
python -m tools.diffusion.preprocessing_multiprocess image \
|
||||
--image_dir /data/images \
|
||||
--output_dir /cache \
|
||||
--processor flux \
|
||||
--resolution_preset 512p
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Output format
|
||||
|
||||
Preprocessing produces a cache directory organized by resolution bucket. NeMo Automodel supports multi-resolution training through bucketed sampling. Samples are grouped by spatial resolution so each batch contains same-size samples, avoiding padding waste.
|
||||
|
||||
```
|
||||
/cache/
|
||||
├── 512x512/ # Resolution bucket
|
||||
│ ├── <hash1>.meta # VAE latents + text embeddings
|
||||
│ ├── <hash2>.meta
|
||||
│ └── ...
|
||||
├── 832x480/ # Another resolution bucket
|
||||
│ └── ...
|
||||
├── metadata.json # Global config (processor, model, total items)
|
||||
└── metadata_shard_0000.json # Per-sample metadata (paths, resolutions, captions)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> See the [Diffusion Dataset Preparation](https://docs.nvidia.com/nemo/automodel/latest/guides/diffusion/dataset.html) guide for caption formats, input data requirements, and all available preprocessing arguments.
|
||||
|
||||
## Training configuration
|
||||
|
||||
Fine-tuning is driven by two components:
|
||||
|
||||
1. A recipe script ([finetune.py](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/diffusion/finetune/finetune.py)) is a Python entry point that contains the training loop: loading the model, building the dataloader, running forward/backward passes, computing the flow matching loss, checkpointing, and logging.
|
||||
2. A YAML configuration file specifies all settings the recipe uses: which model to fine-tune, where the data lives, optimizer hyperparameters, parallelism strategy, and more. You customize training by editing this file rather than modifying code, allowing you to scale from 1 to hundreds of GPUs.
|
||||
|
||||
Any YAML field can also be overridden from the CLI:
|
||||
|
||||
```bash
|
||||
torchrun --nproc-per-node=8 examples/diffusion/finetune/finetune.py \
|
||||
-c examples/diffusion/finetune/wan2_1_t2v_flow.yaml \
|
||||
--optim.learning_rate 1e-5 \
|
||||
--step_scheduler.num_epochs 50
|
||||
```
|
||||
|
||||
Below is the annotated config for fine-tuning Wan 2.1 T2V 1.3B, with each section explained.
|
||||
|
||||
```yaml
|
||||
seed: 42
|
||||
|
||||
# ── Experiment tracking (optional) ──────────────────────────────────────────
|
||||
# Weights & Biases integration for logging metrics, losses, and learning rates.
|
||||
# Set mode: "disabled" to turn off.
|
||||
wandb:
|
||||
project: wan-t2v-flow-matching
|
||||
mode: online
|
||||
name: wan2_1_t2v_fm
|
||||
|
||||
# ── Model ───────────────────────────────────────────────────────────────────
|
||||
# pretrained_model_name_or_path: any Hugging Face model ID or local path.
|
||||
# mode: "finetune" loads pretrained weights; "pretrain" trains from scratch.
|
||||
model:
|
||||
pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
|
||||
mode: finetune
|
||||
|
||||
# ── Training schedule ───────────────────────────────────────────────────────
|
||||
# global_batch_size: effective batch across all GPUs.
|
||||
# Gradient accumulation is computed automatically: global / (local × num_gpus).
|
||||
step_scheduler:
|
||||
global_batch_size: 8
|
||||
local_batch_size: 1
|
||||
ckpt_every_steps: 1000 # Save a checkpoint every N steps
|
||||
num_epochs: 100
|
||||
log_every: 2 # Log metrics every N steps
|
||||
|
||||
# ── Data ────────────────────────────────────────────────────────────────────
|
||||
# _target_: the dataloader factory function.
|
||||
# Use build_video_multiresolution_dataloader for video models (Wan, HunyuanVideo).
|
||||
# Use build_text_to_image_multiresolution_dataloader for image models (FLUX).
|
||||
# model_type: "wan" or "hunyuan" (selects the correct latent format).
|
||||
# base_resolution: target resolution for multiresolution bucketing.
|
||||
data:
|
||||
dataloader:
|
||||
_target_: nemo_automodel.components.datasets.diffusion.build_video_multiresolution_dataloader
|
||||
cache_dir: PATH_TO_YOUR_DATA
|
||||
model_type: wan
|
||||
base_resolution: [512, 512]
|
||||
dynamic_batch_size: false # When true, adjusts batch per bucket to maintain constant memory
|
||||
shuffle: true
|
||||
drop_last: false
|
||||
num_workers: 0
|
||||
|
||||
# ── Optimizer ───────────────────────────────────────────────────────────────
|
||||
# learning_rate: 5e-6 is a good starting point for fine-tuning.
|
||||
# Adjust weight_decay and betas for your dataset.
|
||||
optim:
|
||||
learning_rate: 5e-6
|
||||
optimizer:
|
||||
weight_decay: 0.01
|
||||
betas: [0.9, 0.999]
|
||||
|
||||
# ── Learning rate scheduler ─────────────────────────────────────────────────
|
||||
# Supports cosine, linear, and constant schedules.
|
||||
lr_scheduler:
|
||||
lr_decay_style: cosine
|
||||
lr_warmup_steps: 0
|
||||
min_lr: 1e-6
|
||||
|
||||
# ── Flow matching ───────────────────────────────────────────────────────────
|
||||
# adapter_type: model-specific adapter — must match the model:
|
||||
# "simple" for Wan 2.1, "flux" for FLUX.1-dev, "hunyuan" for HunyuanVideo.
|
||||
# timestep_sampling: "uniform" for Wan, "logit_normal" for FLUX and HunyuanVideo.
|
||||
# flow_shift: shifts the flow schedule (model-dependent).
|
||||
# i2v_prob: probability of image-to-video conditioning during training (video models).
|
||||
flow_matching:
|
||||
adapter_type: "simple"
|
||||
adapter_kwargs: {}
|
||||
timestep_sampling: "uniform"
|
||||
logit_mean: 0.0
|
||||
logit_std: 1.0
|
||||
flow_shift: 3.0
|
||||
num_train_timesteps: 1000
|
||||
i2v_prob: 0.3
|
||||
use_loss_weighting: true
|
||||
|
||||
# ── FSDP2 distributed training ──────────────────────────────────────────────
|
||||
# dp_size: number of GPUs for data parallelism (typically = total GPUs on node).
|
||||
# tp_size, cp_size, pp_size: tensor, context, and pipeline parallelism.
|
||||
# For most fine-tuning, dp_size is all you need; leave others at 1.
|
||||
fsdp:
|
||||
tp_size: 1
|
||||
cp_size: 1
|
||||
pp_size: 1
|
||||
dp_replicate_size: 1
|
||||
dp_size: 8
|
||||
|
||||
# ── Checkpointing ──────────────────────────────────────────────────────────
|
||||
# checkpoint_dir: where to save checkpoints (use a persistent path with Docker).
|
||||
# restore_from: path to resume training from a previous checkpoint.
|
||||
checkpoint:
|
||||
enabled: true
|
||||
checkpoint_dir: PATH_TO_YOUR_CKPT_DIR
|
||||
model_save_format: torch_save
|
||||
save_consolidated: false
|
||||
restore_from: null
|
||||
```
|
||||
|
||||
### Config field reference
|
||||
|
||||
The table below lists the minimal required configs. See the [NeMo Automodel examples](https://github.com/NVIDIA-NeMo/Automodel/tree/main/examples/diffusion/finetune) have full example configs for all models.
|
||||
|
||||
| Section | Required? | What to Change |
|
||||
|---------|-----------|----------------|
|
||||
| `model` | Yes | Set `pretrained_model_name_or_path` to the Hugging Face model ID. Set `mode: finetune` or `mode: pretrain`. |
|
||||
| `step_scheduler` | Yes | `global_batch_size` is the effective batch size across all GPUs. `ckpt_every_steps` controls checkpoint frequency. Gradient accumulation is computed automatically. |
|
||||
| `data` | Yes | Set `cache_dir` to the path containing your preprocessed `.meta` files. Change `_target_` and `model_type` for different models. |
|
||||
| `optim` | Yes | `learning_rate: 5e-6` is a good default for fine-tuning. Adjust for your dataset and model. |
|
||||
| `lr_scheduler` | Yes | Choose `cosine`, `linear`, or `constant` for `lr_decay_style`. Set `lr_warmup_steps` for gradual warmup. |
|
||||
| `flow_matching` | Yes | `adapter_type` must match the model (`simple` for Wan, `flux` for FLUX, `hunyuan` for HunyuanVideo). See model-specific configs for `adapter_kwargs`. |
|
||||
| `fsdp` | Yes | Set `dp_size` to the number of GPUs. For multi-node, set to total GPUs across all nodes. |
|
||||
| `checkpoint` | Recommended | Set `checkpoint_dir` to a persistent path, especially in Docker. Use `restore_from` to resume from a previous checkpoint. |
|
||||
| `wandb` | Optional | Configure to enable Weights & Biases experiment tracking. Set `mode: disabled` to turn off. |
|
||||
|
||||
|
||||
|
||||
## Launch training
|
||||
|
||||
<hfoptions id="launch-training">
|
||||
<hfoption id="single-node">
|
||||
|
||||
```bash
|
||||
torchrun --nproc-per-node=8 \
|
||||
examples/diffusion/finetune/finetune.py \
|
||||
-c examples/diffusion/finetune/wan2_1_t2v_flow.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="multi-node">
|
||||
|
||||
Run the following on each node, setting `NODE_RANK` accordingly:
|
||||
|
||||
```bash
|
||||
export MASTER_ADDR=node0.hostname
|
||||
export MASTER_PORT=29500
|
||||
export NODE_RANK=0 # 0 on master, 1 on second node, etc.
|
||||
|
||||
torchrun \
|
||||
--nnodes=2 \
|
||||
--nproc-per-node=8 \
|
||||
--node_rank=${NODE_RANK} \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
|
||||
examples/diffusion/finetune/finetune.py \
|
||||
-c examples/diffusion/finetune/wan2_1_t2v_flow_multinode.yaml
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For multi-node training, set `fsdp.dp_size` in the YAML to the **total** number of GPUs across all nodes (e.g., 16 for 2 nodes with 8 GPUs each).
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Generation
|
||||
|
||||
After training, generate videos or images from text prompts using the fine-tuned checkpoint.
|
||||
|
||||
<hfoptions id="generation">
|
||||
<hfoption id="Wan 2.1">
|
||||
|
||||
```bash
|
||||
python examples/diffusion/generate/generate.py \
|
||||
-c examples/diffusion/generate/configs/generate_wan.yaml
|
||||
```
|
||||
|
||||
With a fine-tuned checkpoint:
|
||||
|
||||
```bash
|
||||
python examples/diffusion/generate/generate.py \
|
||||
-c examples/diffusion/generate/configs/generate_wan.yaml \
|
||||
--model.checkpoint ./checkpoints/step_1000 \
|
||||
--inference.prompts '["A dog running on a beach"]'
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="FLUX">
|
||||
|
||||
```bash
|
||||
python examples/diffusion/generate/generate.py \
|
||||
-c examples/diffusion/generate/configs/generate_flux.yaml
|
||||
```
|
||||
|
||||
With a fine-tuned checkpoint:
|
||||
|
||||
```bash
|
||||
python examples/diffusion/generate/generate.py \
|
||||
-c examples/diffusion/generate/configs/generate_flux.yaml \
|
||||
--model.checkpoint ./checkpoints/step_1000 \
|
||||
--inference.prompts '["A dog running on a beach"]'
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="HunyuanVideo">
|
||||
|
||||
```bash
|
||||
python examples/diffusion/generate/generate.py \
|
||||
-c examples/diffusion/generate/configs/generate_hunyuan.yaml
|
||||
```
|
||||
|
||||
With a fine-tuned checkpoint:
|
||||
|
||||
```bash
|
||||
python examples/diffusion/generate/generate.py \
|
||||
-c examples/diffusion/generate/configs/generate_hunyuan.yaml \
|
||||
--model.checkpoint ./checkpoints/step_1000 \
|
||||
--inference.prompts '["A dog running on a beach"]'
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Diffusers integration
|
||||
|
||||
NeMo Automodel is built on top of Diffusers and uses it as the backbone for model loading and inference. It loads models directly from the Hugging Face Hub using Diffusers model classes such as [`WanTransformer3DModel`], [`FluxTransformer2DModel`], and [`HunyuanVideoTransformer3DModel`], and generates outputs via Diffusers pipelines like [`WanPipeline`] and [`FluxPipeline`].
|
||||
|
||||
This integration provides several benefits for Diffusers users:
|
||||
|
||||
- **No checkpoint conversion**: pretrained weights from the Hub work out of the box. Point `pretrained_model_name_or_path` at any Diffusers-format model ID and start training immediately.
|
||||
- **Day-0 model support**: when a new diffusion model is added to Diffusers and uploaded to the Hub, it can be fine-tuned with NeMo Automodel without waiting for a dedicated training script.
|
||||
- **Pipeline-compatible outputs**: fine-tuned checkpoints are saved in a format that can be loaded directly back into Diffusers pipelines for inference, sharing on the Hub, or further optimization with tools like quantization and compilation.
|
||||
- **Scalable training for Diffusers models**: NeMo Automodel adds distributed training capabilities (FSDP2, multi-node, multiresolution bucketing) that go beyond what the built-in Diffusers training scripts provide, while keeping the same model and pipeline interfaces.
|
||||
- **Shared ecosystem**: any model, LoRA adapter, or pipeline component from the Diffusers ecosystem remains compatible throughout the training and inference workflow.
|
||||
|
||||
## NVIDIA Team
|
||||
|
||||
- Pranav Prashant Thombre, pthombre@nvidia.com
|
||||
- Linnan Wang, linnanw@nvidia.com
|
||||
- Alexandros Koumparoulis, akoumparouli@nvidia.com
|
||||
|
||||
## Resources
|
||||
|
||||
- [NeMo Automodel GitHub](https://github.com/NVIDIA-NeMo/Automodel)
|
||||
- [Diffusion Fine-Tuning Guide](https://docs.nvidia.com/nemo/automodel/latest/guides/diffusion/finetune.html)
|
||||
- [Diffusion Dataset Preparation](https://docs.nvidia.com/nemo/automodel/latest/guides/diffusion/dataset.html)
|
||||
- [Diffusion Model Coverage](https://docs.nvidia.com/nemo/automodel/latest/model-coverage/diffusion.html)
|
||||
- [NeMo Automodel for Transformers (LLM/VLM fine-tuning)](https://huggingface.co/docs/transformers/en/community_integrations/nemo_automodel_finetuning)
|
||||
@@ -347,16 +347,17 @@ When LoRA was first adapted from language models to diffusion models, it was app
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
|
||||
the exact modules for LoRA training. Here are some examples of target modules you can provide:
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.to_qkv_mlp_proj"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.to_qkv_mlp_proj,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.linear_in,ff.linear_out,ff_context.linear_in,ff_context.linear_out"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.to_qkv_mlp_proj,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.linear_in,ff.linear_out,ff_context.linear_in,ff_context.linear_out,norm_out.linear,norm_out.proj_out"`
|
||||
> [!NOTE]
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
|
||||
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> [!NOTE]
|
||||
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
|
||||
|
||||
> [!NOTE]
|
||||
In FLUX2, the q, k, and v projections are fused into a single linear layer named attn.to_qkv_mlp_proj within the single transformer block. Also, the attention output is just attn.to_out, not attn.to_out.0 — it’s no longer a ModuleList like in transformer block.
|
||||
|
||||
## Training Image-to-Image
|
||||
|
||||
|
||||
@@ -1256,7 +1256,13 @@ def main(args):
|
||||
if args.lora_layers is not None:
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
else:
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
|
||||
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
|
||||
|
||||
# train transformer_blocks and single_transformer_blocks
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
|
||||
"to_qkv_mlp_proj",
|
||||
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(48)],
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
|
||||
@@ -1206,7 +1206,13 @@ def main(args):
|
||||
if args.lora_layers is not None:
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
else:
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
|
||||
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
|
||||
|
||||
# train transformer_blocks and single_transformer_blocks
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
|
||||
"to_qkv_mlp_proj",
|
||||
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(48)],
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
|
||||
@@ -1249,7 +1249,13 @@ def main(args):
|
||||
if args.lora_layers is not None:
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
else:
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
|
||||
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
|
||||
|
||||
# train transformer_blocks and single_transformer_blocks
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
|
||||
"to_qkv_mlp_proj",
|
||||
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(24)],
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
|
||||
@@ -1200,7 +1200,13 @@ def main(args):
|
||||
if args.lora_layers is not None:
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
else:
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
|
||||
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
|
||||
|
||||
# train transformer_blocks and single_transformer_blocks
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
|
||||
"to_qkv_mlp_proj",
|
||||
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(24)],
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
|
||||
@@ -1105,7 +1105,7 @@ def main(args):
|
||||
|
||||
# text encoding.
|
||||
captions = batch["captions"]
|
||||
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
|
||||
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
|
||||
captions, prompt_2=None
|
||||
|
||||
@@ -1251,7 +1251,7 @@ def main(args):
|
||||
|
||||
# text encoding.
|
||||
captions = batch["captions"]
|
||||
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
|
||||
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
|
||||
captions, prompt_2=None
|
||||
|
||||
@@ -862,23 +862,23 @@ def _native_attention_backward_op(
|
||||
key.requires_grad_(True)
|
||||
value.requires_grad_(True)
|
||||
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
with torch.enable_grad():
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
grad_out_t = grad_out.permute(0, 2, 1, 3)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
|
||||
)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out, retain_graph=False
|
||||
)
|
||||
|
||||
grad_query = grad_query_t.permute(0, 2, 1, 3)
|
||||
grad_key = grad_key_t.permute(0, 2, 1, 3)
|
||||
|
||||
@@ -1,6 +1,155 @@
|
||||
# Copyright 2026 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Pre-trained sigma values for distilled model are taken from
|
||||
# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py
|
||||
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]
|
||||
|
||||
# Reduced schedule for super-resolution stage 2 (subset of distilled values)
|
||||
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875]
|
||||
|
||||
|
||||
# Default negative prompt from
|
||||
# https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py#L131-L143
|
||||
DEFAULT_NEGATIVE_PROMPT = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
|
||||
|
||||
# System prompts for prompt enhancement
|
||||
# https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt#L1
|
||||
# Disable line-too-long rule in ruff to keep the prompts exactly the same (e.g. in terms of newlines)
|
||||
# Supported in ruff>=0.15.0
|
||||
# ruff: disable[E501]
|
||||
T2V_DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed
|
||||
video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
|
||||
|
||||
#### Guidelines
|
||||
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions,
|
||||
actions, camera movement, audio).
|
||||
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
|
||||
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
|
||||
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural
|
||||
movements.
|
||||
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
|
||||
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested).
|
||||
Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g.,
|
||||
"ambient sound is present").
|
||||
- Speech (only when requested):
|
||||
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with
|
||||
voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
|
||||
- Specify language if not English and accent if relevant.
|
||||
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if
|
||||
unspecified. Omit if unclear.
|
||||
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
|
||||
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
|
||||
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
|
||||
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
|
||||
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
|
||||
|
||||
#### Important notes:
|
||||
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is
|
||||
requested.
|
||||
- Camera motion: DO NOT invent camera motion unless requested by the user.
|
||||
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
|
||||
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological
|
||||
scene description.
|
||||
- Format: DO NOT start your response with special characters.
|
||||
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits
|
||||
or introduce new elements. Add/enhance audio descriptions if missing.
|
||||
|
||||
#### Output Format (Strict):
|
||||
- Single continuous paragraph in natural language (English).
|
||||
- NO titles, headings, prefaces, code fences, or Markdown.
|
||||
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||
|
||||
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video
|
||||
generation.
|
||||
|
||||
#### Example Input: "A woman at a coffee shop talking on the phone" Output: Style: realistic with cinematic lighting.
|
||||
In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the
|
||||
window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone
|
||||
to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle
|
||||
clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a
|
||||
soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd
|
||||
love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her
|
||||
chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully,
|
||||
lowering the phone.
|
||||
"""
|
||||
# ruff: enable[E501]
|
||||
|
||||
# ruff: disable[E501]
|
||||
I2V_DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and
|
||||
user Raw Input Prompt, generate a prompt to guide video generation from that image.
|
||||
|
||||
#### Guidelines:
|
||||
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
|
||||
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in
|
||||
conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image
|
||||
to user's scene).
|
||||
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause
|
||||
scene cuts.
|
||||
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural
|
||||
movements.
|
||||
- Chronological flow: Use temporal connectors ("as," "then," "while").
|
||||
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio
|
||||
intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when
|
||||
requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
|
||||
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The
|
||||
tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation
|
||||
mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should
|
||||
include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just
|
||||
saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room
|
||||
underscores his animated speech.")
|
||||
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
|
||||
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
|
||||
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
|
||||
|
||||
#### Important notes:
|
||||
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion
|
||||
only if specified in the input.
|
||||
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
|
||||
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
|
||||
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style
|
||||
(optional) and chronological scene description.
|
||||
- Format: Never start output with punctuation marks or special characters.
|
||||
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio
|
||||
descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
|
||||
|
||||
#### Output Format (Strict):
|
||||
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
|
||||
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||
|
||||
#### Example output: Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a
|
||||
cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the
|
||||
counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine
|
||||
hissing softly blends with gentle background chatter and the light clinking of cups on saucers.
|
||||
"""
|
||||
# ruff: enable[E501]
|
||||
|
||||
@@ -23,20 +23,17 @@ https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
|
||||
from ..utils import deprecate, is_torch_available, is_torchao_version, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -53,16 +50,6 @@ class QuantizationMethod(str, Enum):
|
||||
MODELOPT = "modelopt"
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.quantization.quant_primitives import MappingType
|
||||
|
||||
class TorchAoJSONEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, MappingType):
|
||||
return obj.name
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizationConfigMixin:
|
||||
"""
|
||||
@@ -446,49 +433,21 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
"""This is a config class for torchao quantization/sparsity techniques.
|
||||
|
||||
Args:
|
||||
quant_type (`str` | AOBaseConfig):
|
||||
The type of quantization we want to use, currently supporting:
|
||||
- **Integer quantization:**
|
||||
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
|
||||
`int8_weight_only`, `int8_dynamic_activation_int8_weight`
|
||||
- Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq`
|
||||
|
||||
- **Floating point 8-bit quantization:**
|
||||
- Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`,
|
||||
`float8_static_activation_float8_weight`
|
||||
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
|
||||
`float8_e4m3_tensor`, `float8_e4m3_row`,
|
||||
|
||||
- **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0)
|
||||
- Full function names: `fpx_weight_only`
|
||||
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
|
||||
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
|
||||
be satisfied for a given shorthand notation.
|
||||
|
||||
- **Unsigned Integer quantization:**
|
||||
- Full function names: `uintx_weight_only`
|
||||
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
|
||||
- An AOBaseConfig instance: for more advanced configuration options.
|
||||
quant_type (`AOBaseConfig`):
|
||||
An `AOBaseConfig` subclass instance specifying the quantization type. See the [torchao
|
||||
documentation](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) for
|
||||
available config classes (e.g. `Int4WeightOnlyConfig`, `Int8WeightOnlyConfig`, `Float8WeightOnlyConfig`,
|
||||
`Float8DynamicActivationFloat8WeightConfig`, etc.).
|
||||
modules_to_not_convert (`list[str]`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
|
||||
modules left in their original precision.
|
||||
kwargs (`dict[str, Any]`, *optional*):
|
||||
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization
|
||||
supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and
|
||||
documentation of arguments can be found in
|
||||
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
|
||||
|
||||
Example:
|
||||
```python
|
||||
from diffusers import FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
# AOBaseConfig-based configuration
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
|
||||
# String-based config
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/Flux.1-Dev",
|
||||
subfolder="transformer",
|
||||
@@ -500,7 +459,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_type: str | "AOBaseConfig", # noqa: F821
|
||||
quant_type: "AOBaseConfig", # noqa: F821
|
||||
modules_to_not_convert: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@@ -508,102 +467,39 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
self.quant_type = quant_type
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
# When we load from serialized config, "quant_type_kwargs" will be the key
|
||||
if "quant_type_kwargs" in kwargs:
|
||||
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
|
||||
else:
|
||||
self.quant_type_kwargs = kwargs
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
if not isinstance(self.quant_type, str):
|
||||
if is_torchao_version("<=", "0.9.0"):
|
||||
raise ValueError(
|
||||
f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
|
||||
f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
|
||||
)
|
||||
if is_torchao_version("<", "0.15.0"):
|
||||
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")
|
||||
|
||||
from torchao.quantization.quant_api import AOBaseConfig
|
||||
from torchao.quantization.quant_api import AOBaseConfig
|
||||
|
||||
if not isinstance(self.quant_type, AOBaseConfig):
|
||||
raise TypeError(f"quant_type must be a AOBaseConfig instance, got {type(self.quant_type).__name__}")
|
||||
|
||||
elif isinstance(self.quant_type, str):
|
||||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
|
||||
|
||||
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
|
||||
is_floatx_quant_type = self.quant_type.startswith("fp")
|
||||
is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type
|
||||
if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
|
||||
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
|
||||
)
|
||||
elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"):
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. "
|
||||
f"Please downgrade to torchao <= 0.14.1 to use this quantization type."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
|
||||
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
|
||||
)
|
||||
|
||||
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
|
||||
signature = inspect.signature(method)
|
||||
all_kwargs = {
|
||||
param.name
|
||||
for param in signature.parameters.values()
|
||||
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
|
||||
}
|
||||
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
|
||||
|
||||
if len(unsupported_kwargs) > 0:
|
||||
raise ValueError(
|
||||
f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
|
||||
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
|
||||
)
|
||||
if not isinstance(self.quant_type, AOBaseConfig):
|
||||
raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}")
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert configuration to a dictionary."""
|
||||
d = super().to_dict()
|
||||
|
||||
if isinstance(self.quant_type, str):
|
||||
# Handle layout serialization if present
|
||||
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
|
||||
if is_dataclass(d["quant_type_kwargs"]["layout"]):
|
||||
d["quant_type_kwargs"]["layout"] = [
|
||||
d["quant_type_kwargs"]["layout"].__class__.__name__,
|
||||
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
|
||||
]
|
||||
if isinstance(d["quant_type_kwargs"]["layout"], list):
|
||||
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs"
|
||||
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
|
||||
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
|
||||
else:
|
||||
raise ValueError("layout must be a list")
|
||||
else:
|
||||
# Handle AOBaseConfig serialization
|
||||
from torchao.core.config import config_to_dict
|
||||
# Handle AOBaseConfig serialization
|
||||
from torchao.core.config import config_to_dict
|
||||
|
||||
# For now we assume there is 1 config per Transformer, however in the future
|
||||
# We may want to support a config per fqn.
|
||||
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
|
||||
# For now we assume there is 1 config per Transformer, however in the future
|
||||
# we may want to support a config per fqn.
|
||||
# See: https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.quantize_.html
|
||||
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
|
||||
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
|
||||
"""Create configuration from a dictionary."""
|
||||
if not is_torchao_version(">", "0.9.0"):
|
||||
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
|
||||
if not is_torchao_version(">=", "0.15.0"):
|
||||
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
|
||||
config_dict = config_dict.copy()
|
||||
quant_type = config_dict.pop("quant_type")
|
||||
|
||||
if isinstance(quant_type, str):
|
||||
return cls(quant_type=quant_type, **config_dict)
|
||||
# Check if we only have one key which is "default"
|
||||
# In the future we may update this
|
||||
assert len(quant_type) == 1 and "default" in quant_type, (
|
||||
@@ -618,210 +514,13 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
|
||||
return cls(quant_type=quant_type, **config_dict)
|
||||
|
||||
@classmethod
|
||||
def _get_torchao_quant_type_to_method(cls):
|
||||
r"""
|
||||
Returns supported torchao quantization types with all commonly used notations.
|
||||
"""
|
||||
|
||||
if is_torchao_available():
|
||||
# TODO(aryan): Support sparsify
|
||||
from torchao.quantization import (
|
||||
float8_dynamic_activation_float8_weight,
|
||||
float8_static_activation_float8_weight,
|
||||
float8_weight_only,
|
||||
int4_weight_only,
|
||||
int8_dynamic_activation_int4_weight,
|
||||
int8_dynamic_activation_int8_weight,
|
||||
int8_weight_only,
|
||||
uintx_weight_only,
|
||||
)
|
||||
|
||||
if is_torchao_version("<=", "0.14.1"):
|
||||
from torchao.quantization import fpx_weight_only
|
||||
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
|
||||
from torchao.quantization.observer import PerRow, PerTensor
|
||||
|
||||
def generate_float8dq_types(dtype: torch.dtype):
|
||||
name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3"
|
||||
types = {}
|
||||
|
||||
for granularity_cls in [PerTensor, PerRow]:
|
||||
# Note: Activation and Weights cannot have different granularities
|
||||
granularity_name = "tensor" if granularity_cls is PerTensor else "row"
|
||||
types[f"float8dq_{name}_{granularity_name}"] = partial(
|
||||
float8_dynamic_activation_float8_weight,
|
||||
activation_dtype=dtype,
|
||||
weight_dtype=dtype,
|
||||
granularity=(granularity_cls(), granularity_cls()),
|
||||
)
|
||||
|
||||
return types
|
||||
|
||||
def generate_fpx_quantization_types(bits: int):
|
||||
if is_torchao_version("<=", "0.14.1"):
|
||||
types = {}
|
||||
|
||||
for ebits in range(1, bits):
|
||||
mbits = bits - ebits - 1
|
||||
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
|
||||
|
||||
non_sign_bits = bits - 1
|
||||
default_ebits = (non_sign_bits + 1) // 2
|
||||
default_mbits = non_sign_bits - default_ebits
|
||||
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
|
||||
|
||||
return types
|
||||
else:
|
||||
raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0")
|
||||
|
||||
INT4_QUANTIZATION_TYPES = {
|
||||
# int4 weight + bfloat16/float16 activation
|
||||
"int4wo": int4_weight_only,
|
||||
"int4_weight_only": int4_weight_only,
|
||||
# int4 weight + int8 activation
|
||||
"int4dq": int8_dynamic_activation_int4_weight,
|
||||
"int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight,
|
||||
}
|
||||
|
||||
INT8_QUANTIZATION_TYPES = {
|
||||
# int8 weight + bfloat16/float16 activation
|
||||
"int8wo": int8_weight_only,
|
||||
"int8_weight_only": int8_weight_only,
|
||||
# int8 weight + int8 activation
|
||||
"int8dq": int8_dynamic_activation_int8_weight,
|
||||
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
|
||||
}
|
||||
|
||||
# TODO(aryan): handle torch 2.2/2.3
|
||||
FLOATX_QUANTIZATION_TYPES = {
|
||||
# float8_e5m2 weight + bfloat16/float16 activation
|
||||
"float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
|
||||
"float8_weight_only": float8_weight_only,
|
||||
"float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
|
||||
# float8_e4m3 weight + bfloat16/float16 activation
|
||||
"float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
|
||||
# float8_e5m2 weight + float8 activation (dynamic)
|
||||
"float8dq": float8_dynamic_activation_float8_weight,
|
||||
"float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight,
|
||||
# ===== Matrix multiplication is not supported in float8_e5m2 so the following errors out.
|
||||
# However, changing activation_dtype=torch.float8_e4m3 might work here =====
|
||||
# "float8dq_e5m2": partial(
|
||||
# float8_dynamic_activation_float8_weight,
|
||||
# activation_dtype=torch.float8_e5m2,
|
||||
# weight_dtype=torch.float8_e5m2,
|
||||
# ),
|
||||
# **generate_float8dq_types(torch.float8_e5m2),
|
||||
# ===== =====
|
||||
# float8_e4m3 weight + float8 activation (dynamic)
|
||||
"float8dq_e4m3": partial(
|
||||
float8_dynamic_activation_float8_weight,
|
||||
activation_dtype=torch.float8_e4m3fn,
|
||||
weight_dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
**generate_float8dq_types(torch.float8_e4m3fn),
|
||||
# float8 weight + float8 activation (static)
|
||||
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
|
||||
}
|
||||
|
||||
if is_torchao_version("<=", "0.14.1"):
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7))
|
||||
|
||||
UINTX_QUANTIZATION_DTYPES = {
|
||||
"uintx_weight_only": uintx_weight_only,
|
||||
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),
|
||||
"uint2wo": partial(uintx_weight_only, dtype=torch.uint2),
|
||||
"uint3wo": partial(uintx_weight_only, dtype=torch.uint3),
|
||||
"uint4wo": partial(uintx_weight_only, dtype=torch.uint4),
|
||||
"uint5wo": partial(uintx_weight_only, dtype=torch.uint5),
|
||||
"uint6wo": partial(uintx_weight_only, dtype=torch.uint6),
|
||||
"uint7wo": partial(uintx_weight_only, dtype=torch.uint7),
|
||||
# "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported
|
||||
}
|
||||
|
||||
QUANTIZATION_TYPES = {}
|
||||
QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES)
|
||||
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
|
||||
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
|
||||
|
||||
if cls._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
|
||||
|
||||
return QUANTIZATION_TYPES
|
||||
else:
|
||||
raise ValueError(
|
||||
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
|
||||
if torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major == 8:
|
||||
return minor >= 9
|
||||
return major >= 9
|
||||
elif torch.xpu.is_available():
|
||||
return True
|
||||
else:
|
||||
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
|
||||
|
||||
def get_apply_tensor_subclass(self):
|
||||
"""Create the appropriate quantization method based on configuration."""
|
||||
if not isinstance(self.quant_type, str):
|
||||
return self.quant_type
|
||||
else:
|
||||
methods = self._get_torchao_quant_type_to_method()
|
||||
quant_type_kwargs = self.quant_type_kwargs.copy()
|
||||
if (
|
||||
not torch.cuda.is_available()
|
||||
and is_torchao_available()
|
||||
and self.quant_type == "int4_weight_only"
|
||||
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
||||
and quant_type_kwargs.get("layout", None) is None
|
||||
):
|
||||
if torch.xpu.is_available():
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
|
||||
"0.11.0"
|
||||
) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"):
|
||||
from torchao.dtypes import Int4XPULayout
|
||||
from torchao.quantization.quant_primitives import ZeroPointDomain
|
||||
|
||||
quant_type_kwargs["layout"] = Int4XPULayout()
|
||||
quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT
|
||||
else:
|
||||
raise ValueError(
|
||||
"TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
|
||||
)
|
||||
else:
|
||||
from torchao.dtypes import Int4CPULayout
|
||||
|
||||
quant_type_kwargs["layout"] = Int4CPULayout()
|
||||
|
||||
return methods[self.quant_type](**quant_type_kwargs)
|
||||
return self.quant_type
|
||||
|
||||
def __repr__(self):
|
||||
r"""
|
||||
Example of how this looks for `TorchAoConfig("uint4wo", group_size=32)`:
|
||||
|
||||
```
|
||||
TorchAoConfig {
|
||||
"modules_to_not_convert": null,
|
||||
"quant_method": "torchao",
|
||||
"quant_type": "uint4wo",
|
||||
"quant_type_kwargs": {
|
||||
"group_size": 32
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
config_dict = self.to_dict()
|
||||
return (
|
||||
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
|
||||
)
|
||||
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -20,7 +20,6 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
|
||||
import importlib
|
||||
import re
|
||||
import types
|
||||
from fnmatch import fnmatch
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from packaging import version
|
||||
@@ -114,7 +113,7 @@ if (
|
||||
is_torch_available()
|
||||
and is_torch_version(">=", "2.6.0")
|
||||
and is_torchao_available()
|
||||
and is_torchao_version(">=", "0.7.0")
|
||||
and is_torchao_version(">=", "0.15.0")
|
||||
):
|
||||
_update_torch_safe_globals()
|
||||
|
||||
@@ -169,10 +168,10 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
raise ImportError(
|
||||
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
|
||||
)
|
||||
torchao_version = version.parse(importlib.metadata.version("torch"))
|
||||
if torchao_version < version.parse("0.7.0"):
|
||||
torchao_version = version.parse(importlib.metadata.version("torchao"))
|
||||
if torchao_version < version.parse("0.15.0"):
|
||||
raise RuntimeError(
|
||||
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
f"The minimum required version of `torchao` is 0.15.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
)
|
||||
|
||||
self.offload = False
|
||||
@@ -199,13 +198,13 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
)
|
||||
|
||||
def update_torch_dtype(self, torch_dtype):
|
||||
quant_type = self.quantization_config.quant_type
|
||||
if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
|
||||
if torch_dtype is not None and torch_dtype != torch.bfloat16:
|
||||
logger.warning(
|
||||
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
|
||||
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
|
||||
)
|
||||
config_name = self.quantization_config.quant_type.__class__.__name__
|
||||
is_int_quant = config_name.startswith("Int") or config_name.startswith("Uint")
|
||||
if is_int_quant and torch_dtype is not None and torch_dtype != torch.bfloat16:
|
||||
logger.warning(
|
||||
f"You are trying to set torch_dtype to {torch_dtype} for integer quantization, but "
|
||||
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
|
||||
)
|
||||
|
||||
if torch_dtype is None:
|
||||
# We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
|
||||
@@ -219,45 +218,16 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
return torch_dtype
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
quant_type = self.quantization_config.quant_type
|
||||
from accelerate.utils import CustomDtype
|
||||
|
||||
if isinstance(quant_type, str):
|
||||
if quant_type.startswith("int8"):
|
||||
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
|
||||
return torch.int8
|
||||
elif quant_type.startswith("int4"):
|
||||
return CustomDtype.INT4
|
||||
elif quant_type == "uintx_weight_only":
|
||||
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
|
||||
elif quant_type.startswith("uint"):
|
||||
return {
|
||||
1: torch.uint1,
|
||||
2: torch.uint2,
|
||||
3: torch.uint3,
|
||||
4: torch.uint4,
|
||||
5: torch.uint5,
|
||||
6: torch.uint6,
|
||||
7: torch.uint7,
|
||||
}[int(quant_type[4])]
|
||||
elif quant_type.startswith("float") or quant_type.startswith("fp"):
|
||||
return torch.bfloat16
|
||||
quant_type = self.quantization_config.quant_type
|
||||
config_name = quant_type.__class__.__name__
|
||||
size_digit = fuzzy_match_size(config_name)
|
||||
|
||||
elif is_torchao_version(">", "0.9.0"):
|
||||
from torchao.core.config import AOBaseConfig
|
||||
|
||||
quant_type = self.quantization_config.quant_type
|
||||
if isinstance(quant_type, AOBaseConfig):
|
||||
# Extract size digit using fuzzy match on the class name
|
||||
config_name = quant_type.__class__.__name__
|
||||
size_digit = fuzzy_match_size(config_name)
|
||||
|
||||
# Map the extracted digit to appropriate dtype
|
||||
if size_digit == "4":
|
||||
return CustomDtype.INT4
|
||||
else:
|
||||
# Default to int8
|
||||
return torch.int8
|
||||
if size_digit == "4":
|
||||
return CustomDtype.INT4
|
||||
else:
|
||||
return torch.int8
|
||||
|
||||
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
|
||||
return target_dtype
|
||||
@@ -337,29 +307,14 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
- Use a division factor of 8 for int4 weights
|
||||
- Use a division factor of 4 for int8 weights
|
||||
"""
|
||||
# Original mapping for non-AOBaseConfig types
|
||||
# For the uint types, this is a best guess. Once these types become more used
|
||||
# we can look into their nuances.
|
||||
if is_torchao_version(">", "0.9.0"):
|
||||
from torchao.core.config import AOBaseConfig
|
||||
|
||||
quant_type = self.quantization_config.quant_type
|
||||
if isinstance(quant_type, AOBaseConfig):
|
||||
# Extract size digit using fuzzy match on the class name
|
||||
config_name = quant_type.__class__.__name__
|
||||
size_digit = fuzzy_match_size(config_name)
|
||||
|
||||
if size_digit == "4":
|
||||
return 8
|
||||
else:
|
||||
return 4
|
||||
|
||||
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
|
||||
quant_type = self.quantization_config.quant_type
|
||||
for pattern, target_dtype in map_to_target_dtype.items():
|
||||
if fnmatch(quant_type, pattern):
|
||||
return target_dtype
|
||||
raise ValueError(f"Unsupported quant_type: {quant_type!r}")
|
||||
config_name = quant_type.__class__.__name__
|
||||
size_digit = fuzzy_match_size(config_name)
|
||||
|
||||
if size_digit == "4":
|
||||
return 8
|
||||
else:
|
||||
return 4
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
@@ -415,9 +370,17 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
|
||||
return _is_torchao_serializable
|
||||
|
||||
_TRAINABLE_QUANTIZATION_CONFIGS = (
|
||||
"Int8WeightOnlyConfig",
|
||||
"Int8DynamicActivationInt8WeightConfig",
|
||||
"Int8StaticActivationInt8WeightConfig",
|
||||
"Float8WeightOnlyConfig",
|
||||
"Float8DynamicActivationFloat8WeightConfig",
|
||||
)
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
return self.quantization_config.quant_type.startswith("int8")
|
||||
return self.quantization_config.quant_type.__class__.__name__ in self._TRAINABLE_QUANTIZATION_CONFIGS
|
||||
|
||||
@property
|
||||
def is_compileable(self) -> bool:
|
||||
|
||||
@@ -13,24 +13,29 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLWan
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
class AutoencoderKLWanTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderKLWan
|
||||
|
||||
def get_autoencoder_kl_wan_config(self):
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"base_dim": 3,
|
||||
"z_dim": 16,
|
||||
@@ -39,54 +44,51 @@ class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.T
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
image = torch.randn(batch_size, num_channels, num_frames, *sizes).to(torch_device)
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def dummy_input_tiling(self):
|
||||
# Bridge for AutoencoderTesterMixin which still uses the old interface
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return self.get_init_dict(), self.get_dummy_inputs()
|
||||
|
||||
def prepare_init_args_and_inputs_for_tiling(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (128, 128)
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
return {"sample": image}
|
||||
return self.get_init_dict(), {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
class TestAutoencoderKLWan(AutoencoderKLWanTesterConfig, ModelTesterMixin):
|
||||
base_precision = 1e-2
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_wan_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def prepare_init_args_and_inputs_for_tiling(self):
|
||||
init_dict = self.get_autoencoder_kl_wan_config()
|
||||
inputs_dict = self.dummy_input_tiling
|
||||
return init_dict, inputs_dict
|
||||
class TestAutoencoderKLWanTraining(AutoencoderKLWanTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for AutoencoderKLWan."""
|
||||
|
||||
@unittest.skip("Gradient checkpointing has not been implemented yet")
|
||||
@pytest.mark.skip(reason="Gradient checkpointing has not been implemented yet")
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_forward_with_norm_groups(self):
|
||||
|
||||
class TestAutoencoderKLWanMemory(AutoencoderKLWanTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderKLWan."""
|
||||
|
||||
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestAutoencoderKLWanSlicingTiling(AutoencoderKLWanTesterConfig, AutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderKLWan."""
|
||||
|
||||
@@ -44,9 +44,9 @@ class AutoencoderTesterMixin:
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
if not hasattr(self.model_class, "enable_tiling"):
|
||||
|
||||
@@ -98,6 +98,64 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _context_parallel_backward_worker(
|
||||
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict
|
||||
):
|
||||
"""Worker function for context parallel backward pass testing."""
|
||||
try:
|
||||
# Set up distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
# Get device configuration
|
||||
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
|
||||
backend = device_config["backend"]
|
||||
device_module = device_config["module"]
|
||||
|
||||
# Initialize process group
|
||||
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
|
||||
# Set device for this process
|
||||
device_module.set_device(rank)
|
||||
device = torch.device(f"{torch_device}:{rank}")
|
||||
|
||||
# Create model in training mode
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
# Run forward and backward pass
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
# Check that backward actually produced at least one valid gradient
|
||||
grads = [p.grad for p in model.parameters() if p.requires_grad and p.grad is not None]
|
||||
has_valid_grads = len(grads) > 0 and all(torch.isfinite(g).all() for g in grads)
|
||||
|
||||
# Only rank 0 reports results
|
||||
if rank == 0:
|
||||
return_dict["status"] = "success"
|
||||
return_dict["has_valid_grads"] = bool(has_valid_grads)
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
return_dict["status"] = "error"
|
||||
return_dict["error"] = str(e)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _custom_mesh_worker(
|
||||
rank,
|
||||
world_size,
|
||||
@@ -204,6 +262,51 @@ class ContextParallelTesterMixin:
|
||||
def test_context_parallel_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_inference(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_backward(self, cp_type, batch_size: int = 1):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
if active_backend == AttentionBackendName.NATIVE:
|
||||
pytest.skip("Ring attention is not supported with the native attention backend.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
|
||||
|
||||
# Move all tensors to CPU for multiprocessing
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
# Find a free port for distributed communication
|
||||
master_port = _find_free_port()
|
||||
|
||||
# Use multiprocessing manager for cross-process communication
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
# Spawn worker processes
|
||||
mp.spawn(
|
||||
_context_parallel_backward_worker,
|
||||
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel backward pass failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
assert return_dict.get("has_valid_grads"), "Context parallel backward pass did not produce valid gradients."
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_backward_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_backward(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
|
||||
@@ -25,7 +25,6 @@ from diffusers.utils.import_utils import (
|
||||
is_nvidia_modelopt_available,
|
||||
is_optimum_quanto_available,
|
||||
is_torchao_available,
|
||||
is_torchao_version,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
@@ -63,8 +62,7 @@ if is_gguf_available():
|
||||
pass
|
||||
|
||||
if is_torchao_available():
|
||||
if is_torchao_version(">=", "0.9.0"):
|
||||
pass
|
||||
import torchao.quantization as _torchao_quantization
|
||||
|
||||
|
||||
class LoRALayer(torch.nn.Module):
|
||||
@@ -806,9 +804,9 @@ class TorchAoConfigMixin:
|
||||
"""
|
||||
|
||||
TORCHAO_QUANT_TYPES = {
|
||||
"int4wo": {"quant_type": "int4_weight_only"},
|
||||
"int8wo": {"quant_type": "int8_weight_only"},
|
||||
"int8dq": {"quant_type": "int8_dynamic_activation_int8_weight"},
|
||||
"int4wo": "Int4WeightOnlyConfig",
|
||||
"int8wo": "Int8WeightOnlyConfig",
|
||||
"int8dq": "Int8DynamicActivationInt8WeightConfig",
|
||||
}
|
||||
|
||||
TORCHAO_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
@@ -817,8 +815,13 @@ class TorchAoConfigMixin:
|
||||
"int8dq": 1.5,
|
||||
}
|
||||
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = TorchAoConfig(**config_kwargs)
|
||||
@staticmethod
|
||||
def _get_quant_config(config_name):
|
||||
config_cls = getattr(_torchao_quantization, config_name)
|
||||
return TorchAoConfig(config_cls())
|
||||
|
||||
def _create_quantized_model(self, config_name, **extra_kwargs):
|
||||
config = self._get_quant_config(config_name)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs["device_map"] = str(torch_device)
|
||||
|
||||
@@ -12,53 +12,71 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import HunyuanVideo15Transformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideo15TransformerTesterConfig(BaseModelTesterConfig):
|
||||
class HunyuanVideo15Transformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideo15Transformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
model_split_percents = [0.99, 0.99, 0.99]
|
||||
|
||||
text_embed_dim = 16
|
||||
text_embed_2_dim = 8
|
||||
image_embed_dim = 12
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return HunyuanVideo15Transformer3DModel
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 8
|
||||
width = 8
|
||||
sequence_length = 6
|
||||
sequence_length_2 = 4
|
||||
image_sequence_length = 3
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, self.text_embed_dim), device=torch_device)
|
||||
encoder_hidden_states_2 = torch.randn(
|
||||
(batch_size, sequence_length_2, self.text_embed_2_dim), device=torch_device
|
||||
)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length), device=torch_device)
|
||||
encoder_attention_mask_2 = torch.ones((batch_size, sequence_length_2), device=torch_device)
|
||||
# All zeros for inducing T2V path in the model.
|
||||
image_embeds = torch.zeros((batch_size, image_sequence_length, self.image_embed_dim), device=torch_device)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.99, 0.99, 0.99]
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (4, 1, 8, 8)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (4, 1, 8, 8)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"encoder_hidden_states_2": encoder_hidden_states_2,
|
||||
"encoder_attention_mask_2": encoder_attention_mask_2,
|
||||
"image_embeds": image_embeds,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 8, 8)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 8, 8)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
@@ -75,40 +93,9 @@ class HunyuanVideo15TransformerTesterConfig(BaseModelTesterConfig):
|
||||
"target_size": 16,
|
||||
"task_type": "t2v",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 8
|
||||
width = 8
|
||||
sequence_length = 6
|
||||
sequence_length_2 = 4
|
||||
image_sequence_length = 3
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, self.text_embed_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states_2": randn_tensor(
|
||||
(batch_size, sequence_length_2, self.text_embed_2_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_attention_mask": torch.ones((batch_size, sequence_length), device=torch_device),
|
||||
"encoder_attention_mask_2": torch.ones((batch_size, sequence_length_2), device=torch_device),
|
||||
"image_embeds": torch.zeros(
|
||||
(batch_size, image_sequence_length, self.image_embed_dim), device=torch_device
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TestHunyuanVideo15Transformer(HunyuanVideo15TransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestHunyuanVideo15TransformerTraining(HunyuanVideo15TransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanVideo15Transformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -13,100 +13,51 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import HunyuanDiT2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanDiTTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return HunyuanDiT2DModel
|
||||
class HunyuanDiTTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanDiT2DModel
|
||||
main_input_name = "hidden_states"
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-hunyuan-dit-pipe"
|
||||
|
||||
@property
|
||||
def pretrained_model_kwargs(self):
|
||||
return {"subfolder": "transformer"}
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (8, 8, 8)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (4, 8, 8)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"sample_size": 8,
|
||||
"patch_size": 2,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 2,
|
||||
"cross_attention_dim": 8,
|
||||
"cross_attention_dim_t5": 8,
|
||||
"pooled_projection_dim": 4,
|
||||
"hidden_size": 16,
|
||||
"text_len": 4,
|
||||
"text_len_t5": 4,
|
||||
"activation_fn": "gelu-approximate",
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 4
|
||||
height = width = 8
|
||||
embedding_dim = 8
|
||||
sequence_length = 4
|
||||
sequence_length_t5 = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
text_embedding_mask = torch.ones(size=(batch_size, sequence_length)).to(torch_device)
|
||||
encoder_hidden_states_t5 = randn_tensor(
|
||||
(batch_size, sequence_length_t5, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_t5 = torch.randn((batch_size, sequence_length_t5, embedding_dim)).to(torch_device)
|
||||
text_embedding_mask_t5 = torch.ones(size=(batch_size, sequence_length_t5)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).float().to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,), dtype=encoder_hidden_states.dtype).to(torch_device)
|
||||
|
||||
original_size = [1024, 1024]
|
||||
target_size = [16, 16]
|
||||
crops_coords_top_left = [0, 0]
|
||||
add_time_ids = list(original_size + target_size + crops_coords_top_left)
|
||||
add_time_ids = torch.tensor([add_time_ids] * batch_size, dtype=torch.float32).to(torch_device)
|
||||
add_time_ids = torch.tensor([add_time_ids, add_time_ids], dtype=encoder_hidden_states.dtype).to(torch_device)
|
||||
style = torch.zeros(size=(batch_size,), dtype=int).to(torch_device)
|
||||
image_rotary_emb = [
|
||||
torch.ones(size=(1, 8), dtype=torch.float32),
|
||||
torch.zeros(size=(1, 8), dtype=torch.float32),
|
||||
torch.ones(size=(1, 8), dtype=encoder_hidden_states.dtype),
|
||||
torch.zeros(size=(1, 8), dtype=encoder_hidden_states.dtype),
|
||||
]
|
||||
|
||||
return {
|
||||
@@ -121,26 +72,42 @@ class HunyuanDiTTesterConfig(BaseModelTesterConfig):
|
||||
"image_rotary_emb": image_rotary_emb,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 8, 8)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (8, 8, 8)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 8,
|
||||
"patch_size": 2,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 2,
|
||||
"cross_attention_dim": 8,
|
||||
"cross_attention_dim_t5": 8,
|
||||
"pooled_projection_dim": 4,
|
||||
"hidden_size": 16,
|
||||
"text_len": 4,
|
||||
"text_len_t5": 4,
|
||||
"activation_fn": "gelu-approximate",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
class TestHunyuanDiT(HunyuanDiTTesterConfig, ModelTesterMixin):
|
||||
def test_output(self):
|
||||
batch_size = self.get_dummy_inputs()[self.main_input_name].shape[0]
|
||||
super().test_output(expected_output_shape=(batch_size,) + self.output_shape)
|
||||
super().test_output(
|
||||
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
|
||||
)
|
||||
|
||||
@unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0")
|
||||
def test_set_xformers_attn_processor_for_determinism(self):
|
||||
pass
|
||||
|
||||
class TestHunyuanDiTTraining(HunyuanDiTTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanDiT2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestHunyuanDiTCompile(HunyuanDiTTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestHunyuanDiTBitsAndBytes(HunyuanDiTTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for HunyuanDiT."""
|
||||
|
||||
|
||||
class TestHunyuanDiTTorchAo(HunyuanDiTTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for HunyuanDiT."""
|
||||
@unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
pass
|
||||
|
||||
@@ -12,59 +12,64 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import HunyuanVideoTransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# ======================== HunyuanVideo Text-to-Video ========================
|
||||
|
||||
|
||||
class HunyuanVideoTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return HunyuanVideoTransformer3DModel
|
||||
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-random-hunyuanvideo"
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
pooled_projection_dim = 8
|
||||
sequence_length = 12
|
||||
|
||||
@property
|
||||
def pretrained_model_kwargs(self):
|
||||
return {"subfolder": "transformer"}
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_projections,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
@@ -80,9 +85,30 @@ class HunyuanVideoTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
"image_condition_type": None,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 8
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
@@ -90,74 +116,105 @@ class HunyuanVideoTransformerTesterConfig(BaseModelTesterConfig):
|
||||
pooled_projection_dim = 8
|
||||
sequence_length = 12
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(batch_size, pooled_projection_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
|
||||
"guidance": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(
|
||||
torch_device, dtype=torch.float32
|
||||
),
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_projections,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (8, 1, 16, 16)
|
||||
|
||||
class TestHunyuanVideoTransformer(HunyuanVideoTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 8,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 10,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"num_refiner_layers": 1,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"guidance_embeds": True,
|
||||
"text_embed_dim": 16,
|
||||
"pooled_projection_dim": 8,
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
"image_condition_type": None,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_output(self):
|
||||
super().test_output(expected_output_shape=(1, *self.output_shape))
|
||||
|
||||
class TestHunyuanVideoTransformerTraining(HunyuanVideoTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestHunyuanVideoTransformerCompile(HunyuanVideoTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class TestHunyuanVideoTransformerBitsAndBytes(HunyuanVideoTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for HunyuanVideo Transformer."""
|
||||
|
||||
|
||||
class TestHunyuanVideoTransformerTorchAo(HunyuanVideoTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for HunyuanVideo Transformer."""
|
||||
|
||||
|
||||
# ======================== HunyuanVideo Image-to-Video (Latent Concat) ========================
|
||||
|
||||
|
||||
class HunyuanVideoI2VTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return HunyuanVideoTransformer3DModel
|
||||
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 2 * 4 + 1
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
pooled_projection_dim = 8
|
||||
sequence_length = 12
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_projections,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
}
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
def input_shape(self):
|
||||
return (8, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 2 * 4 + 1,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
@@ -173,9 +230,33 @@ class HunyuanVideoI2VTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
"image_condition_type": "latent_concat",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_channels = 2 * 4 + 1
|
||||
def test_output(self):
|
||||
super().test_output(expected_output_shape=(1, *self.output_shape))
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 2
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
@@ -183,58 +264,32 @@ class HunyuanVideoI2VTransformerTesterConfig(BaseModelTesterConfig):
|
||||
pooled_projection_dim = 8
|
||||
sequence_length = 12
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(batch_size, pooled_projection_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_projections,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestHunyuanVideoI2VTransformer(HunyuanVideoI2VTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_output(self):
|
||||
super().test_output(expected_output_shape=(1, *self.output_shape))
|
||||
|
||||
|
||||
class TestHunyuanVideoI2VTransformerCompile(HunyuanVideoI2VTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
# ======================== HunyuanVideo Token Replace Image-to-Video ========================
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return HunyuanVideoTransformer3DModel
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
def input_shape(self):
|
||||
return (8, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 2,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
@@ -250,42 +305,19 @@ class HunyuanVideoTokenReplaceTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"rope_axes_dim": (2, 4, 4),
|
||||
"image_condition_type": "token_replace",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_channels = 2
|
||||
num_frames = 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
pooled_projection_dim = 8
|
||||
sequence_length = 12
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(batch_size, pooled_projection_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
|
||||
"guidance": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(
|
||||
torch_device, dtype=torch.float32
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TestHunyuanVideoTokenReplaceTransformer(HunyuanVideoTokenReplaceTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_output(self):
|
||||
super().test_output(expected_output_shape=(1, *self.output_shape))
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
class TestHunyuanVideoTokenReplaceTransformerCompile(
|
||||
HunyuanVideoTokenReplaceTransformerTesterConfig, TorchCompileTesterMixin
|
||||
):
|
||||
pass
|
||||
|
||||
class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -12,49 +12,84 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import HunyuanVideoFramepackTransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideoFramepackTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return HunyuanVideoFramepackTransformer3DModel
|
||||
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoFramepackTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
model_split_percents = [0.5, 0.7, 0.9]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 3
|
||||
height = 4
|
||||
width = 4
|
||||
text_encoder_embedding_dim = 16
|
||||
image_encoder_embedding_dim = 16
|
||||
pooled_projection_dim = 8
|
||||
sequence_length = 12
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.5, 0.7, 0.9]
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
|
||||
image_embeds = torch.randn((batch_size, sequence_length, image_encoder_embedding_dim)).to(torch_device)
|
||||
indices_latents = torch.ones((3,)).to(torch_device)
|
||||
latents_clean = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
|
||||
indices_latents_clean = torch.ones((num_frames - 1,)).to(torch_device)
|
||||
latents_history_2x = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
|
||||
indices_latents_history_2x = torch.ones((num_frames - 1,)).to(torch_device)
|
||||
latents_history_4x = torch.randn((batch_size, num_channels, (num_frames - 1) * 4, height, width)).to(
|
||||
torch_device
|
||||
)
|
||||
indices_latents_history_4x = torch.ones(((num_frames - 1) * 4,)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (4, 3, 4, 4)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (4, 3, 4, 4)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_projections,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"guidance": guidance,
|
||||
"image_embeds": image_embeds,
|
||||
"indices_latents": indices_latents,
|
||||
"latents_clean": latents_clean,
|
||||
"indices_latents_clean": indices_latents_clean,
|
||||
"latents_history_2x": latents_history_2x,
|
||||
"indices_latents_history_2x": indices_latents_history_2x,
|
||||
"latents_history_4x": latents_history_4x,
|
||||
"indices_latents_history_4x": indices_latents_history_4x,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 3, 4, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 3, 4, 4)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_attention_heads": 2,
|
||||
@@ -73,64 +108,9 @@ class HunyuanVideoFramepackTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"image_proj_dim": 16,
|
||||
"has_clean_x_embedder": True,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
num_frames = 3
|
||||
height = 4
|
||||
width = 4
|
||||
text_encoder_embedding_dim = 16
|
||||
image_encoder_embedding_dim = 16
|
||||
pooled_projection_dim = 8
|
||||
sequence_length = 12
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(batch_size, pooled_projection_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
|
||||
"guidance": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"image_embeds": randn_tensor(
|
||||
(batch_size, sequence_length, image_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"indices_latents": torch.ones((num_frames,)).to(torch_device),
|
||||
"latents_clean": randn_tensor(
|
||||
(batch_size, num_channels, num_frames - 1, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"indices_latents_clean": torch.ones((num_frames - 1,)).to(torch_device),
|
||||
"latents_history_2x": randn_tensor(
|
||||
(batch_size, num_channels, num_frames - 1, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"indices_latents_history_2x": torch.ones((num_frames - 1,)).to(torch_device),
|
||||
"latents_history_4x": randn_tensor(
|
||||
(batch_size, num_channels, (num_frames - 1) * 4, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"indices_latents_history_4x": torch.ones(((num_frames - 1) * 4,)).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestHunyuanVideoFramepackTransformer(HunyuanVideoFramepackTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestHunyuanVideoFramepackTransformerTraining(HunyuanVideoFramepackTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanVideoFramepackTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
242
tests/modular_pipelines/test_conditional_pipeline_blocks.py
Normal file
242
tests/modular_pipelines/test_conditional_pipeline_blocks.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
AutoPipelineBlocks,
|
||||
ConditionalPipelineBlocks,
|
||||
InputParam,
|
||||
ModularPipelineBlocks,
|
||||
)
|
||||
|
||||
|
||||
class TextToImageBlock(ModularPipelineBlocks):
|
||||
model_name = "text2img"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "text-to-image workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "text2img"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ImageToImageBlock(ModularPipelineBlocks):
|
||||
model_name = "img2img"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt"), InputParam(name="image")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "image-to-image workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "img2img"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class InpaintBlock(ModularPipelineBlocks):
|
||||
model_name = "inpaint"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "inpaint workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "inpaint"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ConditionalImageBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = "text2img"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Conditional image blocks for testing"
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None # falls back to default_block_name
|
||||
|
||||
|
||||
class OptionalConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = None # no default; block can be skipped
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Optional conditional blocks (skippable)"
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None
|
||||
|
||||
|
||||
class AutoImageBlocks(AutoPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto image blocks for testing"
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksSelectBlock:
|
||||
def test_select_block_with_mask(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask="something") == "inpaint"
|
||||
|
||||
def test_select_block_with_image(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(image="something") == "img2img"
|
||||
|
||||
def test_select_block_with_mask_and_image(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask="m", image="i") == "inpaint"
|
||||
|
||||
def test_select_block_no_triggers_returns_none(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block() is None
|
||||
|
||||
def test_select_block_explicit_none_values(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask=None, image=None) is None
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksWorkflowSelection:
|
||||
def test_default_workflow_when_no_triggers(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert execution is not None
|
||||
assert isinstance(execution, TextToImageBlock)
|
||||
|
||||
def test_mask_trigger_selects_inpaint(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_image_trigger_selects_img2img(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
def test_mask_and_image_selects_inpaint(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True, image=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_skippable_block_returns_none(self):
|
||||
blocks = OptionalConditionalBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert execution is None
|
||||
|
||||
def test_skippable_block_still_selects_when_triggered(self):
|
||||
blocks = OptionalConditionalBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
|
||||
class TestAutoPipelineBlocksSelectBlock:
|
||||
def test_auto_select_mask(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(mask="m") == "inpaint"
|
||||
|
||||
def test_auto_select_image(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(image="i") == "img2img"
|
||||
|
||||
def test_auto_select_default(self):
|
||||
blocks = AutoImageBlocks()
|
||||
# No trigger -> returns None -> falls back to default (text2img)
|
||||
assert blocks.select_block() is None
|
||||
|
||||
def test_auto_select_priority_order(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(mask="m", image="i") == "inpaint"
|
||||
|
||||
|
||||
class TestAutoPipelineBlocksWorkflowSelection:
|
||||
def test_auto_default_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert isinstance(execution, TextToImageBlock)
|
||||
|
||||
def test_auto_mask_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_auto_image_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksStructure:
|
||||
def test_block_names_accessible(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
sub = dict(blocks.sub_blocks)
|
||||
assert set(sub.keys()) == {"inpaint", "img2img", "text2img"}
|
||||
|
||||
def test_sub_block_types(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
sub = dict(blocks.sub_blocks)
|
||||
assert isinstance(sub["inpaint"], InpaintBlock)
|
||||
assert isinstance(sub["img2img"], ImageToImageBlock)
|
||||
assert isinstance(sub["text2img"], TextToImageBlock)
|
||||
|
||||
def test_description(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert "Conditional" in blocks.description
|
||||
@@ -10,11 +10,6 @@ from huggingface_hub import hf_hub_download
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines import (
|
||||
ConditionalPipelineBlocks,
|
||||
LoopSequentialPipelineBlocks,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
@@ -25,7 +20,6 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
@@ -498,117 +492,6 @@ class ModularGuiderTesterMixin:
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def get_dummy_conditional_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
class DummyConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [DummyBlockOne, DummyBlockTwo]
|
||||
block_names = ["block_one", "block_two"]
|
||||
block_trigger_inputs = []
|
||||
|
||||
def select_block(self, **kwargs):
|
||||
return "block_one"
|
||||
|
||||
return DummyConditionalBlocks()
|
||||
|
||||
def get_dummy_loop_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
|
||||
|
||||
def test_sequential_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_sequential_block_requirements_warnings(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
def test_conditional_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_conditional_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
def test_loop_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_loop_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
|
||||
class TestModularModelCardContent:
|
||||
def create_mock_block(self, name="TestBlock", description="Test block description"):
|
||||
class MockBlock:
|
||||
|
||||
@@ -24,14 +24,18 @@ import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.modular_pipelines import (
|
||||
ComponentSpec,
|
||||
ConditionalPipelineBlocks,
|
||||
InputParam,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipelineBlocks,
|
||||
OutputParam,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
WanModularPipeline,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
|
||||
from ..testing_utils import CaptureLogger, nightly, require_torch, require_torch_accelerator, slow, torch_device
|
||||
|
||||
|
||||
def _create_tiny_model_dir(model_dir):
|
||||
@@ -463,6 +467,117 @@ class TestModularCustomBlocks:
|
||||
assert output_prompt.startswith("Modular diffusers + ")
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def get_dummy_conditional_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
class DummyConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [DummyBlockOne, DummyBlockTwo]
|
||||
block_names = ["block_one", "block_two"]
|
||||
block_trigger_inputs = []
|
||||
|
||||
def select_block(self, **kwargs):
|
||||
return "block_one"
|
||||
|
||||
return DummyConditionalBlocks()
|
||||
|
||||
def get_dummy_loop_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
|
||||
|
||||
def test_sequential_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_sequential_block_requirements_warnings(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
def test_conditional_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_conditional_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
def test_loop_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_loop_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch
|
||||
|
||||
@@ -14,13 +14,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import importlib.metadata
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
@@ -55,6 +53,20 @@ from ..test_torch_compile_utils import QuantCompileTests
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major == 8:
|
||||
return minor >= 9
|
||||
return major >= 9
|
||||
elif torch.xpu.is_available():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -64,75 +76,56 @@ if is_torch_available():
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
from torchao.quantization import (
|
||||
Float8WeightOnlyConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8DynamicActivationIntxWeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
IntxWeightOnlyConfig,
|
||||
)
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
from torchao.quantization.quant_primitives import MappingType
|
||||
from torchao.utils import get_model_size_in_bytes
|
||||
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.9.0"):
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
Makes sure the config format is properly set
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int4_weight_only")
|
||||
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(version=2))
|
||||
torchao_orig_config = quantization_config.to_dict()
|
||||
|
||||
for key in torchao_orig_config:
|
||||
self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key])
|
||||
self.assertIn("quant_type", torchao_orig_config)
|
||||
self.assertIn("quant_method", torchao_orig_config)
|
||||
|
||||
def test_post_init_check(self):
|
||||
"""
|
||||
Test kwargs validations in TorchAoConfig
|
||||
Test that non-AOBaseConfig types are rejected
|
||||
"""
|
||||
_ = TorchAoConfig("int4_weight_only")
|
||||
with self.assertRaisesRegex(ValueError, "is not supported"):
|
||||
_ = TorchAoConfig("uint8")
|
||||
_ = TorchAoConfig(Int4WeightOnlyConfig())
|
||||
with self.assertRaises(TypeError):
|
||||
_ = TorchAoConfig("int4_weight_only")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"):
|
||||
_ = TorchAoConfig("int4_weight_only", group_size1=32)
|
||||
with self.assertRaises(TypeError):
|
||||
_ = TorchAoConfig(42)
|
||||
|
||||
def test_repr(self):
|
||||
"""
|
||||
Check that there is no error in the repr
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
|
||||
expected_repr = """TorchAoConfig {
|
||||
"modules_to_not_convert": [
|
||||
"conv"
|
||||
],
|
||||
"quant_method": "torchao",
|
||||
"quant_type": "int4_weight_only",
|
||||
"quant_type_kwargs": {
|
||||
"group_size": 8
|
||||
}
|
||||
}""".replace(" ", "").replace("\n", "")
|
||||
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
|
||||
self.assertEqual(quantization_repr, expected_repr)
|
||||
|
||||
quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC)
|
||||
expected_repr = """TorchAoConfig {
|
||||
"modules_to_not_convert": null,
|
||||
"quant_method": "torchao",
|
||||
"quant_type": "int4dq",
|
||||
"quant_type_kwargs": {
|
||||
"act_mapping_type": "SYMMETRIC",
|
||||
"group_size": 64
|
||||
}
|
||||
}""".replace(" ", "").replace("\n", "")
|
||||
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
|
||||
self.assertEqual(quantization_repr, expected_repr)
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig(version=2), modules_to_not_convert=["conv"])
|
||||
quantization_repr = repr(quantization_config)
|
||||
self.assertIn("TorchAoConfig", quantization_repr)
|
||||
self.assertIn("torchao", quantization_repr)
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
@@ -234,79 +227,30 @@ class TorchAoTest(unittest.TestCase):
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
# fmt: off
|
||||
QUANTIZATION_TYPES_TO_TEST = [
|
||||
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
|
||||
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
|
||||
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
(Int4WeightOnlyConfig(version=2), np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
|
||||
(Int8DynamicActivationIntxWeightConfig(version=2), np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
|
||||
(Int8WeightOnlyConfig(version=2), np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
(Int8DynamicActivationInt8WeightConfig(version=2), np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
(IntxWeightOnlyConfig(dtype=torch.uint4, group_size=16, version=2), np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
(IntxWeightOnlyConfig(dtype=torch.uint7, group_size=16, version=2), np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
]
|
||||
|
||||
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
if _is_xpu_or_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
|
||||
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
# =====
|
||||
# The following lead to an internal torch error:
|
||||
# RuntimeError: mat2 shape (32x4 must be divisible by 16
|
||||
# Skip these for now; TODO(aryan): investigate later
|
||||
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
# Cutlass fails to initialize for below
|
||||
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
|
||||
# =====
|
||||
(Float8WeightOnlyConfig(weight_dtype=torch.float8_e5m2), np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
|
||||
(Float8WeightOnlyConfig(weight_dtype=torch.float8_e4m3fn), np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
])
|
||||
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
quant_kwargs = {}
|
||||
if quantization_name in ["uint4wo", "uint7wo"]:
|
||||
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
|
||||
quant_kwargs.update({"group_size": 16})
|
||||
quantization_config = TorchAoConfig(
|
||||
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
|
||||
)
|
||||
for quant_config, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config, modules_to_not_convert=["x_embedder"])
|
||||
self._test_quant_type(quantization_config, expected_slice, model_id)
|
||||
|
||||
@unittest.skip("Skipping floatx quantization tests")
|
||||
def test_floatx_quantization(self):
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
|
||||
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
|
||||
self._test_quant_type(
|
||||
quantization_config,
|
||||
np.array(
|
||||
[
|
||||
0.4648,
|
||||
0.5195,
|
||||
0.5547,
|
||||
0.4180,
|
||||
0.4434,
|
||||
0.6445,
|
||||
0.4316,
|
||||
0.4531,
|
||||
0.5625,
|
||||
]
|
||||
),
|
||||
model_id,
|
||||
)
|
||||
else:
|
||||
# Make sure the correct error is thrown
|
||||
with self.assertRaisesRegex(ValueError, "Please downgrade"):
|
||||
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
|
||||
|
||||
def test_int4wo_quant_bfloat16_conversion(self):
|
||||
"""
|
||||
Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization.
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64))
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
@@ -361,7 +305,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
else:
|
||||
expected_slice = expected_slice_offload
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64))
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
@@ -385,7 +329,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as offload_folder:
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
|
||||
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64))
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-sharded",
|
||||
subfolder="transformer",
|
||||
@@ -406,7 +350,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
|
||||
|
||||
def test_modules_to_not_convert(self):
|
||||
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig(), modules_to_not_convert=["transformer_blocks.0"])
|
||||
quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
@@ -422,7 +366,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
quantized_layer = quantized_model_with_not_convert.proj_out
|
||||
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
|
||||
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
@@ -436,7 +380,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
self.assertTrue(size_quantized < size_quantized_with_not_convert)
|
||||
|
||||
def test_training(self):
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
@@ -470,7 +414,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
def test_torch_compile(self):
|
||||
r"""Test that verifies if torch.compile works with torchao quantization."""
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
quantization_config = TorchAoConfig("int8_weight_only")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
components = self.get_dummy_components(quantization_config, model_id=model_id)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.to(device=torch_device)
|
||||
@@ -491,11 +435,15 @@ class TorchAoTest(unittest.TestCase):
|
||||
memory footprint of the converted model and the class type of the linear layers of the converted models
|
||||
"""
|
||||
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
|
||||
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"]
|
||||
transformer_int4wo = self.get_dummy_components(TorchAoConfig(Int4WeightOnlyConfig()), model_id=model_id)[
|
||||
"transformer"
|
||||
]
|
||||
transformer_int4wo_gs32 = self.get_dummy_components(
|
||||
TorchAoConfig("int4wo", group_size=32), model_id=model_id
|
||||
TorchAoConfig(Int4WeightOnlyConfig(group_size=32)), model_id=model_id
|
||||
)["transformer"]
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()), model_id=model_id)[
|
||||
"transformer"
|
||||
]
|
||||
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
|
||||
|
||||
# Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
|
||||
@@ -553,20 +501,22 @@ class TorchAoTest(unittest.TestCase):
|
||||
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
|
||||
del transformer_bf16
|
||||
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
|
||||
transformer_int8wo = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()), model_id=model_id)[
|
||||
"transformer"
|
||||
]
|
||||
transformer_int8wo.to(torch_device)
|
||||
quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
|
||||
assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
|
||||
|
||||
def test_wrong_config(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises(TypeError):
|
||||
self.get_dummy_components(TorchAoConfig("int42"))
|
||||
|
||||
def test_sequential_cpu_offload(self):
|
||||
r"""
|
||||
A test that checks if inference runs as expected when sequential cpu offloading is enabled.
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
@@ -574,7 +524,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
_ = pipe(**inputs)
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.9.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
def test_aobase_config(self):
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
@@ -587,7 +537,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
model_name = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
@@ -595,8 +545,8 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
|
||||
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
|
||||
def get_dummy_model(self, quant_type, device=None):
|
||||
quantization_config = TorchAoConfig(quant_type)
|
||||
quantized_model = FluxTransformer2DModel.from_pretrained(
|
||||
self.model_name,
|
||||
subfolder="transformer",
|
||||
@@ -632,8 +582,8 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
|
||||
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
|
||||
def _test_original_model_expected_slice(self, quant_type, expected_slice):
|
||||
quantized_model = self.get_dummy_model(quant_type, torch_device)
|
||||
inputs = self.get_dummy_tensor_inputs(torch_device)
|
||||
output = quantized_model(**inputs)[0]
|
||||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
|
||||
@@ -641,8 +591,8 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
|
||||
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
|
||||
def _check_serialization_expected_slice(self, quant_type, expected_slice, device):
|
||||
quantized_model = self.get_dummy_model(quant_type, device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
@@ -662,43 +612,42 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
|
||||
|
||||
def test_int_a8w8_accelerator(self):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
quant_type = Int8DynamicActivationInt8WeightConfig()
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
device = torch_device
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
self._test_original_model_expected_slice(quant_type, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
def test_int_a16w8_accelerator(self):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
quant_type = Int8WeightOnlyConfig()
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = torch_device
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
self._test_original_model_expected_slice(quant_type, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
def test_int_a8w8_cpu(self):
|
||||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
|
||||
quant_type = Int8DynamicActivationInt8WeightConfig()
|
||||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
|
||||
device = "cpu"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
self._test_original_model_expected_slice(quant_type, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
def test_int_a16w8_cpu(self):
|
||||
quant_method, quant_method_kwargs = "int8_weight_only", {}
|
||||
quant_type = Int8WeightOnlyConfig()
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = "cpu"
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
self._test_original_model_expected_slice(quant_type, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.9.0")
|
||||
def test_aobase_config(self):
|
||||
quant_method, quant_method_kwargs = Int8WeightOnlyConfig(), {}
|
||||
quant_type = Int8WeightOnlyConfig()
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = torch_device
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
self._test_original_model_expected_slice(quant_type, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
@@ -744,7 +693,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoTests(unittest.TestCase):
|
||||
@@ -817,29 +766,25 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
def test_quantization(self):
|
||||
# fmt: off
|
||||
QUANTIZATION_TYPES_TO_TEST = [
|
||||
("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
|
||||
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
|
||||
(Int8WeightOnlyConfig(), np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
|
||||
(Int8DynamicActivationInt8WeightConfig(), np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
|
||||
]
|
||||
|
||||
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
if _is_xpu_or_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
|
||||
(Float8WeightOnlyConfig(weight_dtype=torch.float8_e4m3fn), np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
|
||||
])
|
||||
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
|
||||
for quant_config, expected_slice in QUANTIZATION_TYPES_TO_TEST:
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config, modules_to_not_convert=["x_embedder"])
|
||||
self._test_quant_type(quantization_config, expected_slice)
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
backend_synchronize(torch_device)
|
||||
|
||||
def test_serialization_int8wo(self):
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
pipe = FluxPipeline(**components)
|
||||
pipe.enable_model_cpu_offload()
|
||||
@@ -876,7 +821,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
def test_memory_footprint_int4wo(self):
|
||||
# The original checkpoints are in bf16 and about 24 GB
|
||||
expected_memory_in_gb = 6.0
|
||||
quantization_config = TorchAoConfig("int4wo")
|
||||
quantization_config = TorchAoConfig(Int4WeightOnlyConfig())
|
||||
cache_dir = None
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
@@ -891,7 +836,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
def test_memory_footprint_int8wo(self):
|
||||
# The original checkpoints are in bf16 and about 24 GB
|
||||
expected_memory_in_gb = 12.0
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
cache_dir = None
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
@@ -906,7 +851,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user