Compare commits

...

17 Commits

Author SHA1 Message Date
Pranav Thombre
7e463ea4cc [docs] Add NeMo Automodel training guide (#13306)
* [docs] Add NeMo Automodel training guide

Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com>

* Update docs/source/en/training/nemo_automodel.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/training/nemo_automodel.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* adding contacts into the readme

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Address CR comments

Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com>

* Update docs/source/en/training/nemo_automodel.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/training/nemo_automodel.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: linnan wang <wangnan318@gmail.com>
2026-03-30 10:21:58 -07:00
tcaimm
7f2b34bced Add train flux2 series lora config (#13011)
* feat(lora): support FLUX.2 single blocks + update README

* add img2img config & add explanatory comments

* simple modify

---------

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2026-03-30 14:22:04 +03:00
Cheung Ka Wai
e1e7d58a4a Fix Ulysses SP backward with SDPA (#13328)
* add UT for backward

* fix SDPA attention backward
2026-03-30 15:15:27 +05:30
Steven Liu
a93f7f137a [docs] refactor model skill (#13334)
* refactor

* feedback

* feedback

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-29 23:13:52 -07:00
Sayak Paul
10ec3040a2 [ci] move to assert instead of self.Assert* (#13366)
move to assert instead of self.Assert*
2026-03-30 11:09:14 +05:30
Howard Zhang
f2be8bd6b3 change minimum version guard for torchao to 0.15.0 (#13355) 2026-03-28 09:11:51 +05:30
Sayak Paul
7da22b9db5 [ci] include checkout step in claude review workflow (#13352)
up
2026-03-27 17:28:31 +05:30
Howard Zhang
1fe2125802 remove str option for quantization config in torchao (#13291)
* remove str option for quantization config in torchao

* Apply style fixes

* minor fixes

* Added AOBaseConfig docs to torchao.md

* minor fixes for removing str option torchao

* minor change to add back int and uint check

* minor fixes

* minor fixes to tests

* Update tests/quantization/torchao/test_torchao.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/quantization/torchao.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update tests/quantization/torchao/test_torchao.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* version=2 update to test_torchao.py

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-27 08:52:37 +05:30
dg845
7298f5be93 Update LTX-2 Docs to Cover LTX-2.3 Models (#13337)
* Update LTX-2 docs to cover multimodal guidance and prompt enhancement

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply reviewer feedback

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-03-26 17:51:29 -07:00
Sayak Paul
b757035df6 fix claude workflow to include id-token with write. (#13338) 2026-03-26 15:39:10 +05:30
kaixuanliu
41e1003316 avoid hardcode device in flux-control example (#13336)
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
2026-03-26 12:40:53 +05:30
Sayak Paul
85ffcf1db2 [tests] Tests for conditional pipeline blocks (#13247)
* implement test suite for conditional blocks.

* remove

* another fix.

* Revert "another fix."

This reverts commit ab07b603ab.
2026-03-26 08:48:16 +05:30
Steven Liu
cbf4d9a3c3 [docs] kernels (#13139)
* kernels

* feedback
2026-03-25 09:31:54 -07:00
Sayak Paul
426daabad9 [ci] claude in ci. (#13297)
* claude in ci.

* review feedback.
2026-03-25 21:30:06 +05:30
Kashif Rasul
762ae059fa [LLADA2] documentation fixes (#13333)
documentation fixes
2026-03-25 17:49:31 +05:30
Kashif Rasul
5d207e756e [Discrete Diffusion] Add LLaDA2 pipeline (#13226)
* feat: add LLaDA2 and BlockRefinement pipelines for discrete text diffusion

Add support for LLaDA2/LLaDA2.1 discrete diffusion text generation:
- BlockRefinementPipeline: block-wise iterative refinement with confidence-based
  token commitment, supporting editing threshold for LLaDA2.1 models
- LLaDA2Pipeline: convenience wrapper with LLaDA2-specific defaults
- DiscreteDiffusionPipelineMixin: shared SAR sampling utilities (top-k, top-p,
  temperature) and prompt/prefix helpers
- compute_confidence_aware_loss: CAP-style training loss
- Examples: sampling scripts for LLaDA2 and block refinement, training scripts
  with Qwen causal LM
- Docs and tests included

* feat: add BlockRefinementScheduler for commit-by-confidence scheduling

Extract the confidence-based token commit logic from BlockRefinementPipeline
into a dedicated BlockRefinementScheduler, following diffusers conventions.

The scheduler owns:
- Transfer schedule computation (get_num_transfer_tokens)
- Timestep management (set_timesteps)
- Step logic: confidence-based mask-filling and optional token editing

The pipeline now delegates scheduling to self.scheduler.step() and accepts
a scheduler parameter in __init__.

* test: add unit tests for BlockRefinementScheduler

12 tests covering set_timesteps, get_num_transfer_tokens, step logic
(confidence-based commits, threshold behavior, editing, prompt masking,
batched inputs, tuple output).

* docs: add toctree entries and standalone scheduler doc page

- Add BlockRefinement and LLaDA2 to docs sidebar navigation
- Add BlockRefinementScheduler to schedulers sidebar navigation
- Move scheduler autodoc to its own page under api/schedulers/

* feat: add --revision flag and fix dtype deprecation in sample_llada2.py

- Add --revision argument for loading model revisions from the Hub
- Replace deprecated torch_dtype with dtype for transformers 5.x compat

* fix: use 1/0 attention mask instead of 0/-inf for LLaDA2 compat

LLaDA2 models expect a boolean-style (1/0) attention mask, not an
additive (0/-inf) mask. The model internally converts to additive,
so passing 0/-inf caused double-masking and gibberish output.

* refactor: consolidate training scripts into single train_block_refinement.py

- Remove toy train_block_refinement_cap.py (self-contained demo with tiny model)
- Rename train_block_refinement_qwen_cap.py to train_block_refinement.py
  (already works with any causal LM via AutoModelForCausalLM)
- Fix torch_dtype deprecation and update README with correct script names

* fix formatting

* docs: improve LLaDA2 and BlockRefinement documentation

- Add usage examples with real model IDs and working code
- Add recommended parameters table for LLaDA2.1 quality/speed modes
- Note that editing is LLaDA2.1-only (not for LLaDA2.0 models)
- Remove misleading config defaults section from BlockRefinement docs

* feat: set LLaDA2Pipeline defaults to recommended model parameters

- threshold: 0.95 -> 0.7 (quality mode)
- max_post_steps: 0 -> 16 (recommended for LLaDA2.1, harmless for 2.0)
- eos_early_stop: False -> True (stop at EOS token)

block_length=32, steps=32, temperature=0.0 were already correct.
editing_threshold remains None (users enable for LLaDA2.1 models).

* feat: default editing_threshold=0.5 for LLaDA2.1 quality mode

LLaDA2.1 is the current generation. Users with LLaDA2.0 models can
disable editing by passing editing_threshold=None.

* fix: align sampling utilities with official LLaDA2 implementation

- top_p filtering: add shift-right to preserve at least one token above
  threshold (matches official code line 1210)
- temperature ordering: apply scaling before top-k/top-p filtering so
  filtering operates on scaled logits (matches official code lines 1232-1235)
- greedy branch: return argmax directly when temperature=0 without
  filtering (matches official code lines 1226-1230)

* refactor: remove duplicate prompt encoding, reuse mixin's _prepare_input_ids

LLaDA2Pipeline._prepare_prompt_ids was a near-copy of
DiscreteDiffusionPipelineMixin._prepare_input_ids. Remove the duplicate
and call the mixin method directly. Also simplify _extract_input_ids
since we always pass return_dict=True.

* formatting

* fix: replace deprecated torch_dtype with dtype in examples and docstrings

- Update EXAMPLE_DOC_STRING to use dtype= and LLaDA2.1-mini model ID
- Fix sample_block_refinement.py to use dtype=

* remove BlockRefinementPipeline

* cleanup

* fix readme

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* removed DiscreteDiffusionPipelineMixin

* add support for 2d masks for flash attn

* Update src/diffusers/training_utils.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/training_utils.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* fix issues from review

* added tests

* formatting

* add check_eos_finished to scheduler

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_block_refinement.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_block_refinement.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* fix renaming issues and types

* remove duplicate check

* Update docs/source/en/api/pipelines/llada2.md

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-25 16:17:50 +05:30
Sayak Paul
e358ddcce6 fix to device and to dtype tests. (#13323) 2026-03-25 11:47:02 +05:30
49 changed files with 4232 additions and 845 deletions

View File

@@ -10,24 +10,34 @@ Strive to write code as simple and explicit as possible.
---
### Dependencies
- No new mandatory dependency without discussion (e.g. `einops`)
- Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`
## Code formatting
- `make style` and `make fix-copies` should be run as the final step before opening a PR
### Copied Code
- Many classes are kept in sync with a source via a `# Copied from ...` header comment
- Do not edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source
- Remove the header to intentionally break the link
### Models
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
- See the **model-integration** skill for the attention pattern, pipeline rules, test setup instructions, and other important details.
- See [models.md](models.md) for model conventions, attention pattern, implementation rules, dependencies, and gotchas.
- See the [model-integration](./skills/model-integration/SKILL.md) skill for the full integration workflow, file structure, test setup, and other details.
### Pipelines & Schedulers
- Pipelines inherit from `DiffusionPipeline`
- Schedulers use `SchedulerMixin` with `ConfigMixin`
- Use `@torch.no_grad()` on pipeline `__call__`
- Support `output_type="latent"` for skipping VAE decode
- Support `generator` parameter for reproducibility
- Use `self.progress_bar(timesteps)` for progress tracking
- Don't subclass an existing pipeline for a variant — DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`)
## Skills
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents.
Available skills: **model-integration** (adding/converting pipelines), **parity-testing** (debugging numerical parity).
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents. Available skills include:
- [model-integration](./skills/model-integration/SKILL.md) (adding/converting pipelines)
- [parity-testing](./skills/parity-testing/SKILL.md) (debugging numerical parity).

76
.ai/models.md Normal file
View File

@@ -0,0 +1,76 @@
# Model conventions and rules
Shared reference for model-related conventions, patterns, and gotchas.
Linked from `AGENTS.md`, `skills/model-integration/SKILL.md`, and `review-rules.md`.
## Coding style
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
- No new mandatory dependency without discussion (e.g. `einops`). Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`.
## Common model conventions
- Models use `ModelMixin` with `register_to_config` for config serialization
## Attention pattern
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
```python
# transformer_mymodel.py
class MyModelAttnProcessor:
_attention_backend = None
_parallel_config = None
def __call__(self, attn, hidden_states, attention_mask=None, ...):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# reshape, apply rope, etc.
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
return attn.to_out[0](hidden_states)
class MyModelAttention(nn.Module, AttentionModuleMixin):
_default_processor_cls = MyModelAttnProcessor
_available_processors = [MyModelAttnProcessor]
def __init__(self, query_dim, heads=8, dim_head=64, ...):
super().__init__()
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
self.set_processor(MyModelAttnProcessor())
def forward(self, hidden_states, attention_mask=None, **kwargs):
return self.processor(self, hidden_states, attention_mask, **kwargs)
```
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
## Gotchas
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.

11
.ai/review-rules.md Normal file
View File

@@ -0,0 +1,11 @@
# PR Review Rules
Review-specific rules for Claude. Focus on correctness — style is handled by ruff.
Before reviewing, read and apply the guidelines in:
- [AGENTS.md](AGENTS.md) — coding style, copied code
- [models.md](models.md) — model conventions, attention pattern, implementation rules, dependencies, gotchas
- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities
- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.)
## Common mistakes (add new rules below this line)

View File

@@ -65,89 +65,19 @@ docs/source/en/api/
- [ ] Run `make style` and `make quality`
- [ ] Test parity with reference implementation (see `parity-testing` skill)
### Attention pattern
### Model conventions, attention pattern, and implementation rules
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
See [../../models.md](../../models.md) for the attention pattern, implementation rules, common conventions, dependencies, and gotchas. These apply to all model work.
```python
# transformer_mymodel.py
### Model integration specific rules
class MyModelAttnProcessor:
_attention_backend = None
_parallel_config = None
def __call__(self, attn, hidden_states, attention_mask=None, ...):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# reshape, apply rope, etc.
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
return attn.to_out[0](hidden_states)
class MyModelAttention(nn.Module, AttentionModuleMixin):
_default_processor_cls = MyModelAttnProcessor
_available_processors = [MyModelAttnProcessor]
def __init__(self, query_dim, heads=8, dim_head=64, ...):
super().__init__()
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
self.set_processor(MyModelAttnProcessor())
def forward(self, hidden_states, attention_mask=None, **kwargs):
return self.processor(self, hidden_states, attention_mask, **kwargs)
```
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
### Implementation rules
1. **Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
2. **Pipelines must inherit from `DiffusionPipeline`.** Consult implementations in `src/diffusers/pipelines` in case you need references.
3. **Don't subclass an existing pipeline for a variant.** DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`).
**Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
### Test setup
- Slow tests gated with `@slow` and `RUN_SLOW=1`
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.
### Common diffusers conventions
- Pipelines inherit from `DiffusionPipeline`
- Models use `ModelMixin` with `register_to_config` for config serialization
- Schedulers use `SchedulerMixin` with `ConfigMixin`
- Use `@torch.no_grad()` on pipeline `__call__`
- Support `output_type="latent"` for skipping VAE decode
- Support `generator` parameter for reproducibility
- Use `self.progress_bar(timesteps)` for progress tracking
## Gotchas
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
---
## Modular Pipeline Conversion

42
.github/workflows/claude_review.yml vendored Normal file
View File

@@ -0,0 +1,42 @@
name: Claude PR Review
on:
issue_comment:
types: [created]
pull_request_review_comment:
types: [created]
permissions:
contents: write
pull-requests: write
issues: read
id-token: write
jobs:
claude-review:
if: |
(
github.event_name == 'issue_comment' &&
github.event.issue.pull_request &&
github.event.issue.state == 'open' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'MEMBER' ||
github.event.comment.author_association == 'OWNER' ||
github.event.comment.author_association == 'COLLABORATOR')
) || (
github.event_name == 'pull_request_review_comment' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'MEMBER' ||
github.event.comment.author_association == 'OWNER' ||
github.event.comment.author_association == 'COLLABORATOR')
)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 1
- uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
claude_args: |
--append-system-prompt "Review this PR against the rules in .ai/review-rules.md. Focus on correctness, not style (ruff handles style). Only review changes under src/diffusers/. Do NOT commit changes unless the comment explicitly asks you to using the phrase 'commit this'."

View File

@@ -161,6 +161,8 @@
- local: training/ddpo
title: Reinforcement learning training with DDPO
title: Methods
- local: training/nemo_automodel
title: NeMo Automodel
title: Training
- isExpanded: false
sections:
@@ -670,6 +672,10 @@
- local: api/pipelines/z_image
title: Z-Image
title: Image
- sections:
- local: api/pipelines/llada2
title: LLaDA2
title: Text
- sections:
- local: api/pipelines/allegro
title: Allegro
@@ -718,6 +724,8 @@
- sections:
- local: api/schedulers/overview
title: Overview
- local: api/schedulers/block_refinement
title: BlockRefinementScheduler
- local: api/schedulers/cm_stochastic_iterative
title: CMStochasticIterativeScheduler
- local: api/schedulers/ddim_cogvideox

View File

@@ -41,16 +41,15 @@ The quantized CogVideoX 5B model below requires ~16GB of VRAM.
```py
import torch
from diffusers import CogVideoXPipeline, AutoModel
from diffusers import CogVideoXPipeline, AutoModel, TorchAoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video
from torchao.quantization import Int8WeightOnlyConfig
# quantize weights to int8 with torchao
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="torchao",
quant_kwargs={"quant_type": "int8wo"},
components_to_quantize="transformer"
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig())}
)
# fp8 layerwise weight-casting

View File

@@ -0,0 +1,90 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LLaDA2
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) is a family of discrete diffusion language models
that generate text through block-wise iterative refinement. Instead of autoregressive token-by-token generation,
LLaDA2 starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement
steps.
## Usage
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
model_id = "inclusionAI/LLaDA2.1-mini"
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
scheduler = BlockRefinementScheduler()
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
output = pipe(
prompt="Write a short poem about the ocean.",
gen_length=256,
block_length=32,
num_inference_steps=32,
threshold=0.7,
editing_threshold=0.5,
max_post_steps=16,
temperature=0.0,
)
print(output.texts[0])
```
## Callbacks
Callbacks run after each refinement step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are
included in `callback_kwargs`. In the current implementation, `block_x` (the sequence window being refined) and
`transfer_index` (mask-filling commit mask) are provided; return `{"block_x": ...}` from the callback to replace the
window.
```py
def on_step_end(pipe, step, timestep, callback_kwargs):
block_x = callback_kwargs["block_x"]
# Inspect or modify `block_x` here.
return {"block_x": block_x}
out = pipe(
prompt="Write a short poem.",
callback_on_step_end=on_step_end,
callback_on_step_end_tensor_inputs=["block_x"],
)
```
## Recommended parameters
LLaDA2.1 models support two modes:
| Mode | `threshold` | `editing_threshold` | `max_post_steps` |
|------|-------------|---------------------|------------------|
| Quality | 0.7 | 0.5 | 16 |
| Speed | 0.5 | `None` | 16 |
Pass `editing_threshold=None`, `0.0`, or a negative value to turn off post-mask editing.
For LLaDA2.0 models, disable editing by passing `editing_threshold=None` or `0.0`.
For all models: `block_length=32`, `temperature=0.0`, `num_inference_steps=32`.
## LLaDA2Pipeline
[[autodoc]] LLaDA2Pipeline
- all
- __call__
## LLaDA2PipelineOutput
[[autodoc]] pipelines.LLaDA2PipelineOutput

View File

@@ -18,7 +18,7 @@
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
[LTX-2](https://hf.co/papers/2601.03233) is a DiT-based foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
@@ -293,6 +293,7 @@ import torch
from diffusers import LTX2ConditionPipeline
from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT
from diffusers.utils import load_image, load_video
device = "cuda"
@@ -315,19 +316,6 @@ prompt = (
"landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the "
"solitude and beauty of a winter drive through a mountainous region."
)
negative_prompt = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
cond_video = load_video(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
@@ -343,7 +331,7 @@ frame_rate = 24.0
video, audio = pipe(
conditions=conditions,
prompt=prompt,
negative_prompt=negative_prompt,
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
width=width,
height=height,
num_frames=121,
@@ -366,6 +354,154 @@ encode_video(
Because the conditioning is done via latent frames, the 8 data space frames corresponding to the specified latent frame for an image condition will tend to be static.
## Multimodal Guidance
LTX-2.X pipelines support multimodal guidance. It is composed of three terms, all using a CFG-style update rule:
1. Classifier-Free Guidance (CFG): standard [CFG](https://huggingface.co/papers/2207.12598) where the perturbed ("weaker") output is generated using the negative prompt.
2. Spatio-Temporal Guidance (STG): [STG](https://huggingface.co/papers/2411.18664) moves away from a perturbed output created from short-cutting self-attention operations and substitutes in the attention values instead. The idea is that this creates sharper videos and better spatiotemporal consistency.
3. Modality Isolation Guidance: moves away from a perturbed output created from disabling cross-modality (audio-to-video and video-to-audio) cross attention. This guidance is more specific to [LTX-2.X](https://huggingface.co/papers/2601.03233) models, with the idea that this produces better consistency between the generated audio and video.
These are controlled by the `guidance_scale`, `stg_scale`, and `modality_scale` arguments and can be set separately for video and audio. Additionally, for STG the transformer block indices where self-attention is skipped needs to be specified via the `spatio_temporal_guidance_blocks` argument. The LTX-2.X pipelines also support [guidance rescaling](https://huggingface.co/papers/2305.08891) to help reduce over-exposure, which can be a problem when the guidance scales are set to high values.
```py
import torch
from diffusers import LTX2ImageToVideoPipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT
from diffusers.utils import load_image
device = "cuda"
width = 768
height = 512
random_seed = 42
frame_rate = 24.0
generator = torch.Generator(device).manual_seed(random_seed)
model_path = "dg845/LTX-2.3-Diffusers"
pipe = LTX2ImageToVideoPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload(device=device)
pipe.vae.enable_tiling()
prompt = (
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
"gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
"before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
"fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
"shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
"smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
"distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
"breath-taking, movie-like shot."
)
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg",
)
video, audio = pipe(
image=image,
prompt=prompt,
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=30,
guidance_scale=3.0, # Recommended LTX-2.3 guidance parameters
stg_scale=1.0, # Note that 0.0 (not 1.0) means that STG is disabled (all other guidance is disabled at 1.0)
modality_scale=3.0,
guidance_rescale=0.7,
audio_guidance_scale=7.0, # Note that a higher CFG guidance scale is recommended for audio
audio_stg_scale=1.0,
audio_modality_scale=3.0,
audio_guidance_rescale=0.7,
spatio_temporal_guidance_blocks=[28],
use_cross_timestep=True,
generator=generator,
output_type="np",
return_dict=False,
)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_3_i2v_stage_1.mp4",
)
```
## Prompt Enhancement
The LTX-2.X models are sensitive to prompting style. Refer to the [official prompting guide](https://ltx.io/model/model-blog/prompting-guide-for-ltx-2) for recommendations on how to write a good prompt. Using prompt enhancement, where the supplied prompts are enhanced using the pipeline's text encoder (by default a [Gemma 3](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized) model) given a system prompt, can also improve sample quality. The optional `processor` pipeline component needs to be present to use prompt enhancement. Enable prompt enhancement by supplying a `system_prompt` argument:
```py
import torch
from transformers import Gemma3Processor
from diffusers import LTX2Pipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import DEFAULT_NEGATIVE_PROMPT, T2V_DEFAULT_SYSTEM_PROMPT
device = "cuda"
width = 768
height = 512
random_seed = 42
frame_rate = 24.0
generator = torch.Generator(device).manual_seed(random_seed)
model_path = "dg845/LTX-2.3-Diffusers"
pipe = LTX2Pipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload(device=device)
pipe.vae.enable_tiling()
if getattr(pipe, "processor", None) is None:
processor = Gemma3Processor.from_pretrained("google/gemma-3-12b-it-qat-q4_0-unquantized")
pipe.processor = processor
prompt = (
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
"gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
"before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
"fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
"shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
"smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
"distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
"breath-taking, movie-like shot."
)
video, audio = pipe(
prompt=prompt,
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
width=width,
height=height,
num_frames=121,
frame_rate=frame_rate,
num_inference_steps=30,
guidance_scale=3.0,
stg_scale=1.0,
modality_scale=3.0,
guidance_rescale=0.7,
audio_guidance_scale=7.0,
audio_stg_scale=1.0,
audio_modality_scale=3.0,
audio_guidance_rescale=0.7,
spatio_temporal_guidance_blocks=[28],
use_cross_timestep=True,
system_prompt=T2V_DEFAULT_SYSTEM_PROMPT,
generator=generator,
output_type="np",
return_dict=False,
)
encode_video(
video[0],
fps=frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
output_path="ltx2_3_t2v_stage_1.mp4",
)
```
## LTX2Pipeline
[[autodoc]] LTX2Pipeline

View File

@@ -63,6 +63,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
| [Latte](latte) | text2image |
| [LEDITS++](ledits_pp) | image editing |
| [LLaDA2](llada2) | text2text |
| [Lumina-T2X](lumina) | text2image |
| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |
| [MultiDiffusion](panorama) | text2image |

View File

@@ -0,0 +1,25 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# BlockRefinementScheduler
The `BlockRefinementScheduler` manages block-wise iterative refinement for discrete token diffusion. At each step it
commits the most confident tokens and optionally edits already-committed tokens when the model predicts a different
token with high confidence.
This scheduler is used by [`LLaDA2Pipeline`].
## BlockRefinementScheduler
[[autodoc]] BlockRefinementScheduler
## BlockRefinementSchedulerOutput
[[autodoc]] schedulers.scheduling_block_refinement.BlockRefinementSchedulerOutput

View File

@@ -248,6 +248,24 @@ Refer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/be
The [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results) repository also contains benchmarking results for compiled versions of Flux and CogVideoX.
## Kernels
[Kernels](https://huggingface.co/docs/kernels/index) is a library for building, distributing, and loading optimized compute kernels on the [Hub](https://huggingface.co/kernels-community). It supports [attention](./attention_backends#set_attention_backend) kernels and custom CUDA kernels for operations like RMSNorm, GEGLU, RoPE, and AdaLN.
The [Diffusers Pipeline Integration](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/references/diffusers-integration.md) guide shows how to integrate a kernel with the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill. This skill enables an agent, like Claude or Codex, to write custom kernels targeted towards a specific model and your hardware.
> [!TIP]
> Install the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill to teach an agent how to write a kernel. The [Custom kernels for all from Codex and Claude](https://huggingface.co/blog/custom-cuda-kernels-agent-skills) blog post covers this in more detail.
For example, a custom RMSNorm kernel (generated by the `add cuda-kernels` skill) with [torch.compile](#torchcompile) speeds up LTX-Video generation 1.43x on an H100.
<iframe
src="https://huggingface.co/datasets/docs-benchmarks/kernel-ltx-video/embed/viewer/default/train"
frameborder="0"
width="100%"
height="560px"
></iframe>
## Dynamic quantization
[Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.

View File

@@ -29,24 +29,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
from torchao.quantization import Int8WeightOnlyConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
```
For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.
```py
import torch
from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig("int8wo")}
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128, version=2))}
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
@@ -91,18 +74,15 @@ Weight-only quantization stores the model weights in a specific low-bit data typ
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
The quantization methods supported are as follows:
Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods. An exhaustive list of configuration options are available [here](https://docs.pytorch.org/ao/main/workflows/inference.html#inference-workflows).
| **Category** | **Full Function Names** | **Shorthands** |
|--------------|-------------------------|----------------|
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8dq_e4m3_tensor`, `float8dq_e4m3_row` |
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
Some example popular quantization configurations are as follows:
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
| **Category** | **Configuration Classes** |
|---|---|
| **Integer quantization** | [`Int4WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int4WeightOnlyConfig.html), [`Int8WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int8WeightOnlyConfig.html), [`Int8DynamicActivationInt8WeightConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Int8DynamicActivationInt8WeightConfig.html) |
| **Floating point 8-bit quantization** | [`Float8WeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Float8WeightOnlyConfig.html), [`Float8DynamicActivationFloat8WeightConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.Float8DynamicActivationFloat8WeightConfig.html) |
| **Unsigned integer quantization** | [`IntxWeightOnlyConfig`](https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.IntxWeightOnlyConfig.html) |
## Serializing and Deserializing quantized models
@@ -111,8 +91,9 @@ To serialize a quantized model in a given dtype, first load the model with the d
```python
import torch
from diffusers import AutoModel, TorchAoConfig
from torchao.quantization import Int8WeightOnlyConfig
quantization_config = TorchAoConfig("int8wo")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
transformer = AutoModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
@@ -137,18 +118,19 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4` weight-only, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
```python
import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, AutoModel, TorchAoConfig
from torchao.quantization import IntxWeightOnlyConfig
# Serialize the model
transformer = AutoModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
quantization_config=TorchAoConfig(IntxWeightOnlyConfig(dtype=torch.uint4)),
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")

View File

@@ -0,0 +1,378 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# NeMo Automodel
[NeMo Automodel](https://github.com/NVIDIA-NeMo/Automodel) is a PyTorch DTensor-native training library from NVIDIA for fine-tuning and pretraining diffusion models at scale. It is Hugging Face native — train any Diffusers-format model from the Hub with no checkpoint conversion. The same YAML recipe and hackable training script runs on any scale from 1 GPU to hundreds of nodes, with [FSDP2](https://pytorch.org/docs/stable/fsdp.html) distributed training, multiresolution bucketed dataloading, and pre-encoded latent space training for maximum GPU utilization. It uses [flow matching](https://huggingface.co/papers/2210.02747) for training and is fully open source (Apache 2.0), NVIDIA-supported, and actively maintained.
NeMo Automodel integrates directly with Diffusers. It loads pretrained models from the Hugging Face Hub using Diffusers model classes and generates outputs with the [`DiffusionPipeline`].
The typical workflow is to install NeMo Automodel (pip or Docker), prepare your data by encoding it into `.meta` files, configure a YAML recipe, launch training with `torchrun`, and run inference with the resulting checkpoint.
## Supported models
| Model | Hugging Face ID | Task | Parameters | Use case |
|-------|----------------|------|------------|----------|
| Wan 2.1 T2V 1.3B | [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) | Text-to-Video | 1.3B | video generation on limited hardware (fits on single 40GB A100) |
| FLUX.1-dev | [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) | Text-to-Image | 12B | high-quality image generation |
| HunyuanVideo 1.5 | [hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v) | Text-to-Video | 13B | high-quality video generation |
## Installation
### Hardware requirements
| Component | Minimum | Recommended |
|-----------|---------|-------------|
| GPU | A100 40GB | A100 80GB / H100 |
| GPUs | 4 | 8+ |
| RAM | 128 GB | 256 GB+ |
| Storage | 500 GB SSD | 2 TB NVMe |
Install NeMo Automodel with pip. For the full set of installation methods (including from source), see the [NeMo Automodel installation guide](https://docs.nvidia.com/nemo/automodel/latest/guides/installation.html).
```bash
pip3 install nemo-automodel
```
Alternatively, use the pre-built Docker container which includes all dependencies.
```bash
docker pull nvcr.io/nvidia/nemo-automodel:26.02.00
docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/nemo-automodel:26.02.00
```
> [!WARNING]
> Checkpoints are lost when the container exits unless you bind-mount the checkpoint directory to the host. For example, add `-v /host/path/checkpoints:/workspace/checkpoints` to the `docker run` command.
## Data preparation
NeMo Automodel trains diffusion models in latent space. Raw images or videos must be preprocessed into `.meta` files containing VAE latents and text embeddings before training. This avoids re-encoding on every training step.
Use the built-in preprocessing tool to encode your data. The tool automatically distributes work across all available GPUs.
<hfoptions id="data-prep">
<hfoption id="video preprocessing">
The video preprocessing command is the same for both Wan 2.1 and HunyuanVideo, but the flags differ. Wan 2.1 uses `--processor wan` with `--resolution_preset` and `--caption_format sidecar`, while HunyuanVideo uses `--processor hunyuan` with `--target_frames` to set the frame count and `--caption_format meta_json`.
**Wan 2.1:**
```bash
python -m tools.diffusion.preprocessing_multiprocess video \
--video_dir /data/videos \
--output_dir /cache \
--processor wan \
--resolution_preset 512p \
--caption_format sidecar
```
**HunyuanVideo:**
```bash
python -m tools.diffusion.preprocessing_multiprocess video \
--video_dir /data/videos \
--output_dir /cache \
--processor hunyuan \
--target_frames 121 \
--caption_format meta_json
```
</hfoption>
<hfoption id="image preprocessing">
```bash
python -m tools.diffusion.preprocessing_multiprocess image \
--image_dir /data/images \
--output_dir /cache \
--processor flux \
--resolution_preset 512p
```
</hfoption>
</hfoptions>
### Output format
Preprocessing produces a cache directory organized by resolution bucket. NeMo Automodel supports multi-resolution training through bucketed sampling. Samples are grouped by spatial resolution so each batch contains same-size samples, avoiding padding waste.
```
/cache/
├── 512x512/ # Resolution bucket
│ ├── <hash1>.meta # VAE latents + text embeddings
│ ├── <hash2>.meta
│ └── ...
├── 832x480/ # Another resolution bucket
│ └── ...
├── metadata.json # Global config (processor, model, total items)
└── metadata_shard_0000.json # Per-sample metadata (paths, resolutions, captions)
```
> [!TIP]
> See the [Diffusion Dataset Preparation](https://docs.nvidia.com/nemo/automodel/latest/guides/diffusion/dataset.html) guide for caption formats, input data requirements, and all available preprocessing arguments.
## Training configuration
Fine-tuning is driven by two components:
1. A recipe script ([finetune.py](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/diffusion/finetune/finetune.py)) is a Python entry point that contains the training loop: loading the model, building the dataloader, running forward/backward passes, computing the flow matching loss, checkpointing, and logging.
2. A YAML configuration file specifies all settings the recipe uses: which model to fine-tune, where the data lives, optimizer hyperparameters, parallelism strategy, and more. You customize training by editing this file rather than modifying code, allowing you to scale from 1 to hundreds of GPUs.
Any YAML field can also be overridden from the CLI:
```bash
torchrun --nproc-per-node=8 examples/diffusion/finetune/finetune.py \
-c examples/diffusion/finetune/wan2_1_t2v_flow.yaml \
--optim.learning_rate 1e-5 \
--step_scheduler.num_epochs 50
```
Below is the annotated config for fine-tuning Wan 2.1 T2V 1.3B, with each section explained.
```yaml
seed: 42
# ── Experiment tracking (optional) ──────────────────────────────────────────
# Weights & Biases integration for logging metrics, losses, and learning rates.
# Set mode: "disabled" to turn off.
wandb:
project: wan-t2v-flow-matching
mode: online
name: wan2_1_t2v_fm
# ── Model ───────────────────────────────────────────────────────────────────
# pretrained_model_name_or_path: any Hugging Face model ID or local path.
# mode: "finetune" loads pretrained weights; "pretrain" trains from scratch.
model:
pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
mode: finetune
# ── Training schedule ───────────────────────────────────────────────────────
# global_batch_size: effective batch across all GPUs.
# Gradient accumulation is computed automatically: global / (local × num_gpus).
step_scheduler:
global_batch_size: 8
local_batch_size: 1
ckpt_every_steps: 1000 # Save a checkpoint every N steps
num_epochs: 100
log_every: 2 # Log metrics every N steps
# ── Data ────────────────────────────────────────────────────────────────────
# _target_: the dataloader factory function.
# Use build_video_multiresolution_dataloader for video models (Wan, HunyuanVideo).
# Use build_text_to_image_multiresolution_dataloader for image models (FLUX).
# model_type: "wan" or "hunyuan" (selects the correct latent format).
# base_resolution: target resolution for multiresolution bucketing.
data:
dataloader:
_target_: nemo_automodel.components.datasets.diffusion.build_video_multiresolution_dataloader
cache_dir: PATH_TO_YOUR_DATA
model_type: wan
base_resolution: [512, 512]
dynamic_batch_size: false # When true, adjusts batch per bucket to maintain constant memory
shuffle: true
drop_last: false
num_workers: 0
# ── Optimizer ───────────────────────────────────────────────────────────────
# learning_rate: 5e-6 is a good starting point for fine-tuning.
# Adjust weight_decay and betas for your dataset.
optim:
learning_rate: 5e-6
optimizer:
weight_decay: 0.01
betas: [0.9, 0.999]
# ── Learning rate scheduler ─────────────────────────────────────────────────
# Supports cosine, linear, and constant schedules.
lr_scheduler:
lr_decay_style: cosine
lr_warmup_steps: 0
min_lr: 1e-6
# ── Flow matching ───────────────────────────────────────────────────────────
# adapter_type: model-specific adapter — must match the model:
# "simple" for Wan 2.1, "flux" for FLUX.1-dev, "hunyuan" for HunyuanVideo.
# timestep_sampling: "uniform" for Wan, "logit_normal" for FLUX and HunyuanVideo.
# flow_shift: shifts the flow schedule (model-dependent).
# i2v_prob: probability of image-to-video conditioning during training (video models).
flow_matching:
adapter_type: "simple"
adapter_kwargs: {}
timestep_sampling: "uniform"
logit_mean: 0.0
logit_std: 1.0
flow_shift: 3.0
num_train_timesteps: 1000
i2v_prob: 0.3
use_loss_weighting: true
# ── FSDP2 distributed training ──────────────────────────────────────────────
# dp_size: number of GPUs for data parallelism (typically = total GPUs on node).
# tp_size, cp_size, pp_size: tensor, context, and pipeline parallelism.
# For most fine-tuning, dp_size is all you need; leave others at 1.
fsdp:
tp_size: 1
cp_size: 1
pp_size: 1
dp_replicate_size: 1
dp_size: 8
# ── Checkpointing ──────────────────────────────────────────────────────────
# checkpoint_dir: where to save checkpoints (use a persistent path with Docker).
# restore_from: path to resume training from a previous checkpoint.
checkpoint:
enabled: true
checkpoint_dir: PATH_TO_YOUR_CKPT_DIR
model_save_format: torch_save
save_consolidated: false
restore_from: null
```
### Config field reference
The table below lists the minimal required configs. See the [NeMo Automodel examples](https://github.com/NVIDIA-NeMo/Automodel/tree/main/examples/diffusion/finetune) have full example configs for all models.
| Section | Required? | What to Change |
|---------|-----------|----------------|
| `model` | Yes | Set `pretrained_model_name_or_path` to the Hugging Face model ID. Set `mode: finetune` or `mode: pretrain`. |
| `step_scheduler` | Yes | `global_batch_size` is the effective batch size across all GPUs. `ckpt_every_steps` controls checkpoint frequency. Gradient accumulation is computed automatically. |
| `data` | Yes | Set `cache_dir` to the path containing your preprocessed `.meta` files. Change `_target_` and `model_type` for different models. |
| `optim` | Yes | `learning_rate: 5e-6` is a good default for fine-tuning. Adjust for your dataset and model. |
| `lr_scheduler` | Yes | Choose `cosine`, `linear`, or `constant` for `lr_decay_style`. Set `lr_warmup_steps` for gradual warmup. |
| `flow_matching` | Yes | `adapter_type` must match the model (`simple` for Wan, `flux` for FLUX, `hunyuan` for HunyuanVideo). See model-specific configs for `adapter_kwargs`. |
| `fsdp` | Yes | Set `dp_size` to the number of GPUs. For multi-node, set to total GPUs across all nodes. |
| `checkpoint` | Recommended | Set `checkpoint_dir` to a persistent path, especially in Docker. Use `restore_from` to resume from a previous checkpoint. |
| `wandb` | Optional | Configure to enable Weights & Biases experiment tracking. Set `mode: disabled` to turn off. |
## Launch training
<hfoptions id="launch-training">
<hfoption id="single-node">
```bash
torchrun --nproc-per-node=8 \
examples/diffusion/finetune/finetune.py \
-c examples/diffusion/finetune/wan2_1_t2v_flow.yaml
```
</hfoption>
<hfoption id="multi-node">
Run the following on each node, setting `NODE_RANK` accordingly:
```bash
export MASTER_ADDR=node0.hostname
export MASTER_PORT=29500
export NODE_RANK=0 # 0 on master, 1 on second node, etc.
torchrun \
--nnodes=2 \
--nproc-per-node=8 \
--node_rank=${NODE_RANK} \
--rdzv_backend=c10d \
--rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
examples/diffusion/finetune/finetune.py \
-c examples/diffusion/finetune/wan2_1_t2v_flow_multinode.yaml
```
> [!NOTE]
> For multi-node training, set `fsdp.dp_size` in the YAML to the **total** number of GPUs across all nodes (e.g., 16 for 2 nodes with 8 GPUs each).
</hfoption>
</hfoptions>
## Generation
After training, generate videos or images from text prompts using the fine-tuned checkpoint.
<hfoptions id="generation">
<hfoption id="Wan 2.1">
```bash
python examples/diffusion/generate/generate.py \
-c examples/diffusion/generate/configs/generate_wan.yaml
```
With a fine-tuned checkpoint:
```bash
python examples/diffusion/generate/generate.py \
-c examples/diffusion/generate/configs/generate_wan.yaml \
--model.checkpoint ./checkpoints/step_1000 \
--inference.prompts '["A dog running on a beach"]'
```
</hfoption>
<hfoption id="FLUX">
```bash
python examples/diffusion/generate/generate.py \
-c examples/diffusion/generate/configs/generate_flux.yaml
```
With a fine-tuned checkpoint:
```bash
python examples/diffusion/generate/generate.py \
-c examples/diffusion/generate/configs/generate_flux.yaml \
--model.checkpoint ./checkpoints/step_1000 \
--inference.prompts '["A dog running on a beach"]'
```
</hfoption>
<hfoption id="HunyuanVideo">
```bash
python examples/diffusion/generate/generate.py \
-c examples/diffusion/generate/configs/generate_hunyuan.yaml
```
With a fine-tuned checkpoint:
```bash
python examples/diffusion/generate/generate.py \
-c examples/diffusion/generate/configs/generate_hunyuan.yaml \
--model.checkpoint ./checkpoints/step_1000 \
--inference.prompts '["A dog running on a beach"]'
```
</hfoption>
</hfoptions>
## Diffusers integration
NeMo Automodel is built on top of Diffusers and uses it as the backbone for model loading and inference. It loads models directly from the Hugging Face Hub using Diffusers model classes such as [`WanTransformer3DModel`], [`FluxTransformer2DModel`], and [`HunyuanVideoTransformer3DModel`], and generates outputs via Diffusers pipelines like [`WanPipeline`] and [`FluxPipeline`].
This integration provides several benefits for Diffusers users:
- **No checkpoint conversion**: pretrained weights from the Hub work out of the box. Point `pretrained_model_name_or_path` at any Diffusers-format model ID and start training immediately.
- **Day-0 model support**: when a new diffusion model is added to Diffusers and uploaded to the Hub, it can be fine-tuned with NeMo Automodel without waiting for a dedicated training script.
- **Pipeline-compatible outputs**: fine-tuned checkpoints are saved in a format that can be loaded directly back into Diffusers pipelines for inference, sharing on the Hub, or further optimization with tools like quantization and compilation.
- **Scalable training for Diffusers models**: NeMo Automodel adds distributed training capabilities (FSDP2, multi-node, multiresolution bucketing) that go beyond what the built-in Diffusers training scripts provide, while keeping the same model and pipeline interfaces.
- **Shared ecosystem**: any model, LoRA adapter, or pipeline component from the Diffusers ecosystem remains compatible throughout the training and inference workflow.
## NVIDIA Team
- Pranav Prashant Thombre, pthombre@nvidia.com
- Linnan Wang, linnanw@nvidia.com
- Alexandros Koumparoulis, akoumparouli@nvidia.com
## Resources
- [NeMo Automodel GitHub](https://github.com/NVIDIA-NeMo/Automodel)
- [Diffusion Fine-Tuning Guide](https://docs.nvidia.com/nemo/automodel/latest/guides/diffusion/finetune.html)
- [Diffusion Dataset Preparation](https://docs.nvidia.com/nemo/automodel/latest/guides/diffusion/dataset.html)
- [Diffusion Model Coverage](https://docs.nvidia.com/nemo/automodel/latest/model-coverage/diffusion.html)
- [NeMo Automodel for Transformers (LLM/VLM fine-tuning)](https://huggingface.co/docs/transformers/en/community_integrations/nemo_automodel_finetuning)

View File

@@ -0,0 +1,50 @@
# Discrete Token Diffusion (Experimental)
This folder contains **training and sampling examples** for *discrete diffusion over token IDs* (language-model style), built to follow the `diffusers` + `accelerate` training conventions.
## LLaDA2
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) generates text through block-wise iterative refinement. Instead of autoregressive token-by-token generation, it starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement steps.
### Train
The training script uses confidence-aware loss and works with any causal LM from the Hub (e.g. Qwen, Llama, Mistral):
```bash
accelerate launch examples/discrete_diffusion/train_llada2.py \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--text_column text \
--output_dir llada2-output \
--max_train_steps 1000 \
--prompt_length 32 \
--block_length 32 \
--lambda_conf 2.0 \
--conf_temperature 0.5
```
If you don't want to download a dataset, you can use random-token data:
```bash
accelerate launch examples/discrete_diffusion/train_llada2.py \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--output_dir llada2-output \
--use_dummy_data \
--num_dummy_samples 2048
```
### Sample
```bash
python examples/discrete_diffusion/sample_llada2.py \
--model_id inclusionAI/LLaDA2.1-mini \
--prompt "Write a short poem about the ocean." \
--gen_length 256 \
--num_inference_steps 32 \
--threshold 0.7 \
--editing_threshold 0.5 \
--max_post_steps 16 \
--use_chat_template \
--add_generation_prompt
```

View File

@@ -0,0 +1,263 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sample script for LLaDA2-style discrete diffusion text generation.
This script demonstrates how to use the LLaDA2Pipeline for text generation
using block-wise iterative refinement.
Example usage:
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-mini --prompt "What is the capital of France?"
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-flash-CAP --prompt "Explain quantum computing." --temperature 0.7
"""
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
from diffusers.hooks import apply_group_offloading
def main():
parser = argparse.ArgumentParser(
description="Generate text using LLaDA2Pipeline with block-wise discrete diffusion."
)
parser.add_argument(
"--model_id",
type=str,
default="inclusionAI/LLaDA2.0-mini",
help="HuggingFace model ID or path to local model.",
)
parser.add_argument(
"--prompt",
type=str,
default="Why does Camus think that Sisyphus is happy?",
help="Text prompt to generate from.",
)
parser.add_argument(
"--gen_length",
type=int,
default=2048,
help="Number of tokens to generate.",
)
parser.add_argument(
"--block_length",
type=int,
default=32,
help="Size of each generation block.",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=32,
help="Number of refinement steps per block.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Sampling temperature (0.0 for greedy).",
)
parser.add_argument(
"--top_p",
type=float,
default=None,
help="Nucleus sampling probability threshold.",
)
parser.add_argument(
"--top_k",
type=int,
default=None,
help="Top-k sampling parameter.",
)
parser.add_argument(
"--threshold",
type=float,
default=0.95,
help="Confidence threshold for committing tokens.",
)
parser.add_argument(
"--editing_threshold",
type=float,
default=None,
help="Confidence threshold for editing already-committed tokens. Set to enable post-mask editing (e.g. 0.5).",
)
parser.add_argument(
"--max_post_steps",
type=int,
default=0,
help="Maximum post-mask editing iterations per block (e.g. 16). Only used when --editing_threshold is set.",
)
parser.add_argument(
"--sampling_method",
type=str,
default="multinomial",
choices=["auto", "greedy", "multinomial"],
help="Sampling method for block refinement.",
)
parser.add_argument(
"--eos_early_stop",
action="store_true",
help="Stop generation early when EOS token is generated.",
)
parser.add_argument(
"--use_chat_template",
action="store_true",
help="Use the tokenizer chat template for the prompt.",
)
parser.add_argument(
"--add_generation_prompt",
action="store_true",
help="Add the generation prompt when using the chat template.",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to run inference on.",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
help="Model dtype.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility.",
)
parser.add_argument(
"--offload",
type=str,
default=None,
choices=["group", "sequential"],
help="Memory offloading strategy: 'group' for group offloading (faster), 'sequential' for sequential CPU offload (slower but lower memory).",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="Model revision (branch, tag, or commit hash) to load from the Hub.",
)
args = parser.parse_args()
# Parse dtype
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
torch_dtype = dtype_map[args.dtype]
print(f"Loading model: {args.model_id}")
tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, revision=args.revision)
# Load model with appropriate memory settings based on offload strategy
if args.offload == "group":
# For group offloading, load to CPU first then apply hooks
print("Using group offloading for memory efficiency...")
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
low_cpu_mem_usage=True,
revision=args.revision,
)
# Apply group offloading with CUDA streams for better performance
onload_device = torch.device(args.device)
offload_device = torch.device("cpu")
apply_group_offloading(
model,
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True,
)
elif args.offload == "sequential":
# For sequential offloading, load to CPU first
print("Using sequential CPU offloading (slower but lower memory)...")
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
low_cpu_mem_usage=True,
revision=args.revision,
)
# Sequential offloading will be applied via pipeline
else:
# Default: use device_map="auto" for automatic memory management
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
device_map="auto",
low_cpu_mem_usage=True,
revision=args.revision,
)
model.eval()
# Create pipeline
scheduler = BlockRefinementScheduler()
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
# Apply sequential CPU offload if requested
if args.offload == "sequential":
pipe.enable_sequential_cpu_offload()
# Set up generator for reproducibility
generator = None
if args.seed is not None:
generator = torch.Generator(device=args.device).manual_seed(args.seed)
print(f"\nPrompt: {args.prompt}")
print(
f"Generating {args.gen_length} tokens with block_length={args.block_length}, steps={args.num_inference_steps}"
)
print("-" * 50)
# Generate
output = pipe(
prompt=args.prompt,
use_chat_template=args.use_chat_template,
add_generation_prompt=args.add_generation_prompt,
gen_length=args.gen_length,
block_length=args.block_length,
num_inference_steps=args.num_inference_steps,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
threshold=args.threshold,
editing_threshold=args.editing_threshold,
max_post_steps=args.max_post_steps,
sampling_method=args.sampling_method,
eos_early_stop=args.eos_early_stop,
generator=generator,
)
print("\nGenerated text:")
print(output.texts[0])
print(f"\nGenerated {output.sequences.shape[1]} tokens")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,321 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import math
import os
from dataclasses import asdict, dataclass
from typing import Dict, Optional
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler
from diffusers import BlockRefinementScheduler
from diffusers.training_utils import compute_confidence_aware_loss
logger = get_logger(__name__)
@dataclass
class TrainConfig:
model_name_or_path: str
dataset_name: str
dataset_config_name: Optional[str]
text_column: str
cache_dir: Optional[str]
use_dummy_data: bool
num_dummy_samples: int
output_dir: str
seed: int
max_train_steps: int
checkpointing_steps: int
logging_steps: int
per_device_train_batch_size: int
gradient_accumulation_steps: int
learning_rate: float
weight_decay: float
lr_scheduler: str
lr_warmup_steps: int
max_length: int
prompt_length: int
block_length: int
lambda_conf: float
conf_temperature: float
def parse_args() -> TrainConfig:
parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.")
parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B")
parser.add_argument("--dataset_name", type=str, default="wikitext")
parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1")
parser.add_argument("--text_column", type=str, default="text")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.")
parser.add_argument("--num_dummy_samples", type=int, default=2048)
parser.add_argument("--output_dir", type=str, default="block-refinement-output")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max_train_steps", type=int, default=1000)
parser.add_argument("--checkpointing_steps", type=int, default=500)
parser.add_argument("--logging_steps", type=int, default=50)
parser.add_argument("--per_device_train_batch_size", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument(
"--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"]
)
parser.add_argument("--lr_warmup_steps", type=int, default=100)
parser.add_argument("--max_length", type=int, default=256)
parser.add_argument("--prompt_length", type=int, default=32)
parser.add_argument("--block_length", type=int, default=32)
parser.add_argument("--lambda_conf", type=float, default=2.0)
parser.add_argument("--conf_temperature", type=float, default=0.5)
args = parser.parse_args()
return TrainConfig(**vars(args))
def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int):
texts = examples[text_column]
texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0]
return tokenizer(texts, truncation=True, padding=False, max_length=max_length)
class RandomTokenDataset(torch.utils.data.Dataset):
def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, pad_token_id: int):
self.num_samples = int(num_samples)
self.seq_len = int(seq_len)
self.vocab_size = int(vocab_size)
self.pad_token_id = int(pad_token_id)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
del idx
input_ids = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attention_mask}
def main():
cfg = parse_args()
if cfg.prompt_length >= cfg.max_length:
raise ValueError("`prompt_length` must be < `max_length`.")
if cfg.block_length <= 0:
raise ValueError("`block_length` must be > 0.")
project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs"))
accelerator = Accelerator(
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
project_config=project_config,
)
if accelerator.is_main_process:
os.makedirs(cfg.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
set_seed(cfg.seed)
logger.info("Training configuration: %s", asdict(cfg))
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True, cache_dir=cfg.cache_dir)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.mask_token_id is None:
tokenizer.add_special_tokens({"mask_token": "[MASK]"})
load_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, cache_dir=cfg.cache_dir, dtype=load_dtype)
model.resize_token_embeddings(len(tokenizer))
if load_dtype == torch.float32:
model.to(dtype=torch.float32)
mask_token_id = int(tokenizer.mask_token_id)
if cfg.use_dummy_data:
dataset = RandomTokenDataset(
num_samples=cfg.num_dummy_samples,
seq_len=cfg.max_length,
vocab_size=len(tokenizer),
pad_token_id=int(tokenizer.pad_token_id),
)
train_dataloader = DataLoader(
dataset,
shuffle=True,
batch_size=cfg.per_device_train_batch_size,
drop_last=True,
)
else:
raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name, cache_dir=cfg.cache_dir)
if "train" not in raw_datasets:
raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.")
with accelerator.main_process_first():
tokenized = raw_datasets["train"].map(
lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length),
batched=True,
remove_columns=raw_datasets["train"].column_names,
desc="Tokenizing",
)
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")
train_dataloader = DataLoader(
tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True
)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
lr_scheduler = get_scheduler(
name=cfg.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.lr_warmup_steps,
num_training_steps=cfg.max_train_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
noise_scheduler = BlockRefinementScheduler(block_length=cfg.block_length)
global_step = 0
model.train()
for _epoch in range(num_train_epochs):
for batch in train_dataloader:
with accelerator.accumulate(model):
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask", torch.ones_like(input_ids))
gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step)
noisy, noisy_rev, masked, masked_rev = noise_scheduler.add_noise(
input_ids,
attention_mask,
prompt_length=cfg.prompt_length,
block_length=cfg.block_length,
mask_token_id=mask_token_id,
generator=gen,
)
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids)
)
logits = model(input_ids=noisy, attention_mask=attention_mask, position_ids=position_ids).logits
logits_rev = model(
input_ids=noisy_rev, attention_mask=attention_mask, position_ids=position_ids
).logits
logits = logits.clone()
logits[..., mask_token_id] = torch.finfo(logits.dtype).min
logits_rev = logits_rev.clone()
logits_rev[..., mask_token_id] = torch.finfo(logits_rev.dtype).min
valid = attention_mask.to(dtype=torch.bool)
masked = masked & valid
masked_rev = masked_rev & valid
labels = input_ids.clone()
labels[~masked] = -100
labels_rev = input_ids.clone()
labels_rev[~masked_rev] = -100
weights = masked.to(dtype=logits.dtype)
weights_rev = masked_rev.to(dtype=logits.dtype)
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
logits,
labels,
lambda_conf=cfg.lambda_conf,
temperature=cfg.conf_temperature,
per_token_weights=weights,
)
loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss(
logits_rev,
labels_rev,
lambda_conf=cfg.lambda_conf,
temperature=cfg.conf_temperature,
per_token_weights=weights_rev,
)
total_loss = loss + loss_rev
accelerator.backward(total_loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if accelerator.sync_gradients:
global_step += 1
if global_step % cfg.logging_steps == 0 and accelerator.is_main_process:
logger.info(
"step=%d loss=%.4f sft=%.4f conf=%.4f lr=%.6g",
global_step,
total_loss.item(),
(loss_sft + loss_sft_rev).item(),
(loss_conf + loss_conf_rev).item(),
lr_scheduler.get_last_lr()[0],
)
print(
f"step={global_step} loss={total_loss.item():.4f} "
f"sft={(loss_sft + loss_sft_rev).item():.4f} "
f"conf={(loss_conf + loss_conf_rev).item():.4f} "
f"lr={lr_scheduler.get_last_lr()[0]:.6g}"
)
if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}")
os.makedirs(save_dir, exist_ok=True)
accelerator.unwrap_model(model).save_pretrained(save_dir, save_function=accelerator.save)
tokenizer.save_pretrained(save_dir)
if global_step >= cfg.max_train_steps:
break
if global_step >= cfg.max_train_steps:
break
accelerator.wait_for_everyone()
if accelerator.is_main_process:
final_dir = os.path.join(cfg.output_dir, "final")
os.makedirs(final_dir, exist_ok=True)
accelerator.unwrap_model(model).save_pretrained(final_dir, save_function=accelerator.save)
tokenizer.save_pretrained(final_dir)
logger.info("Done.")
if __name__ == "__main__":
main()

View File

@@ -347,16 +347,17 @@ When LoRA was first adapted from language models to diffusion models, it was app
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
the exact modules for LoRA training. Here are some examples of target modules you can provide:
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.to_qkv_mlp_proj"`
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.to_qkv_mlp_proj,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.linear_in,ff.linear_out,ff_context.linear_in,ff_context.linear_out"`
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.to_qkv_mlp_proj,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.linear_in,ff.linear_out,ff_context.linear_in,ff_context.linear_out,norm_out.linear,norm_out.proj_out"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
> [!NOTE]
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
> [!NOTE]
In FLUX2, the q, k, and v projections are fused into a single linear layer named attn.to_qkv_mlp_proj within the single transformer block. Also, the attention output is just attn.to_out, not attn.to_out.0 — its no longer a ModuleList like in transformer block.
## Training Image-to-Image

View File

@@ -1256,7 +1256,13 @@ def main(args):
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
# train transformer_blocks and single_transformer_blocks
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
"to_qkv_mlp_proj",
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(48)],
]
# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(

View File

@@ -1206,7 +1206,13 @@ def main(args):
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
# train transformer_blocks and single_transformer_blocks
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
"to_qkv_mlp_proj",
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(48)],
]
# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(

View File

@@ -1249,7 +1249,13 @@ def main(args):
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
# train transformer_blocks and single_transformer_blocks
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
"to_qkv_mlp_proj",
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(24)],
]
# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(

View File

@@ -1200,7 +1200,13 @@ def main(args):
if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else:
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
# train transformer_blocks and single_transformer_blocks
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
"to_qkv_mlp_proj",
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(24)],
]
# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(

View File

@@ -1105,7 +1105,7 @@ def main(args):
# text encoding.
captions = batch["captions"]
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
captions, prompt_2=None

View File

@@ -1251,7 +1251,7 @@ def main(args):
# text encoding.
captions = batch["captions"]
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
captions, prompt_2=None

View File

@@ -344,6 +344,8 @@ else:
_import_structure["schedulers"].extend(
[
"AmusedScheduler",
"BlockRefinementScheduler",
"BlockRefinementSchedulerOutput",
"CMStochasticIterativeScheduler",
"CogVideoXDDIMScheduler",
"CogVideoXDPMScheduler",
@@ -580,6 +582,8 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LLaDA2Pipeline",
"LLaDA2PipelineOutput",
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ConditionPipeline",
@@ -1124,6 +1128,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
BlockRefinementScheduler,
BlockRefinementSchedulerOutput,
CMStochasticIterativeScheduler,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
@@ -1339,6 +1345,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LLaDA2Pipeline,
LLaDA2PipelineOutput,
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ConditionPipeline,

View File

@@ -862,23 +862,23 @@ def _native_attention_backward_op(
key.requires_grad_(True)
value.requires_grad_(True)
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query_t,
key=key_t,
value=value_t,
attn_mask=ctx.attn_mask,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
enable_gqa=ctx.enable_gqa,
)
out = out.permute(0, 2, 1, 3)
with torch.enable_grad():
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query_t,
key=key_t,
value=value_t,
attn_mask=ctx.attn_mask,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
enable_gqa=ctx.enable_gqa,
)
out = out.permute(0, 2, 1, 3)
grad_out_t = grad_out.permute(0, 2, 1, 3)
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
)
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out, retain_graph=False
)
grad_query = grad_query_t.permute(0, 2, 1, 3)
grad_key = grad_key_t.permute(0, 2, 1, 3)

View File

@@ -285,6 +285,7 @@ else:
]
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
_import_structure["ltx"] = [
"LTXPipeline",
"LTXImageToVideoPipeline",
@@ -728,6 +729,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
from .ltx import (
LTXConditionPipeline,

View File

@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,491 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
import torch
from tqdm.auto import tqdm
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...schedulers import BlockRefinementScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
>>> model_id = "inclusionAI/LLaDA2.1-mini"
>>> model = AutoModelForCausalLM.from_pretrained(
... model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
... )
>>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
>>> scheduler = BlockRefinementScheduler()
>>> pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
>>> output = pipe(prompt="What is the meaning of life?", gen_length=256)
>>> print(output.texts[0])
```
"""
@dataclass
class LLaDA2PipelineOutput(BaseOutput):
sequences: torch.LongTensor
texts: list[str] | None = None
class LLaDA2Pipeline(DiffusionPipeline):
r"""
Pipeline for LLaDA2-style discrete diffusion text generation via block-wise iterative refinement.
This pipeline maintains a template sequence filled with a `mask_token_id` and refines it in blocks. In each
refinement step, it samples candidate tokens for the active block and commits a subset based on confidence.
The model is expected to accept an attention mask and `position_ids`, and to return logits of shape `[batch, seq,
vocab_size]`.
"""
model: Any
scheduler: BlockRefinementScheduler
tokenizer: Any
_callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"]
def __init__(
self,
model: Any,
scheduler: BlockRefinementScheduler,
tokenizer: Any | None = None,
):
super().__init__()
self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer)
self.eos_token_id = getattr(self.tokenizer, "eos_token_id", None) if self.tokenizer is not None else None
self.mask_token_id = getattr(self.tokenizer, "mask_token_id", None) if self.tokenizer is not None else None
@property
def num_timesteps(self):
return self._num_timesteps
# --- Prompt encoding ---
def _prepare_input_ids(
self,
*,
prompt: str | list[str] | None,
messages: list[dict[str, str]] | None,
input_ids: torch.LongTensor | None,
use_chat_template: bool,
add_generation_prompt: bool,
chat_template_kwargs: dict[str, Any] | None,
) -> torch.LongTensor:
"""Convert prompt/messages/input_ids to a [batch, seq] LongTensor."""
if input_ids is not None:
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
if input_ids.ndim != 2:
raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.")
if input_ids.dtype != torch.long:
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
return input_ids
if self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
if messages is not None and prompt is not None:
raise ValueError("Provide either `prompt` or `messages`, not both.")
if messages is None and prompt is None:
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
chat_template_kwargs = chat_template_kwargs or {}
if messages is not None:
encoded = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
tokenize=True,
return_tensors="pt",
return_dict=True,
**chat_template_kwargs,
)
return encoded["input_ids"]
if use_chat_template and getattr(self.tokenizer, "chat_template", None):
if isinstance(prompt, list):
raise ValueError("`prompt` must be a string when `use_chat_template=True`.")
encoded = self.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=add_generation_prompt,
tokenize=True,
return_tensors="pt",
return_dict=True,
**chat_template_kwargs,
)
return encoded["input_ids"]
encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list))
return encoded["input_ids"]
def check_inputs(
self,
prompt: str | list[str] | None,
messages: list[dict[str, str]] | None,
input_ids: torch.LongTensor | None,
gen_length: int,
block_length: int,
num_inference_steps: int,
minimal_topk: int,
threshold: float,
sampling_method: str,
output_type: str,
callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None,
callback_on_step_end_tensor_inputs: list[str] | None,
):
# Input source validation
if prompt is None and messages is None and input_ids is None:
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
if prompt is not None and messages is not None:
raise ValueError("Provide either `prompt` or `messages`, not both.")
if input_ids is not None:
if input_ids.ndim not in (1, 2):
raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.")
if input_ids.dtype != torch.long:
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
if prompt is not None and input_ids is None and self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
if messages is not None and input_ids is None and self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
# Generation parameter validation
if gen_length <= 0:
raise ValueError(f"`gen_length` must be > 0, got {gen_length}.")
if block_length <= 0:
raise ValueError(f"`block_length` must be > 0, got {block_length}.")
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
if minimal_topk <= 0:
raise ValueError(f"`minimal_topk` must be > 0, got {minimal_topk}.")
if not (0.0 <= threshold <= 1.0) and not (threshold > 1.0):
raise ValueError(f"`threshold` must be in [0, 1] (or > 1 to force top-k commits), got {threshold}.")
if sampling_method not in {"auto", "greedy", "multinomial"}:
raise ValueError(
f"`sampling_method` must be one of {{'auto','greedy','multinomial'}}, got {sampling_method!r}."
)
if output_type not in {"seq", "text"}:
raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.")
# Callback validation
if callback_on_step_end is not None and isinstance(
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: str | list[str] | None = None,
messages: list[dict[str, str]] | None = None,
input_ids: torch.LongTensor | None = None,
use_chat_template: bool = True,
add_generation_prompt: bool = True,
gen_length: int = 2048,
block_length: int = 32,
num_inference_steps: int = 32,
temperature: float = 0.0,
top_p: float | None = None,
top_k: int | None = None,
sampling_method: str = "multinomial",
threshold: float = 0.7,
editing_threshold: float | None = 0.5,
max_post_steps: int = 16,
minimal_topk: int = 1,
eos_early_stop: bool = True,
eos_token_id: int | None = None,
mask_token_id: int | None = None,
generator: torch.Generator | None = None,
output_type: str = "text",
return_dict: bool = True,
callback_on_step_end: Callable[[int, int, dict], None]
| PipelineCallback
| MultiPipelineCallbacks
| None = None,
callback_on_step_end_tensor_inputs: list[str] | None = None,
) -> LLaDA2PipelineOutput | tuple[torch.LongTensor, list[str] | None]:
"""
Generate text with block-wise refinement.
Args:
prompt (`str` or `List[str]`, *optional*):
Prompt text. When `use_chat_template` is `True` (default) and a tokenizer with a chat template is
available, the prompt is wrapped in a chat message before tokenization.
messages (`List[Dict[str, str]]`, *optional*):
Chat messages to encode (e.g. `[{"role": "user", "content": "Hello"}]`). Takes precedence over `prompt`
when provided. Requires a tokenizer with `apply_chat_template`.
input_ids (`torch.LongTensor`, *optional*):
Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`.
use_chat_template (`bool`, defaults to `True`):
Whether to wrap the prompt in a chat template.
add_generation_prompt (`bool`, defaults to `True`):
Whether to add the generation prompt when using chat templates.
gen_length (`int`):
Number of tokens to generate.
block_length (`int`):
Block size for refinement.
num_inference_steps (`int`):
Number of refinement steps per block.
temperature (`float`):
Sampling temperature.
top_p (`float`, *optional*):
Nucleus sampling cutoff.
top_k (`int`, *optional*):
Top-k sampling cutoff.
sampling_method (`str`):
Sampling method (`auto`, `greedy`, `multinomial`).
threshold (`float`):
Confidence threshold for committing tokens.
editing_threshold (`float`, *optional*):
Confidence threshold for editing already-committed (non-mask) tokens. When positive, after all mask
tokens in a block are resolved, the pipeline continues refining: if the model predicts a different
token with confidence above this threshold, the existing token is replaced. Set to `None`, `0.0`, or a
negative value to disable editing. Defaults to `0.5`.
max_post_steps (`int`):
Maximum number of additional refinement iterations after all mask tokens in a block are resolved. Only
used when `editing_threshold` is enabled. Defaults to `16`.
minimal_topk (`int`):
Minimum number of tokens to commit per step.
eos_early_stop (`bool`):
Whether to stop after committing EOS in a block.
eos_token_id (`int`, *optional*):
EOS token ID to use for early stopping.
mask_token_id (`int`, *optional*):
Mask token ID to use for the template.
generator (`torch.Generator`, *optional*):
RNG for sampling.
output_type (`str`, defaults to `"text"`):
Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw
token ID sequences only.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`LLaDA2PipelineOutput`] instead of a tuple.
callback_on_step_end (`Callable` or `PipelineCallback`, *optional*):
Callback executed after each refinement step with signature `callback_on_step_end(self, step: int,
timestep: int, callback_kwargs: Dict)`.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, `transfer_index`,
`confidence`, `active_block`.
Examples:
"""
# 1. Check inputs early
if callback_on_step_end is not None and isinstance(
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if callback_on_step_end_tensor_inputs is None:
callback_on_step_end_tensor_inputs = ["block_x"]
self.check_inputs(
prompt=prompt,
messages=messages,
input_ids=input_ids,
gen_length=gen_length,
block_length=block_length,
num_inference_steps=num_inference_steps,
minimal_topk=minimal_topk,
threshold=threshold,
sampling_method=sampling_method,
output_type=output_type,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
# 2. Prepare input IDs from prompt/messages/input_ids
prompt_ids = self._prepare_input_ids(
prompt=prompt,
messages=messages,
input_ids=input_ids,
use_chat_template=use_chat_template,
add_generation_prompt=add_generation_prompt,
chat_template_kwargs=None,
)
device = self._execution_device
if prompt_ids.ndim == 1:
prompt_ids = prompt_ids.unsqueeze(0)
prompt_ids = prompt_ids.to(device=device)
batch_size, prompt_length = prompt_ids.shape
if eos_token_id is None:
eos_token_id = self.eos_token_id
if mask_token_id is None:
mask_token_id = self.mask_token_id
if mask_token_id is None:
raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).")
num_inference_steps = min(num_inference_steps, gen_length // minimal_topk)
self.scheduler.set_timesteps(num_inference_steps, device=device)
# 3. Build attention mask and position IDs
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
total_length = num_blocks * block_length
# 2D attention mask (no padding) — the model handles backend-specific conversion internally.
attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long)
position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
# 4. Prepare latents (fully masked sequence)
x = torch.full((batch_size, total_length), mask_token_id, device=device, dtype=torch.long)
if prompt_length > 0:
x[:, :prompt_length] = prompt_ids
prefill_blocks = prompt_length // block_length
self._num_timesteps = num_inference_steps * max(num_blocks - prefill_blocks, 0)
finished = torch.zeros((batch_size,), device=device, dtype=torch.bool)
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
global_step = 0
# 5. Block-wise refinement loop
block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()
block_progress_bar_config["position"] = 0
block_progress_bar_config["desc"] = "Blocks"
for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config):
current_window_end = (num_block + 1) * block_length
block_x = x[:, :current_window_end]
block_attn_mask = attn_mask[:, :current_window_end]
block_position_ids = position_ids[:, :current_window_end]
# Identify which positions in the block are prompt (non-editable).
block_start_pos = num_block * block_length
prompt_mask_in_block = torch.zeros(block_length, device=device, dtype=torch.bool)
if block_start_pos < prompt_length:
prompt_end_in_block = min(prompt_length - block_start_pos, block_length)
prompt_mask_in_block[:prompt_end_in_block] = True
post_steps = 0
step_idx = 0
should_continue = True
self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps")
progress_bar = self.progress_bar(total=num_inference_steps)
while should_continue:
block_tokens = block_x[:, -block_length:]
masks_remaining = (block_tokens == mask_token_id).any()
if not masks_remaining:
post_steps += 1
logits = self.model(block_x, attention_mask=block_attn_mask, position_ids=block_position_ids).logits
block_logits = logits[:, -block_length:, :]
scheduler_output = self.scheduler.step(
model_output=block_logits,
timestep=step_idx,
sample=block_tokens,
mask_token_id=mask_token_id,
temperature=temperature,
top_p=top_p,
top_k=top_k,
sampling_method=sampling_method,
threshold=threshold,
editing_threshold=editing_threshold,
minimal_topk=minimal_topk,
prompt_mask=prompt_mask_in_block,
generator=generator,
return_dict=True,
)
transfer_index = scheduler_output.transfer_index
editing_transfer_index = scheduler_output.editing_transfer_index
final_transfer = transfer_index | editing_transfer_index
if final_transfer.any():
block_x[:, -block_length:] = scheduler_output.prev_sample
if eos_early_stop and eos_token_id is not None:
finished = self.scheduler.check_eos_finished(
cur_x=block_x,
sampled_tokens=scheduler_output.sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_token_id,
mask_token_id=mask_token_id,
prompt_length=prompt_length,
)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs)
block_x = callback_outputs.pop("block_x", block_x)
global_step += 1
if masks_remaining:
step_idx += 1
progress_bar.update(1)
should_continue = self.scheduler.check_block_should_continue(
step_idx=step_idx,
masks_remaining=masks_remaining,
editing_enabled=editing_enabled,
editing_transfer_index=editing_transfer_index,
post_steps=post_steps,
max_post_steps=max_post_steps,
finished=finished,
)
progress_bar.close()
x[:, :current_window_end] = block_x
if eos_early_stop and finished.all():
break
# 6. Post-process output
generated = x[:, : prompt_length + gen_length]
sequences = generated[:, prompt_length:]
if eos_token_id is not None and batch_size == 1:
eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0]
if len(eos_positions) > 0:
sequences = sequences[:, : int(eos_positions[0].item()) + 1]
texts = None
if output_type == "text" and self.tokenizer is not None:
texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
if not return_dict:
return sequences.to(device=device), texts
return LLaDA2PipelineOutput(sequences=sequences.to(device=device), texts=texts)
__all__ = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]

View File

@@ -1,6 +1,155 @@
# Copyright 2026 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Pre-trained sigma values for distilled model are taken from
# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]
# Reduced schedule for super-resolution stage 2 (subset of distilled values)
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875]
# Default negative prompt from
# https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py#L131-L143
DEFAULT_NEGATIVE_PROMPT = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
# System prompts for prompt enhancement
# https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt#L1
# Disable line-too-long rule in ruff to keep the prompts exactly the same (e.g. in terms of newlines)
# Supported in ruff>=0.15.0
# ruff: disable[E501]
T2V_DEFAULT_SYSTEM_PROMPT = """
You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed
video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
#### Guidelines
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions,
actions, camera movement, audio).
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural
movements.
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested).
Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g.,
"ambient sound is present").
- Speech (only when requested):
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with
voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
- Specify language if not English and accent if relevant.
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if
unspecified. Omit if unclear.
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
#### Important notes:
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is
requested.
- Camera motion: DO NOT invent camera motion unless requested by the user.
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological
scene description.
- Format: DO NOT start your response with special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits
or introduce new elements. Add/enhance audio descriptions if missing.
#### Output Format (Strict):
- Single continuous paragraph in natural language (English).
- NO titles, headings, prefaces, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video
generation.
#### Example Input: "A woman at a coffee shop talking on the phone" Output: Style: realistic with cinematic lighting.
In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the
window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone
to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle
clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a
soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd
love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her
chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully,
lowering the phone.
"""
# ruff: enable[E501]
# ruff: disable[E501]
I2V_DEFAULT_SYSTEM_PROMPT = """
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and
user Raw Input Prompt, generate a prompt to guide video generation from that image.
#### Guidelines:
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in
conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image
to user's scene).
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause
scene cuts.
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural
movements.
- Chronological flow: Use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio
intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when
requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The
tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation
mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should
include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just
saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room
underscores his animated speech.")
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
#### Important notes:
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion
only if specified in the input.
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style
(optional) and chronological scene description.
- Format: Never start output with punctuation marks or special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio
descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
#### Output Format (Strict):
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
#### Example output: Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a
cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the
counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine
hissing softly blends with gentle background chatter and the light clinking of cups on saucers.
"""
# ruff: enable[E501]

View File

@@ -23,20 +23,17 @@ https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e
from __future__ import annotations
import copy
import dataclasses
import importlib.metadata
import inspect
import json
import os
import warnings
from dataclasses import dataclass, is_dataclass
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Any, Callable
from packaging import version
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
from ..utils import deprecate, is_torch_available, is_torchao_version, logging
if is_torch_available():
@@ -53,16 +50,6 @@ class QuantizationMethod(str, Enum):
MODELOPT = "modelopt"
if is_torchao_available():
from torchao.quantization.quant_primitives import MappingType
class TorchAoJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, MappingType):
return obj.name
return super().default(obj)
@dataclass
class QuantizationConfigMixin:
"""
@@ -446,49 +433,21 @@ class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.
Args:
quant_type (`str` | AOBaseConfig):
The type of quantization we want to use, currently supporting:
- **Integer quantization:**
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
`int8_weight_only`, `int8_dynamic_activation_int8_weight`
- Shorthands: `int4wo`, `int4dq`, `int8wo`, `int8dq`
- **Floating point 8-bit quantization:**
- Full function names: `float8_weight_only`, `float8_dynamic_activation_float8_weight`,
`float8_static_activation_float8_weight`
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
`float8_e4m3_tensor`, `float8_e4m3_row`,
- **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0)
- Full function names: `fpx_weight_only`
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
be satisfied for a given shorthand notation.
- **Unsigned Integer quantization:**
- Full function names: `uintx_weight_only`
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
- An AOBaseConfig instance: for more advanced configuration options.
quant_type (`AOBaseConfig`):
An `AOBaseConfig` subclass instance specifying the quantization type. See the [torchao
documentation](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) for
available config classes (e.g. `Int4WeightOnlyConfig`, `Int8WeightOnlyConfig`, `Float8WeightOnlyConfig`,
`Float8DynamicActivationFloat8WeightConfig`, etc.).
modules_to_not_convert (`list[str]`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision.
kwargs (`dict[str, Any]`, *optional*):
The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization
supports two keyword arguments `group_size` and `inner_k_tiles` currently. More API examples and
documentation of arguments can be found in
https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques
Example:
```python
from diffusers import FluxTransformer2DModel, TorchAoConfig
# AOBaseConfig-based configuration
from torchao.quantization import Int8WeightOnlyConfig
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
# String-based config
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
@@ -500,7 +459,7 @@ class TorchAoConfig(QuantizationConfigMixin):
def __init__(
self,
quant_type: str | "AOBaseConfig", # noqa: F821
quant_type: "AOBaseConfig", # noqa: F821
modules_to_not_convert: list[str] | None = None,
**kwargs,
) -> None:
@@ -508,102 +467,39 @@ class TorchAoConfig(QuantizationConfigMixin):
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
# When we load from serialized config, "quant_type_kwargs" will be the key
if "quant_type_kwargs" in kwargs:
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
else:
self.quant_type_kwargs = kwargs
self.post_init()
def post_init(self):
if not isinstance(self.quant_type, str):
if is_torchao_version("<=", "0.9.0"):
raise ValueError(
f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
)
if is_torchao_version("<", "0.15.0"):
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")
from torchao.quantization.quant_api import AOBaseConfig
from torchao.quantization.quant_api import AOBaseConfig
if not isinstance(self.quant_type, AOBaseConfig):
raise TypeError(f"quant_type must be a AOBaseConfig instance, got {type(self.quant_type).__name__}")
elif isinstance(self.quant_type, str):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floatx_quant_type = self.quant_type.startswith("fp")
is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type
if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
)
elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"):
raise ValueError(
f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. "
f"Please downgrade to torchao <= 0.14.1 to use this quantization type."
)
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
signature = inspect.signature(method)
all_kwargs = {
param.name
for param in signature.parameters.values()
if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
}
unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
if len(unsupported_kwargs) > 0:
raise ValueError(
f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
)
if not isinstance(self.quant_type, AOBaseConfig):
raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}")
def to_dict(self):
"""Convert configuration to a dictionary."""
d = super().to_dict()
if isinstance(self.quant_type, str):
# Handle layout serialization if present
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
if is_dataclass(d["quant_type_kwargs"]["layout"]):
d["quant_type_kwargs"]["layout"] = [
d["quant_type_kwargs"]["layout"].__class__.__name__,
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
]
if isinstance(d["quant_type_kwargs"]["layout"], list):
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs"
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
else:
raise ValueError("layout must be a list")
else:
# Handle AOBaseConfig serialization
from torchao.core.config import config_to_dict
# Handle AOBaseConfig serialization
from torchao.core.config import config_to_dict
# For now we assume there is 1 config per Transformer, however in the future
# We may want to support a config per fqn.
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
# For now we assume there is 1 config per Transformer, however in the future
# we may want to support a config per fqn.
# See: https://docs.pytorch.org/ao/stable/api_reference/generated/torchao.quantization.quantize_.html
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
return d
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
"""Create configuration from a dictionary."""
if not is_torchao_version(">", "0.9.0"):
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
if not is_torchao_version(">=", "0.15.0"):
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")
if isinstance(quant_type, str):
return cls(quant_type=quant_type, **config_dict)
# Check if we only have one key which is "default"
# In the future we may update this
assert len(quant_type) == 1 and "default" in quant_type, (
@@ -618,210 +514,13 @@ class TorchAoConfig(QuantizationConfigMixin):
return cls(quant_type=quant_type, **config_dict)
@classmethod
def _get_torchao_quant_type_to_method(cls):
r"""
Returns supported torchao quantization types with all commonly used notations.
"""
if is_torchao_available():
# TODO(aryan): Support sparsify
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)
if is_torchao_version("<=", "0.14.1"):
from torchao.quantization import fpx_weight_only
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
from torchao.quantization.observer import PerRow, PerTensor
def generate_float8dq_types(dtype: torch.dtype):
name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3"
types = {}
for granularity_cls in [PerTensor, PerRow]:
# Note: Activation and Weights cannot have different granularities
granularity_name = "tensor" if granularity_cls is PerTensor else "row"
types[f"float8dq_{name}_{granularity_name}"] = partial(
float8_dynamic_activation_float8_weight,
activation_dtype=dtype,
weight_dtype=dtype,
granularity=(granularity_cls(), granularity_cls()),
)
return types
def generate_fpx_quantization_types(bits: int):
if is_torchao_version("<=", "0.14.1"):
types = {}
for ebits in range(1, bits):
mbits = bits - ebits - 1
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
non_sign_bits = bits - 1
default_ebits = (non_sign_bits + 1) // 2
default_mbits = non_sign_bits - default_ebits
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
return types
else:
raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0")
INT4_QUANTIZATION_TYPES = {
# int4 weight + bfloat16/float16 activation
"int4wo": int4_weight_only,
"int4_weight_only": int4_weight_only,
# int4 weight + int8 activation
"int4dq": int8_dynamic_activation_int4_weight,
"int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight,
}
INT8_QUANTIZATION_TYPES = {
# int8 weight + bfloat16/float16 activation
"int8wo": int8_weight_only,
"int8_weight_only": int8_weight_only,
# int8 weight + int8 activation
"int8dq": int8_dynamic_activation_int8_weight,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
}
# TODO(aryan): handle torch 2.2/2.3
FLOATX_QUANTIZATION_TYPES = {
# float8_e5m2 weight + bfloat16/float16 activation
"float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
"float8_weight_only": float8_weight_only,
"float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2),
# float8_e4m3 weight + bfloat16/float16 activation
"float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn),
# float8_e5m2 weight + float8 activation (dynamic)
"float8dq": float8_dynamic_activation_float8_weight,
"float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight,
# ===== Matrix multiplication is not supported in float8_e5m2 so the following errors out.
# However, changing activation_dtype=torch.float8_e4m3 might work here =====
# "float8dq_e5m2": partial(
# float8_dynamic_activation_float8_weight,
# activation_dtype=torch.float8_e5m2,
# weight_dtype=torch.float8_e5m2,
# ),
# **generate_float8dq_types(torch.float8_e5m2),
# ===== =====
# float8_e4m3 weight + float8 activation (dynamic)
"float8dq_e4m3": partial(
float8_dynamic_activation_float8_weight,
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
),
**generate_float8dq_types(torch.float8_e4m3fn),
# float8 weight + float8 activation (static)
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
}
if is_torchao_version("<=", "0.14.1"):
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7))
UINTX_QUANTIZATION_DTYPES = {
"uintx_weight_only": uintx_weight_only,
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),
"uint2wo": partial(uintx_weight_only, dtype=torch.uint2),
"uint3wo": partial(uintx_weight_only, dtype=torch.uint3),
"uint4wo": partial(uintx_weight_only, dtype=torch.uint4),
"uint5wo": partial(uintx_weight_only, dtype=torch.uint5),
"uint6wo": partial(uintx_weight_only, dtype=torch.uint6),
"uint7wo": partial(uintx_weight_only, dtype=torch.uint7),
# "uint8wo": partial(uintx_weight_only, dtype=torch.uint8), # uint8 quantization is not supported
}
QUANTIZATION_TYPES = {}
QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
if cls._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
return QUANTIZATION_TYPES
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
)
@staticmethod
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
if major == 8:
return minor >= 9
return major >= 9
elif torch.xpu.is_available():
return True
else:
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
def get_apply_tensor_subclass(self):
"""Create the appropriate quantization method based on configuration."""
if not isinstance(self.quant_type, str):
return self.quant_type
else:
methods = self._get_torchao_quant_type_to_method()
quant_type_kwargs = self.quant_type_kwargs.copy()
if (
not torch.cuda.is_available()
and is_torchao_available()
and self.quant_type == "int4_weight_only"
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
and quant_type_kwargs.get("layout", None) is None
):
if torch.xpu.is_available():
if version.parse(importlib.metadata.version("torchao")) >= version.parse(
"0.11.0"
) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"):
from torchao.dtypes import Int4XPULayout
from torchao.quantization.quant_primitives import ZeroPointDomain
quant_type_kwargs["layout"] = Int4XPULayout()
quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT
else:
raise ValueError(
"TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
)
else:
from torchao.dtypes import Int4CPULayout
quant_type_kwargs["layout"] = Int4CPULayout()
return methods[self.quant_type](**quant_type_kwargs)
return self.quant_type
def __repr__(self):
r"""
Example of how this looks for `TorchAoConfig("uint4wo", group_size=32)`:
```
TorchAoConfig {
"modules_to_not_convert": null,
"quant_method": "torchao",
"quant_type": "uint4wo",
"quant_type_kwargs": {
"group_size": 32
}
}
```
"""
config_dict = self.to_dict()
return (
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
)
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
@dataclass

View File

@@ -20,7 +20,6 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
import importlib
import re
import types
from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any
from packaging import version
@@ -114,7 +113,7 @@ if (
is_torch_available()
and is_torch_version(">=", "2.6.0")
and is_torchao_available()
and is_torchao_version(">=", "0.7.0")
and is_torchao_version(">=", "0.15.0")
):
_update_torch_safe_globals()
@@ -169,10 +168,10 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
raise ImportError(
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
)
torchao_version = version.parse(importlib.metadata.version("torch"))
if torchao_version < version.parse("0.7.0"):
torchao_version = version.parse(importlib.metadata.version("torchao"))
if torchao_version < version.parse("0.15.0"):
raise RuntimeError(
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
f"The minimum required version of `torchao` is 0.15.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
)
self.offload = False
@@ -199,13 +198,13 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
)
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type
if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
)
config_name = self.quantization_config.quant_type.__class__.__name__
is_int_quant = config_name.startswith("Int") or config_name.startswith("Uint")
if is_int_quant and torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for integer quantization, but "
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
)
if torch_dtype is None:
# We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
@@ -219,45 +218,16 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
return torch_dtype
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
quant_type = self.quantization_config.quant_type
from accelerate.utils import CustomDtype
if isinstance(quant_type, str):
if quant_type.startswith("int8"):
# Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
return torch.int8
elif quant_type.startswith("int4"):
return CustomDtype.INT4
elif quant_type == "uintx_weight_only":
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
elif quant_type.startswith("uint"):
return {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
}[int(quant_type[4])]
elif quant_type.startswith("float") or quant_type.startswith("fp"):
return torch.bfloat16
quant_type = self.quantization_config.quant_type
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
elif is_torchao_version(">", "0.9.0"):
from torchao.core.config import AOBaseConfig
quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
else:
# Default to int8
return torch.int8
if size_digit == "4":
return CustomDtype.INT4
else:
return torch.int8
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
return target_dtype
@@ -337,29 +307,14 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
- Use a division factor of 8 for int4 weights
- Use a division factor of 4 for int8 weights
"""
# Original mapping for non-AOBaseConfig types
# For the uint types, this is a best guess. Once these types become more used
# we can look into their nuances.
if is_torchao_version(">", "0.9.0"):
from torchao.core.config import AOBaseConfig
quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
if size_digit == "4":
return 8
else:
return 4
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
quant_type = self.quantization_config.quant_type
for pattern, target_dtype in map_to_target_dtype.items():
if fnmatch(quant_type, pattern):
return target_dtype
raise ValueError(f"Unsupported quant_type: {quant_type!r}")
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
if size_digit == "4":
return 8
else:
return 4
def _process_model_before_weight_loading(
self,
@@ -415,9 +370,17 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
return _is_torchao_serializable
_TRAINABLE_QUANTIZATION_CONFIGS = (
"Int8WeightOnlyConfig",
"Int8DynamicActivationInt8WeightConfig",
"Int8StaticActivationInt8WeightConfig",
"Float8WeightOnlyConfig",
"Float8DynamicActivationFloat8WeightConfig",
)
@property
def is_trainable(self):
return self.quantization_config.quant_type.startswith("int8")
return self.quantization_config.quant_type.__class__.__name__ in self._TRAINABLE_QUANTIZATION_CONFIGS
@property
def is_compileable(self) -> bool:

View File

@@ -40,6 +40,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
_import_structure["scheduling_amused"] = ["AmusedScheduler"]
_import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
@@ -145,6 +146,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
from .scheduling_amused import AmusedScheduler
from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler

View File

@@ -0,0 +1,460 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class BlockRefinementSchedulerOutput(BaseOutput):
"""
Output class for block refinement scheduling.
Args:
prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
Updated block tokens after the current refinement step.
transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
Boolean mask indicating which tokens were committed (mask-filling).
editing_transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
Boolean mask indicating which tokens were edited (non-mask replacement).
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
Sampled token IDs from the model logits.
sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`):
Probabilities of the sampled tokens.
"""
prev_sample: torch.LongTensor
transfer_index: torch.BoolTensor
editing_transfer_index: torch.BoolTensor
sampled_tokens: torch.LongTensor
sampled_probs: torch.Tensor
class BlockRefinementScheduler(SchedulerMixin, ConfigMixin):
"""
Scheduler for block-wise iterative refinement (commit-by-confidence).
At each step, the scheduler samples candidate tokens from model logits and commits those with the highest
confidence. The number of tokens to commit per step is determined by evenly distributing the block length across
the number of refinement steps.
Optionally supports editing: after all mask tokens are resolved, tokens can be replaced if the model predicts a
different token with confidence above a positive `editing_threshold` (`None`, `0.0`, or negative disables editing).
"""
order = 1
@register_to_config
def __init__(
self,
block_length: int = 32,
num_inference_steps: int = 32,
threshold: float = 0.95,
editing_threshold: float | None = None,
minimal_topk: int = 1,
):
self.num_inference_steps = num_inference_steps
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long)
self._transfer_schedule: torch.LongTensor | None = None
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
self.num_inference_steps = num_inference_steps
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long)
self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to(
device=device if device is not None else "cpu"
)
def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor:
"""Evenly distribute `block_length` token commits across `num_inference_steps` steps."""
if num_inference_steps <= 0:
return torch.zeros((0,), dtype=torch.long)
base = block_length // num_inference_steps
remainder = block_length % num_inference_steps
out = torch.full((num_inference_steps,), base, dtype=torch.long)
out[:remainder] += 1
return out
# --- SAR sampling utilities ---
@staticmethod
def _top_p_filtering(logits: torch.Tensor, top_p: float | None) -> torch.Tensor:
"""Nucleus (top-p) logit filtering."""
if top_p is None or top_p >= 1.0:
return logits
if not (0.0 < top_p <= 1.0):
raise ValueError(f"`top_p` must be in (0, 1], got {top_p}.")
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumulative_probs = sorted_probs.cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs > float(top_p)
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, torch.finfo(sorted_logits.dtype).min)
filtered = logits.scatter(-1, sorted_indices, sorted_logits)
return filtered
@staticmethod
def _top_k_filtering(logits: torch.Tensor, top_k: int | None) -> torch.Tensor:
"""Top-k logit filtering."""
if top_k is None or top_k <= 0:
return logits
if top_k >= logits.shape[-1]:
return logits
values, _ = torch.topk(logits, k=top_k, dim=-1)
min_keep = values[..., -1, None]
return logits.masked_fill(logits < min_keep, torch.finfo(logits.dtype).min)
@staticmethod
def _sample_from_logits(
logits: torch.Tensor,
*,
temperature: float,
top_k: int | None,
top_p: float | None,
generator: torch.Generator | None,
use_multinomial: bool,
) -> tuple[torch.LongTensor, torch.Tensor]:
"""Sample tokens from logits with temperature scaling, top-k, and top-p."""
if temperature < 0:
raise ValueError(f"`temperature` must be >= 0, got {temperature}.")
vocab_size = logits.shape[-1]
flat_logits = logits.reshape(-1, vocab_size)
if temperature == 0.0 or not use_multinomial:
probs = torch.softmax(flat_logits.float(), dim=-1)
token = flat_logits.argmax(dim=-1, keepdim=True)
token_prob = torch.gather(probs, -1, token)
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
scaled = flat_logits
if temperature != 1.0:
scaled = flat_logits / temperature
filtered = BlockRefinementScheduler._top_k_filtering(scaled, top_k=top_k)
filtered = BlockRefinementScheduler._top_p_filtering(filtered, top_p=top_p)
probs = torch.softmax(filtered.float(), dim=-1)
token = torch.multinomial(probs, num_samples=1, generator=generator)
token_prob = torch.gather(probs, -1, token)
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
def step(
self,
model_output: torch.Tensor,
timestep: int | torch.Tensor,
sample: torch.LongTensor,
*,
mask_token_id: int,
temperature: float = 0.0,
top_p: float | None = None,
top_k: int | None = None,
sampling_method: str = "auto",
threshold: float | None = None,
editing_threshold: float | None = None,
minimal_topk: int | None = None,
prompt_mask: torch.BoolTensor | None = None,
generator: torch.Generator | None = None,
return_dict: bool = True,
) -> (
BlockRefinementSchedulerOutput
| tuple[torch.LongTensor, torch.BoolTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor]
):
"""
Perform a single refinement step: sample from logits, commit confident tokens, and optionally edit existing
ones.
Args:
model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`):
Raw logits from the model for the current block.
timestep (`int` or `torch.Tensor`):
Current step index within the block's refinement schedule.
sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
Current block token IDs (contains mask tokens for uncommitted positions).
mask_token_id (`int`):
Token ID used for masked positions.
temperature (`float`):
Sampling temperature.
top_p (`float`, *optional*):
Nucleus sampling cutoff.
top_k (`int`, *optional*):
Top-k sampling cutoff.
sampling_method (`str`):
Sampling method (`auto`, `greedy`, `multinomial`).
threshold (`float`, *optional*):
Confidence threshold for committing tokens. Defaults to config value.
editing_threshold (`float`, *optional*):
Confidence threshold for editing non-mask tokens; must be positive to enable editing. Defaults to
config value.
minimal_topk (`int`, *optional*):
Minimum tokens to commit per step. Defaults to config value.
prompt_mask (`torch.BoolTensor`, *optional*):
Boolean mask of shape `(block_length,)` where `True` marks prompt (non-editable) positions.
generator (`torch.Generator`, *optional*):
RNG for sampling.
return_dict (`bool`):
Whether to return a `BlockRefinementSchedulerOutput` or a tuple.
"""
if threshold is None:
threshold = float(self.config.threshold)
if editing_threshold is None:
editing_threshold = self.config.editing_threshold
if minimal_topk is None:
minimal_topk = self.config.minimal_topk
# Sample from logits
use_multinomial = sampling_method == "multinomial" or (sampling_method == "auto" and temperature != 0.0)
sampled_tokens, sampled_probs = self._sample_from_logits(
model_output,
temperature=temperature,
top_k=top_k,
top_p=top_p,
generator=generator,
use_multinomial=use_multinomial,
)
batch_size, block_length = sample.shape
active_block = sample == mask_token_id
masks_remaining = active_block.any()
if isinstance(timestep, torch.Tensor):
step_index = int(timestep.item())
else:
step_index = int(timestep)
# --- Mask-filling transfer ---
transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
if masks_remaining and self._transfer_schedule is not None:
clamped_step = min(step_index, len(self._transfer_schedule) - 1)
num_to_transfer = int(self._transfer_schedule[clamped_step].item())
confidence = torch.where(
active_block,
sampled_probs.to(dtype=torch.float32),
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
)
for b in range(batch_size):
high_conf = confidence[b] > threshold
if high_conf.sum().item() >= num_to_transfer:
transfer_index[b] = high_conf
else:
k = min(num_to_transfer, int(active_block[b].sum().item()))
if k > 0:
_, idx = torch.topk(confidence[b], k=k)
transfer_index[b, idx] = True
# --- Editing transfer (non-mask, non-prompt positions) ---
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
editing_transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
if editing_enabled:
if prompt_mask is None:
prompt_mask = torch.zeros(block_length, device=sample.device, dtype=torch.bool)
editable = (~active_block) & (~prompt_mask.unsqueeze(0))
editing_conf = torch.where(
editable,
sampled_probs.to(dtype=torch.float32),
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
)
high_conf_edit = editing_conf > float(editing_threshold)
token_changed = sampled_tokens != sample
editing_transfer_index = high_conf_edit & token_changed & editable
# Apply transfers
final_transfer = transfer_index | editing_transfer_index
prev_sample = sample.clone()
if final_transfer.any():
prev_sample[final_transfer] = sampled_tokens[final_transfer]
if not return_dict:
return prev_sample, transfer_index, editing_transfer_index, sampled_tokens, sampled_probs
return BlockRefinementSchedulerOutput(
prev_sample=prev_sample,
transfer_index=transfer_index,
editing_transfer_index=editing_transfer_index,
sampled_tokens=sampled_tokens,
sampled_probs=sampled_probs,
)
@staticmethod
def check_eos_finished(
cur_x: torch.LongTensor,
sampled_tokens: torch.LongTensor,
final_transfer: torch.BoolTensor,
finished: torch.BoolTensor,
eos_token_id: int,
mask_token_id: int,
prompt_length: int,
) -> torch.BoolTensor:
"""
Update per-batch finished flags when EOS tokens are committed.
Args:
cur_x (`torch.LongTensor` of shape `(batch_size, seq_len)`):
Current full sequence including all blocks up to the current window.
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
Tokens sampled by the scheduler in this step.
final_transfer (`torch.BoolTensor` of shape `(batch_size, block_length)`):
Combined mask of committed and edited positions.
finished (`torch.BoolTensor` of shape `(batch_size,)`):
Current per-batch finished flags.
eos_token_id (`int`):
EOS token ID.
mask_token_id (`int`):
Mask token ID.
prompt_length (`int`):
Number of prompt tokens at the start of the sequence.
Returns:
`torch.BoolTensor`: Updated finished flags.
"""
batch_size = cur_x.shape[0]
for b in range(batch_size):
if finished[b]:
continue
eos_in_commits = (sampled_tokens[b][final_transfer[b]] == eos_token_id).any().item()
if not eos_in_commits:
continue
eos_pos = (cur_x[b] == eos_token_id).nonzero(as_tuple=True)
if len(eos_pos[0]) == 0:
continue
eos_pos = int(eos_pos[0][0].item())
if prompt_length >= eos_pos:
continue
if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item():
finished[b] = True
return finished
def check_block_should_continue(
self,
step_idx: int,
masks_remaining: bool,
editing_enabled: bool,
editing_transfer_index: torch.BoolTensor,
post_steps: int,
max_post_steps: int,
finished: torch.BoolTensor,
) -> bool:
"""
Determine whether the inner refinement loop should continue for the current block.
Args:
step_idx (`int`):
Current refinement step index within this block.
masks_remaining (`bool`):
Whether any mask tokens remain in the block.
editing_enabled (`bool`):
Whether editing mode is active.
editing_transfer_index (`torch.BoolTensor`):
Which tokens were edited in this step.
post_steps (`int`):
Number of post-mask editing steps taken so far.
max_post_steps (`int`):
Maximum allowed post-mask editing steps.
finished (`torch.BoolTensor`):
Per-batch finished flags (from EOS detection).
Returns:
`bool`: `True` if refinement should continue, `False` to break.
"""
if finished.all():
return False
if not masks_remaining and not editing_enabled:
return False
if not masks_remaining and not editing_transfer_index.any():
return False
if masks_remaining and step_idx >= self.num_inference_steps:
return False
if not masks_remaining and post_steps > max_post_steps:
return False
return True
def add_noise(
self,
original_samples: torch.LongTensor,
attention_mask: torch.LongTensor,
*,
prompt_length: int,
block_length: int,
mask_token_id: int,
generator: torch.Generator | None = None,
) -> tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]:
"""
Apply the forward (noising) process for semi-autoregressive block masking.
For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with
`mask_token_id`. Two complementary views are returned: `noisy` and `noisy_rev`, where the masked positions in
one are the unmasked positions in the other.
Args:
original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`):
Clean token IDs.
attention_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
Padding mask (1 for valid, 0 for padding).
prompt_length (`int`):
Number of leading prompt tokens to keep unmasked.
block_length (`int`):
Block size for masking.
mask_token_id (`int`):
Token ID to use for masked positions.
generator (`torch.Generator`, *optional*):
RNG for reproducibility.
Returns:
`tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]`:
`(noisy, noisy_rev, masked, masked_rev)` — the two complementary noisy sequences and their
corresponding boolean masks.
"""
batch_size, seq_len = original_samples.shape
device = original_samples.device
noisy = original_samples.clone()
noisy_rev = original_samples.clone()
masked = torch.zeros_like(original_samples, dtype=torch.bool)
masked_rev = torch.zeros_like(original_samples, dtype=torch.bool)
valid = attention_mask.to(dtype=torch.bool)
for block_start in range(prompt_length, seq_len, block_length):
block_end = min(seq_len, block_start + block_length)
seg_len = block_end - block_start
if seg_len <= 0:
continue
p_mask = torch.rand((batch_size, 1), device=device, generator=generator)
seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask
seg = seg & valid[:, block_start:block_end]
seg_rev = (~seg) & valid[:, block_start:block_end]
masked[:, block_start:block_end] = seg
masked_rev[:, block_start:block_end] = seg_rev
noisy = torch.where(masked, torch.full_like(noisy, mask_token_id), noisy)
noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, mask_token_id), noisy_rev)
return noisy, noisy_rev, masked, masked_rev
__all__ = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]

View File

@@ -11,6 +11,7 @@ from typing import Any, Iterable
import numpy as np
import torch
import torch.nn.functional as F
if getattr(torch, "distributed", None) is not None:
@@ -109,6 +110,92 @@ def compute_snr(noise_scheduler, timesteps):
return snr
def compute_confidence_aware_loss(
logits: torch.Tensor,
labels: torch.Tensor,
*,
lambda_conf: float = 0.0,
temperature: float = 1.0,
per_token_weights: torch.Tensor | None = None,
ignore_index: int = -100,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes a confidence-aware training loss for token classification-style heads.
This loss combines:
- `loss_sft`: standard supervised cross-entropy on all non-ignored labels.
- `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly.
Args:
logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`.
labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index`
are excluded from both losses.
lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term.
temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values
sharpen the distribution and change the strength of the confidence gradients.
per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per
token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing.
ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels.
Returns:
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`.
"""
if logits.ndim < 2:
raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.")
if labels.shape != logits.shape[:-1]:
raise ValueError(
f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}."
)
if temperature <= 0:
raise ValueError(f"`temperature` must be > 0, got {temperature}.")
valid = labels.ne(ignore_index)
if per_token_weights is None:
weights = torch.ones_like(labels, dtype=logits.dtype)
else:
if per_token_weights.shape != labels.shape:
raise ValueError(
f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}."
)
weights = per_token_weights.to(dtype=logits.dtype)
# Supervised CE (optionally weighted).
vocab_size = logits.shape[-1]
per_token_nll = F.cross_entropy(
logits.reshape(-1, vocab_size),
labels.reshape(-1),
reduction="none",
ignore_index=ignore_index,
).reshape_as(labels)
denom_sft = (weights * valid.to(weights.dtype)).sum().clamp_min(1)
loss_sft = (per_token_nll * weights * valid.to(per_token_nll.dtype)).sum() / denom_sft
# Confidence loss: penalize entropy only where prediction is already correct.
if lambda_conf == 0.0:
loss_conf = torch.zeros((), device=logits.device, dtype=loss_sft.dtype)
return loss_sft, loss_sft, loss_conf
with torch.no_grad():
pred = logits.argmax(dim=-1)
correct = valid & pred.eq(labels)
scaled_logits = logits.float()
if temperature != 1.0:
scaled_logits = scaled_logits / float(temperature)
probs = torch.softmax(scaled_logits, dim=-1)
eps = torch.finfo(probs.dtype).tiny
log_probs = torch.log(probs.clamp_min(eps))
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
denom_conf = (weights * correct.to(weights.dtype)).sum().clamp_min(1)
loss_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / denom_conf
loss = loss_sft + float(lambda_conf) * loss_conf
return loss, loss_sft, loss_conf
def resolve_interpolation_mode(interpolation_type: str):
"""
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The

View File

@@ -2518,6 +2518,36 @@ class AmusedScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class BlockRefinementScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class BlockRefinementSchedulerOutput(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CMStochasticIterativeScheduler(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -2222,6 +2222,36 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LLaDA2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LLaDA2PipelineOutput(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LongCatImageEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -44,9 +44,9 @@ class AutoencoderTesterMixin:
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_enable_disable_tiling(self):
if not hasattr(self.model_class, "enable_tiling"):

View File

@@ -98,6 +98,64 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _context_parallel_backward_worker(
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict
):
"""Worker function for context parallel backward pass testing."""
try:
# Set up distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
# Get device configuration
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
backend = device_config["backend"]
device_module = device_config["module"]
# Initialize process group
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
# Set device for this process
device_module.set_device(rank)
device = torch.device(f"{torch_device}:{rank}")
# Create model in training mode
model = model_class(**init_dict)
model.to(device)
model.train()
# Move inputs to device
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)
# Run forward and backward pass
output = model(**inputs_on_device, return_dict=False)[0]
loss = output.sum()
loss.backward()
# Check that backward actually produced at least one valid gradient
grads = [p.grad for p in model.parameters() if p.requires_grad and p.grad is not None]
has_valid_grads = len(grads) > 0 and all(torch.isfinite(g).all() for g in grads)
# Only rank 0 reports results
if rank == 0:
return_dict["status"] = "success"
return_dict["has_valid_grads"] = bool(has_valid_grads)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
@@ -204,6 +262,51 @@ class ContextParallelTesterMixin:
def test_context_parallel_batch_inputs(self, cp_type):
self.test_context_parallel_inference(cp_type, batch_size=2)
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_backward(self, cp_type, batch_size: int = 1):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
cp_dict = {cp_type: world_size}
# Find a free port for distributed communication
master_port = _find_free_port()
# Use multiprocessing manager for cross-process communication
manager = mp.Manager()
return_dict = manager.dict()
# Spawn worker processes
mp.spawn(
_context_parallel_backward_worker,
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Context parallel backward pass failed: {return_dict.get('error', 'Unknown error')}"
)
assert return_dict.get("has_valid_grads"), "Context parallel backward pass did not produce valid gradients."
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_backward_batch_inputs(self, cp_type):
self.test_context_parallel_backward(cp_type, batch_size=2)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[

View File

@@ -25,7 +25,6 @@ from diffusers.utils.import_utils import (
is_nvidia_modelopt_available,
is_optimum_quanto_available,
is_torchao_available,
is_torchao_version,
)
from ...testing_utils import (
@@ -63,8 +62,7 @@ if is_gguf_available():
pass
if is_torchao_available():
if is_torchao_version(">=", "0.9.0"):
pass
import torchao.quantization as _torchao_quantization
class LoRALayer(torch.nn.Module):
@@ -806,9 +804,9 @@ class TorchAoConfigMixin:
"""
TORCHAO_QUANT_TYPES = {
"int4wo": {"quant_type": "int4_weight_only"},
"int8wo": {"quant_type": "int8_weight_only"},
"int8dq": {"quant_type": "int8_dynamic_activation_int8_weight"},
"int4wo": "Int4WeightOnlyConfig",
"int8wo": "Int8WeightOnlyConfig",
"int8dq": "Int8DynamicActivationInt8WeightConfig",
}
TORCHAO_EXPECTED_MEMORY_REDUCTIONS = {
@@ -817,8 +815,13 @@ class TorchAoConfigMixin:
"int8dq": 1.5,
}
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
config = TorchAoConfig(**config_kwargs)
@staticmethod
def _get_quant_config(config_name):
config_cls = getattr(_torchao_quantization, config_name)
return TorchAoConfig(config_cls())
def _create_quantized_model(self, config_name, **extra_kwargs):
config = self._get_quant_config(config_name)
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
kwargs["quantization_config"] = config
kwargs["device_map"] = str(torch_device)

View File

@@ -0,0 +1,242 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers.modular_pipelines import (
AutoPipelineBlocks,
ConditionalPipelineBlocks,
InputParam,
ModularPipelineBlocks,
)
class TextToImageBlock(ModularPipelineBlocks):
model_name = "text2img"
@property
def inputs(self):
return [InputParam(name="prompt")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "text-to-image workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "text2img"
self.set_block_state(state, block_state)
return components, state
class ImageToImageBlock(ModularPipelineBlocks):
model_name = "img2img"
@property
def inputs(self):
return [InputParam(name="prompt"), InputParam(name="image")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "image-to-image workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "img2img"
self.set_block_state(state, block_state)
return components, state
class InpaintBlock(ModularPipelineBlocks):
model_name = "inpaint"
@property
def inputs(self):
return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "inpaint workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "inpaint"
self.set_block_state(state, block_state)
return components, state
class ConditionalImageBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = "text2img"
@property
def description(self):
return "Conditional image blocks for testing"
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None # falls back to default_block_name
class OptionalConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = None # no default; block can be skipped
@property
def description(self):
return "Optional conditional blocks (skippable)"
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None
class AutoImageBlocks(AutoPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image", None]
@property
def description(self):
return "Auto image blocks for testing"
class TestConditionalPipelineBlocksSelectBlock:
def test_select_block_with_mask(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask="something") == "inpaint"
def test_select_block_with_image(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(image="something") == "img2img"
def test_select_block_with_mask_and_image(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask="m", image="i") == "inpaint"
def test_select_block_no_triggers_returns_none(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block() is None
def test_select_block_explicit_none_values(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask=None, image=None) is None
class TestConditionalPipelineBlocksWorkflowSelection:
def test_default_workflow_when_no_triggers(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks()
assert execution is not None
assert isinstance(execution, TextToImageBlock)
def test_mask_trigger_selects_inpaint(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(mask=True)
assert isinstance(execution, InpaintBlock)
def test_image_trigger_selects_img2img(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
def test_mask_and_image_selects_inpaint(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(mask=True, image=True)
assert isinstance(execution, InpaintBlock)
def test_skippable_block_returns_none(self):
blocks = OptionalConditionalBlocks()
execution = blocks.get_execution_blocks()
assert execution is None
def test_skippable_block_still_selects_when_triggered(self):
blocks = OptionalConditionalBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
class TestAutoPipelineBlocksSelectBlock:
def test_auto_select_mask(self):
blocks = AutoImageBlocks()
assert blocks.select_block(mask="m") == "inpaint"
def test_auto_select_image(self):
blocks = AutoImageBlocks()
assert blocks.select_block(image="i") == "img2img"
def test_auto_select_default(self):
blocks = AutoImageBlocks()
# No trigger -> returns None -> falls back to default (text2img)
assert blocks.select_block() is None
def test_auto_select_priority_order(self):
blocks = AutoImageBlocks()
assert blocks.select_block(mask="m", image="i") == "inpaint"
class TestAutoPipelineBlocksWorkflowSelection:
def test_auto_default_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks()
assert isinstance(execution, TextToImageBlock)
def test_auto_mask_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks(mask=True)
assert isinstance(execution, InpaintBlock)
def test_auto_image_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
class TestConditionalPipelineBlocksStructure:
def test_block_names_accessible(self):
blocks = ConditionalImageBlocks()
sub = dict(blocks.sub_blocks)
assert set(sub.keys()) == {"inpaint", "img2img", "text2img"}
def test_sub_block_types(self):
blocks = ConditionalImageBlocks()
sub = dict(blocks.sub_blocks)
assert isinstance(sub["inpaint"], InpaintBlock)
assert isinstance(sub["img2img"], ImageToImageBlock)
assert isinstance(sub["text2img"], TextToImageBlock)
def test_description(self):
blocks = ConditionalImageBlocks()
assert "Conditional" in blocks.description

View File

@@ -10,11 +10,6 @@ from huggingface_hub import hf_hub_download
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines import (
ConditionalPipelineBlocks,
LoopSequentialPipelineBlocks,
SequentialPipelineBlocks,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
@@ -25,7 +20,6 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
from diffusers.utils import logging
from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
@@ -498,117 +492,6 @@ class ModularGuiderTesterMixin:
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(str(tmp_path))
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
class TestModularModelCardContent:
def create_mock_block(self, name="TestBlock", description="Test block description"):
class MockBlock:

View File

@@ -24,14 +24,18 @@ import torch
from diffusers import FluxTransformer2DModel
from diffusers.modular_pipelines import (
ComponentSpec,
ConditionalPipelineBlocks,
InputParam,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
OutputParam,
PipelineState,
SequentialPipelineBlocks,
WanModularPipeline,
)
from diffusers.utils import logging
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
from ..testing_utils import CaptureLogger, nightly, require_torch, require_torch_accelerator, slow, torch_device
def _create_tiny_model_dir(model_dir):
@@ -463,6 +467,117 @@ class TestModularCustomBlocks:
assert output_prompt.startswith("Modular diffusers + ")
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(str(tmp_path))
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
@slow
@nightly
@require_torch

View File

@@ -18,7 +18,7 @@ import unittest
import torch
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
from diffusers.training_utils import set_seed
from diffusers.training_utils import compute_confidence_aware_loss, set_seed
from ..testing_utils import slow
@@ -85,3 +85,47 @@ class TrainingTests(unittest.TestCase):
self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5))
self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5))
def test_confidence_aware_loss(self):
logits = torch.tensor([[[5.0, 0.0], [0.0, 5.0]]])
labels = torch.tensor([[0, 0]])
weights = torch.tensor([[1.0, 2.0]])
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
logits, labels, lambda_conf=0.0, per_token_weights=weights
)
self.assertTrue(torch.allclose(loss, loss_sft))
self.assertTrue(torch.allclose(loss_conf, torch.zeros_like(loss_conf)))
lambda_conf = 0.25
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
logits, labels, lambda_conf=lambda_conf, per_token_weights=weights
)
# Manual expected values for the small 2-class case.
per_token_nll = torch.nn.functional.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction="none").view(
1, 2
)
expected_sft = (per_token_nll * weights).sum() / weights.sum()
pred = logits.argmax(dim=-1)
correct = pred.eq(labels)
log_probs = torch.log_softmax(logits.float(), dim=-1)
probs = log_probs.exp()
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
expected_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / (
weights * correct.to(weights.dtype)
).sum().clamp_min(1)
expected = expected_sft + lambda_conf * expected_conf
self.assertTrue(torch.allclose(loss_sft, expected_sft))
self.assertTrue(torch.allclose(loss_conf, expected_conf))
self.assertTrue(torch.allclose(loss, expected))
# Temperature affects only the confidence term.
loss_t, loss_sft_t, loss_conf_t = compute_confidence_aware_loss(
logits, labels, lambda_conf=lambda_conf, temperature=0.5, per_token_weights=weights
)
self.assertTrue(torch.allclose(loss_sft_t, expected_sft))
self.assertFalse(torch.allclose(loss_conf_t, expected_conf))
self.assertTrue(torch.allclose(loss_t, loss_sft_t + lambda_conf * loss_conf_t))

View File

View File

@@ -0,0 +1,245 @@
import unittest
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
class _DummyModelOutput:
def __init__(self, logits):
self.logits = logits
class _DummyCausalLM(torch.nn.Module):
def __init__(self, vocab_size: int):
super().__init__()
self.vocab_size = int(vocab_size)
self.register_buffer("_device_anchor", torch.empty(0))
@property
def dtype(self):
return torch.float32
@property
def device(self):
return self._device_anchor.device
def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs):
batch_size, seq_len = input_ids.shape
logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32)
# Make confidence vary with token position so top-k commits are deterministic.
positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.float32).view(1, seq_len, 1)
token_ids = (torch.arange(seq_len, device=input_ids.device) % (self.vocab_size - 2)).view(1, seq_len, 1)
logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1)
return _DummyModelOutput(logits=logits)
def _make_pipeline(tokenizer=None):
model = _DummyCausalLM(vocab_size=32)
scheduler = BlockRefinementScheduler()
return LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
class LLaDA2PipelineTest(unittest.TestCase):
def test_pipeline_runs(self):
pipe = _make_pipeline().to("cpu")
input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long)
out = pipe(
input_ids=input_ids,
use_chat_template=False,
gen_length=24,
block_length=8,
num_inference_steps=8,
temperature=0.0,
threshold=2.0, # force top-k commits
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
eos_token_id=None,
output_type="seq",
)
self.assertEqual(out.sequences.shape, (2, 24))
self.assertFalse((out.sequences == 31).any().item())
def test_pipeline_return_tuple(self):
pipe = _make_pipeline().to("cpu")
input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long)
sequences, texts = pipe(
input_ids=input_ids,
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="seq",
return_dict=False,
)
self.assertEqual(sequences.shape, (1, 16))
self.assertIsNone(texts)
def test_output_type_seq(self):
"""output_type='seq' should return sequences but no texts."""
pipe = _make_pipeline().to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="seq",
)
self.assertIsNotNone(out.sequences)
self.assertEqual(out.sequences.shape, (1, 16))
self.assertIsNone(out.texts)
def test_output_type_text_without_tokenizer(self):
"""output_type='text' without a tokenizer should return texts=None."""
pipe = _make_pipeline(tokenizer=None).to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="text",
)
self.assertIsNotNone(out.sequences)
self.assertIsNone(out.texts)
def test_output_type_text_with_tokenizer(self):
"""output_type='text' with a tokenizer should return decoded texts."""
tok = type(
"Tok",
(),
{
"eos_token_id": None,
"mask_token_id": 31,
"batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs],
},
)()
pipe = _make_pipeline(tokenizer=tok).to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
output_type="text",
)
self.assertIsNotNone(out.sequences)
self.assertIsNotNone(out.texts)
self.assertEqual(len(out.texts), 1)
self.assertTrue(out.texts[0].startswith("decoded_"))
def test_output_type_invalid_raises(self):
"""Invalid output_type should raise ValueError."""
pipe = _make_pipeline().to("cpu")
with self.assertRaises(ValueError):
pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
mask_token_id=31,
output_type="invalid",
)
def test_prepare_input_ids_from_tensor(self):
pipe = _make_pipeline()
ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
result = pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=ids,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
self.assertTrue(torch.equal(result, ids))
def test_prepare_input_ids_from_1d_tensor(self):
pipe = _make_pipeline()
ids = torch.tensor([1, 2, 3], dtype=torch.long)
result = pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=ids,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
self.assertEqual(result.shape, (1, 3))
def test_prepare_input_ids_no_tokenizer_raises(self):
pipe = _make_pipeline(tokenizer=None)
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt="hello",
messages=None,
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
def test_prepare_input_ids_both_prompt_and_messages_raises(self):
pipe = _make_pipeline()
# Manually set tokenizer to a simple object so _prepare_input_ids doesn't short-circuit
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt="hello",
messages=[{"role": "user", "content": "hi"}],
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
def test_prepare_input_ids_neither_raises(self):
pipe = _make_pipeline()
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -1534,14 +1534,18 @@ class PipelineTesterMixin:
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
model_devices = [
component.device.type for component in components.values() if getattr(component, "device", None)
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to(torch_device)
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
model_devices = [
component.device.type for component in components.values() if getattr(component, "device", None)
]
self.assertTrue(all(device == torch_device for device in model_devices))
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
@@ -1552,11 +1556,11 @@ class PipelineTesterMixin:
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):

View File

@@ -14,13 +14,11 @@
# limitations under the License.
import gc
import importlib.metadata
import tempfile
import unittest
from typing import List
import numpy as np
from packaging import version
from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -55,6 +53,20 @@ from ..test_torch_compile_utils import QuantCompileTests
enable_full_determinism()
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
if is_torch_available():
import torch
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
if major == 8:
return minor >= 9
return major >= 9
elif torch.xpu.is_available():
return True
return False
if is_torch_available():
import torch
import torch.nn as nn
@@ -64,75 +76,56 @@ if is_torch_available():
if is_torchao_available():
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization import (
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
)
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.9.0"):
from torchao.quantization import Int8WeightOnlyConfig
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Makes sure the config format is properly set
"""
quantization_config = TorchAoConfig("int4_weight_only")
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(version=2))
torchao_orig_config = quantization_config.to_dict()
for key in torchao_orig_config:
self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key])
self.assertIn("quant_type", torchao_orig_config)
self.assertIn("quant_method", torchao_orig_config)
def test_post_init_check(self):
"""
Test kwargs validations in TorchAoConfig
Test that non-AOBaseConfig types are rejected
"""
_ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "is not supported"):
_ = TorchAoConfig("uint8")
_ = TorchAoConfig(Int4WeightOnlyConfig())
with self.assertRaises(TypeError):
_ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"):
_ = TorchAoConfig("int4_weight_only", group_size1=32)
with self.assertRaises(TypeError):
_ = TorchAoConfig(42)
def test_repr(self):
"""
Check that there is no error in the repr
"""
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
expected_repr = """TorchAoConfig {
"modules_to_not_convert": [
"conv"
],
"quant_method": "torchao",
"quant_type": "int4_weight_only",
"quant_type_kwargs": {
"group_size": 8
}
}""".replace(" ", "").replace("\n", "")
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
self.assertEqual(quantization_repr, expected_repr)
quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC)
expected_repr = """TorchAoConfig {
"modules_to_not_convert": null,
"quant_method": "torchao",
"quant_type": "int4dq",
"quant_type_kwargs": {
"act_mapping_type": "SYMMETRIC",
"group_size": 64
}
}""".replace(" ", "").replace("\n", "")
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "")
self.assertEqual(quantization_repr, expected_repr)
quantization_config = TorchAoConfig(Int8WeightOnlyConfig(version=2), modules_to_not_convert=["conv"])
quantization_repr = repr(quantization_config)
self.assertIn("TorchAoConfig", quantization_repr)
self.assertIn("torchao", quantization_repr)
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
@@ -234,79 +227,30 @@ class TorchAoTest(unittest.TestCase):
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
# fmt: off
QUANTIZATION_TYPES_TO_TEST = [
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
(Int4WeightOnlyConfig(version=2), np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])),
(Int8DynamicActivationIntxWeightConfig(version=2), np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])),
(Int8WeightOnlyConfig(version=2), np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
(Int8DynamicActivationInt8WeightConfig(version=2), np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
(IntxWeightOnlyConfig(dtype=torch.uint4, group_size=16, version=2), np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])),
(IntxWeightOnlyConfig(dtype=torch.uint7, group_size=16, version=2), np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
]
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
if _is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
# =====
# The following lead to an internal torch error:
# RuntimeError: mat2 shape (32x4 must be divisible by 16
# Skip these for now; TODO(aryan): investigate later
# ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
# Cutlass fails to initialize for below
# ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])),
# =====
(Float8WeightOnlyConfig(weight_dtype=torch.float8_e5m2), np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
(Float8WeightOnlyConfig(weight_dtype=torch.float8_e4m3fn), np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
])
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
QUANTIZATION_TYPES_TO_TEST.extend([
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])),
])
# fmt: on
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quant_kwargs = {}
if quantization_name in ["uint4wo", "uint7wo"]:
# The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
quant_kwargs.update({"group_size": 16})
quantization_config = TorchAoConfig(
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs
)
for quant_config, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quantization_config = TorchAoConfig(quant_type=quant_config, modules_to_not_convert=["x_embedder"])
self._test_quant_type(quantization_config, expected_slice, model_id)
@unittest.skip("Skipping floatx quantization tests")
def test_floatx_quantization(self):
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
self._test_quant_type(
quantization_config,
np.array(
[
0.4648,
0.5195,
0.5547,
0.4180,
0.4434,
0.6445,
0.4316,
0.4531,
0.5625,
]
),
model_id,
)
else:
# Make sure the correct error is thrown
with self.assertRaisesRegex(ValueError, "Please downgrade"):
quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"])
def test_int4wo_quant_bfloat16_conversion(self):
"""
Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization.
"""
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64))
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
@@ -361,7 +305,7 @@ class TorchAoTest(unittest.TestCase):
else:
expected_slice = expected_slice_offload
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64))
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
@@ -385,7 +329,7 @@ class TorchAoTest(unittest.TestCase):
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=64))
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
@@ -406,7 +350,7 @@ class TorchAoTest(unittest.TestCase):
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
quantization_config = TorchAoConfig(Int8WeightOnlyConfig(), modules_to_not_convert=["transformer_blocks.0"])
quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
@@ -422,7 +366,7 @@ class TorchAoTest(unittest.TestCase):
quantized_layer = quantized_model_with_not_convert.proj_out
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
quantization_config = TorchAoConfig("int8_weight_only")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
@@ -436,7 +380,7 @@ class TorchAoTest(unittest.TestCase):
self.assertTrue(size_quantized < size_quantized_with_not_convert)
def test_training(self):
quantization_config = TorchAoConfig("int8_weight_only")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
@@ -470,7 +414,7 @@ class TorchAoTest(unittest.TestCase):
def test_torch_compile(self):
r"""Test that verifies if torch.compile works with torchao quantization."""
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
quantization_config = TorchAoConfig("int8_weight_only")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
components = self.get_dummy_components(quantization_config, model_id=model_id)
pipe = FluxPipeline(**components)
pipe.to(device=torch_device)
@@ -491,11 +435,15 @@ class TorchAoTest(unittest.TestCase):
memory footprint of the converted model and the class type of the linear layers of the converted models
"""
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"]
transformer_int4wo = self.get_dummy_components(TorchAoConfig(Int4WeightOnlyConfig()), model_id=model_id)[
"transformer"
]
transformer_int4wo_gs32 = self.get_dummy_components(
TorchAoConfig("int4wo", group_size=32), model_id=model_id
TorchAoConfig(Int4WeightOnlyConfig(group_size=32)), model_id=model_id
)["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()), model_id=model_id)[
"transformer"
]
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
# Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64
@@ -553,20 +501,22 @@ class TorchAoTest(unittest.TestCase):
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
del transformer_bf16
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig(Int8WeightOnlyConfig()), model_id=model_id)[
"transformer"
]
transformer_int8wo.to(torch_device)
quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
def test_wrong_config(self):
with self.assertRaises(ValueError):
with self.assertRaises(TypeError):
self.get_dummy_components(TorchAoConfig("int42"))
def test_sequential_cpu_offload(self):
r"""
A test that checks if inference runs as expected when sequential cpu offloading is enabled.
"""
quantization_config = TorchAoConfig("int8wo")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components)
pipe.enable_sequential_cpu_offload()
@@ -574,7 +524,7 @@ class TorchAoTest(unittest.TestCase):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)
@require_torchao_version_greater_or_equal("0.9.0")
@require_torchao_version_greater_or_equal("0.15.0")
def test_aobase_config(self):
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
components = self.get_dummy_components(quantization_config)
@@ -587,7 +537,7 @@ class TorchAoTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
@@ -595,8 +545,8 @@ class TorchAoSerializationTest(unittest.TestCase):
gc.collect()
backend_empty_cache(torch_device)
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
def get_dummy_model(self, quant_type, device=None):
quantization_config = TorchAoConfig(quant_type)
quantized_model = FluxTransformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
@@ -632,8 +582,8 @@ class TorchAoSerializationTest(unittest.TestCase):
"timestep": timestep,
}
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device)
def _test_original_model_expected_slice(self, quant_type, expected_slice):
quantized_model = self.get_dummy_model(quant_type, torch_device)
inputs = self.get_dummy_tensor_inputs(torch_device)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
@@ -641,8 +591,8 @@ class TorchAoSerializationTest(unittest.TestCase):
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)))
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device):
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device)
def _check_serialization_expected_slice(self, quant_type, expected_slice, device):
quantized_model = self.get_dummy_model(quant_type, device)
with tempfile.TemporaryDirectory() as tmp_dir:
quantized_model.save_pretrained(tmp_dir, safe_serialization=False)
@@ -662,43 +612,42 @@ class TorchAoSerializationTest(unittest.TestCase):
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def test_int_a8w8_accelerator(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
quant_type = Int8DynamicActivationInt8WeightConfig()
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
self._test_original_model_expected_slice(quant_type, expected_slice)
self._check_serialization_expected_slice(quant_type, expected_slice, device)
def test_int_a16w8_accelerator(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
quant_type = Int8WeightOnlyConfig()
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
self._test_original_model_expected_slice(quant_type, expected_slice)
self._check_serialization_expected_slice(quant_type, expected_slice, device)
def test_int_a8w8_cpu(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
quant_type = Int8DynamicActivationInt8WeightConfig()
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cpu"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
self._test_original_model_expected_slice(quant_type, expected_slice)
self._check_serialization_expected_slice(quant_type, expected_slice, device)
def test_int_a16w8_cpu(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
quant_type = Int8WeightOnlyConfig()
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cpu"
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
self._test_original_model_expected_slice(quant_type, expected_slice)
self._check_serialization_expected_slice(quant_type, expected_slice, device)
@require_torchao_version_greater_or_equal("0.9.0")
def test_aobase_config(self):
quant_method, quant_method_kwargs = Int8WeightOnlyConfig(), {}
quant_type = Int8WeightOnlyConfig()
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
self._test_original_model_expected_slice(quant_type, expected_slice)
self._check_serialization_expected_slice(quant_type, expected_slice, device)
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
@@ -744,7 +693,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
@@ -817,29 +766,25 @@ class SlowTorchAoTests(unittest.TestCase):
def test_quantization(self):
# fmt: off
QUANTIZATION_TYPES_TO_TEST = [
("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
(Int8WeightOnlyConfig(), np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])),
(Int8DynamicActivationInt8WeightConfig(), np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
]
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
if _is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
(Float8WeightOnlyConfig(weight_dtype=torch.float8_e4m3fn), np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
])
if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"):
QUANTIZATION_TYPES_TO_TEST.extend([
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
])
# fmt: on
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
for quant_config, expected_slice in QUANTIZATION_TYPES_TO_TEST:
quantization_config = TorchAoConfig(quant_type=quant_config, modules_to_not_convert=["x_embedder"])
self._test_quant_type(quantization_config, expected_slice)
gc.collect()
backend_empty_cache(torch_device)
backend_synchronize(torch_device)
def test_serialization_int8wo(self):
quantization_config = TorchAoConfig("int8wo")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components)
pipe.enable_model_cpu_offload()
@@ -876,7 +821,7 @@ class SlowTorchAoTests(unittest.TestCase):
def test_memory_footprint_int4wo(self):
# The original checkpoints are in bf16 and about 24 GB
expected_memory_in_gb = 6.0
quantization_config = TorchAoConfig("int4wo")
quantization_config = TorchAoConfig(Int4WeightOnlyConfig())
cache_dir = None
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
@@ -891,7 +836,7 @@ class SlowTorchAoTests(unittest.TestCase):
def test_memory_footprint_int8wo(self):
# The original checkpoints are in bf16 and about 24 GB
expected_memory_in_gb = 12.0
quantization_config = TorchAoConfig("int8wo")
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
cache_dir = None
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
@@ -906,7 +851,7 @@ class SlowTorchAoTests(unittest.TestCase):
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):

View File

@@ -0,0 +1,470 @@
import tempfile
import unittest
import torch
from diffusers import BlockRefinementScheduler
class BlockRefinementSchedulerTest(unittest.TestCase):
def get_scheduler(self, **kwargs):
config = {
"block_length": 32,
"num_inference_steps": 8,
"threshold": 0.95,
"editing_threshold": None,
"minimal_topk": 1,
}
config.update(kwargs)
return BlockRefinementScheduler(**config)
def _make_logits_from_probs(self, target_probs: torch.Tensor, vocab_size: int = 100) -> torch.Tensor:
"""Create logits where softmax of the target token has approximately the given probability."""
batch_size, block_length = target_probs.shape
logits = torch.zeros(batch_size, block_length, vocab_size)
# Set token 0 as the "predicted" token with a logit proportional to desired probability
for b in range(batch_size):
for t in range(block_length):
p = target_probs[b, t].item()
if p > 0:
logits[b, t, t % (vocab_size - 1)] = 10.0 * p
return logits
def test_set_timesteps(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
self.assertEqual(scheduler.num_inference_steps, 8)
self.assertEqual(len(scheduler.timesteps), 8)
self.assertEqual(scheduler.timesteps[0].item(), 7)
self.assertEqual(scheduler.timesteps[-1].item(), 0)
def test_set_timesteps_invalid(self):
scheduler = self.get_scheduler()
with self.assertRaises(ValueError):
scheduler.set_timesteps(0)
def test_get_num_transfer_tokens_even(self):
scheduler = self.get_scheduler()
schedule = scheduler.get_num_transfer_tokens(block_length=32, num_inference_steps=8)
self.assertEqual(schedule.sum().item(), 32)
self.assertEqual(len(schedule), 8)
self.assertTrue((schedule == 4).all().item())
def test_get_num_transfer_tokens_remainder(self):
scheduler = self.get_scheduler()
schedule = scheduler.get_num_transfer_tokens(block_length=10, num_inference_steps=3)
self.assertEqual(schedule.sum().item(), 10)
self.assertEqual(len(schedule), 3)
self.assertEqual(schedule[0].item(), 4)
self.assertEqual(schedule[1].item(), 3)
self.assertEqual(schedule[2].item(), 3)
def test_transfer_schedule_created_on_set_timesteps(self):
scheduler = self.get_scheduler(block_length=16)
scheduler.set_timesteps(4)
self.assertIsNotNone(scheduler._transfer_schedule)
self.assertEqual(scheduler._transfer_schedule.sum().item(), 16)
def test_save_load_config_round_trip(self):
scheduler = self.get_scheduler(block_length=64, threshold=0.8, editing_threshold=0.5, minimal_topk=2)
with tempfile.TemporaryDirectory() as tmpdir:
scheduler.save_config(tmpdir)
loaded = BlockRefinementScheduler.from_pretrained(tmpdir)
self.assertEqual(loaded.config.block_length, 64)
self.assertEqual(loaded.config.threshold, 0.8)
self.assertEqual(loaded.config.editing_threshold, 0.5)
self.assertEqual(loaded.config.minimal_topk, 2)
def test_from_config(self):
scheduler = self.get_scheduler(block_length=16, threshold=0.7)
new_scheduler = BlockRefinementScheduler.from_config(scheduler.config)
self.assertEqual(new_scheduler.config.block_length, 16)
self.assertEqual(new_scheduler.config.threshold, 0.7)
def test_step_commits_tokens(self):
"""Verify that step() commits mask tokens based on confidence."""
scheduler = self.get_scheduler(block_length=8)
scheduler.set_timesteps(2)
batch_size, block_length, vocab_size = 1, 8, 32
mask_id = 31
sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long)
# Create logits where confidence decreases with position
logits = torch.zeros(batch_size, block_length, vocab_size)
for i in range(block_length):
logits[0, i, i] = 10.0 - i # decreasing confidence
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=mask_id,
temperature=0.0,
threshold=0.95,
return_dict=True,
)
# With 8 tokens and 2 steps, first step should commit 4 tokens
committed = out.transfer_index[0].sum().item()
self.assertEqual(committed, 4)
def test_step_no_editing_by_default(self):
"""Without editing_threshold, no non-mask tokens should be changed."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
logits = torch.zeros(1, 4, vocab_size)
logits[0, :, 15] = 10.0 # predict token 15 for all positions
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
editing_threshold=None,
return_dict=True,
)
self.assertFalse(out.editing_transfer_index.any().item())
self.assertFalse(out.transfer_index[0, 0].item())
self.assertFalse(out.transfer_index[0, 1].item())
def test_step_editing_replaces_tokens(self):
"""With editing_threshold, non-mask tokens with high confidence and different prediction get replaced."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
logits = torch.zeros(1, 4, vocab_size)
# Token 0: predict 50 (different from 10) with very high logit
logits[0, 0, 15] = 20.0
# Token 1: predict 20 (same as current)
logits[0, 1, 20] = 20.0
# Mask tokens
logits[0, 2, 5] = 5.0
logits[0, 3, 6] = 5.0
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
editing_threshold=0.5,
return_dict=True,
)
# Token 0 should be edited (different prediction, high confidence)
self.assertTrue(out.editing_transfer_index[0, 0].item())
# Token 1 should NOT be edited (same prediction)
self.assertFalse(out.editing_transfer_index[0, 1].item())
def test_step_prompt_mask_prevents_editing(self):
"""Prompt positions should never be edited even with editing enabled."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
logits = torch.zeros(1, 4, vocab_size)
logits[0, :, 15] = 20.0
prompt_mask = torch.tensor([True, True, False, False])
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
editing_threshold=0.5,
prompt_mask=prompt_mask,
return_dict=True,
)
self.assertFalse(out.editing_transfer_index[0, 0].item())
self.assertFalse(out.editing_transfer_index[0, 1].item())
def test_step_return_tuple(self):
"""Verify tuple output when return_dict=False."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.full((1, 4), 31, dtype=torch.long)
logits = torch.randn(1, 4, vocab_size)
result = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
return_dict=False,
)
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 5)
def test_step_batched(self):
"""Verify step works with batch_size > 1."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
batch_size, vocab_size = 3, 32
mask_id = 31
sample = torch.full((batch_size, 4), mask_id, dtype=torch.long)
logits = torch.randn(batch_size, 4, vocab_size)
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=mask_id,
temperature=0.0,
return_dict=True,
)
self.assertEqual(out.prev_sample.shape, (batch_size, 4))
self.assertEqual(out.transfer_index.shape, (batch_size, 4))
def test_check_block_should_continue_finished(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
finished = torch.tensor([True, True])
result = scheduler.check_block_should_continue(
step_idx=0,
masks_remaining=True,
editing_enabled=False,
editing_transfer_index=torch.zeros(2, 32, dtype=torch.bool),
post_steps=0,
max_post_steps=16,
finished=finished,
)
self.assertFalse(result)
def test_check_block_should_continue_no_masks_no_edits(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
finished = torch.tensor([False])
result = scheduler.check_block_should_continue(
step_idx=5,
masks_remaining=False,
editing_enabled=True,
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
post_steps=1,
max_post_steps=16,
finished=finished,
)
self.assertFalse(result)
def test_check_block_should_continue_steps_exhausted(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
finished = torch.tensor([False])
result = scheduler.check_block_should_continue(
step_idx=8,
masks_remaining=True,
editing_enabled=False,
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
post_steps=0,
max_post_steps=16,
finished=finished,
)
self.assertFalse(result)
def test_check_eos_finished_marks_batch(self):
"""When EOS is committed and all tokens before it are unmasked, mark batch as finished."""
mask_id, eos_id, prompt_length = 99, 2, 2
# cur_x: [prompt, prompt, token, eos, mask, mask]
cur_x = torch.tensor([[10, 11, 5, eos_id, mask_id, mask_id]], dtype=torch.long)
sampled_tokens = torch.tensor([[0, 0, 0, eos_id]], dtype=torch.long)
final_transfer = torch.tensor([[False, False, False, True]])
finished = torch.tensor([False])
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=cur_x,
sampled_tokens=sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_id,
mask_token_id=mask_id,
prompt_length=prompt_length,
)
self.assertTrue(finished[0].item())
def test_check_eos_finished_ignores_when_masks_before_eos(self):
"""If there are still mask tokens between prompt and EOS, don't mark as finished."""
mask_id, eos_id, prompt_length = 99, 2, 2
# cur_x: [prompt, prompt, mask, eos] — mask before EOS
cur_x = torch.tensor([[10, 11, mask_id, eos_id]], dtype=torch.long)
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
final_transfer = torch.tensor([[False, True]])
finished = torch.tensor([False])
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=cur_x,
sampled_tokens=sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_id,
mask_token_id=mask_id,
prompt_length=prompt_length,
)
self.assertFalse(finished[0].item())
def test_check_eos_finished_already_finished(self):
"""Already-finished batches should stay finished."""
mask_id, eos_id = 99, 2
cur_x = torch.tensor([[10, 11, 5, 6]], dtype=torch.long)
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
final_transfer = torch.tensor([[False, False]])
finished = torch.tensor([True])
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=cur_x,
sampled_tokens=sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_id,
mask_token_id=mask_id,
prompt_length=2,
)
self.assertTrue(finished[0].item())
def test_add_noise(self):
scheduler = self.get_scheduler(block_length=4)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
mask_token_id = 99
gen = torch.Generator().manual_seed(42)
noisy, noisy_rev, masked, masked_rev = scheduler.add_noise(
input_ids,
attention_mask,
prompt_length=2,
block_length=4,
mask_token_id=mask_token_id,
generator=gen,
)
# Prompt positions should never be masked
self.assertFalse(masked[0, 0].item())
self.assertFalse(masked[0, 1].item())
self.assertFalse(masked_rev[0, 0].item())
self.assertFalse(masked_rev[0, 1].item())
# Noisy should have mask_token_id where masked is True
self.assertTrue((noisy[masked] == mask_token_id).all().item())
self.assertTrue((noisy_rev[masked_rev] == mask_token_id).all().item())
# masked and masked_rev should be complementary within valid non-prompt positions
non_prompt = torch.zeros_like(masked)
non_prompt[0, 2:] = True
combined = masked | masked_rev
self.assertTrue((combined[0, 2:] == non_prompt[0, 2:]).all().item())
class TestTopPFiltering(unittest.TestCase):
def test_top_p_filtering(self):
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
filtered = BlockRefinementScheduler._top_p_filtering(logits, top_p=0.5)
self.assertTrue((filtered > torch.finfo(filtered.dtype).min).any())
self.assertTrue((filtered == torch.finfo(filtered.dtype).min).any())
def test_top_p_filtering_none(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=None)
self.assertTrue(torch.equal(result, logits))
def test_top_p_filtering_one(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=1.0)
self.assertTrue(torch.equal(result, logits))
class TestTopKFiltering(unittest.TestCase):
def test_top_k_filtering(self):
logits = torch.tensor([[1.0, 4.0, 2.0, 3.0]])
filtered = BlockRefinementScheduler._top_k_filtering(logits, top_k=2)
self.assertAlmostEqual(filtered[0, 1].item(), 4.0)
self.assertAlmostEqual(filtered[0, 3].item(), 3.0)
self.assertEqual(filtered[0, 0].item(), torch.finfo(filtered.dtype).min)
self.assertEqual(filtered[0, 2].item(), torch.finfo(filtered.dtype).min)
def test_top_k_filtering_none(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=None)
self.assertTrue(torch.equal(result, logits))
def test_top_k_filtering_zero(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=0)
self.assertTrue(torch.equal(result, logits))
def test_top_k_filtering_large_k(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=100)
self.assertTrue(torch.equal(result, logits))
class TestSampleFromLogits(unittest.TestCase):
def test_greedy_sampling(self):
logits = torch.tensor([[1.0, 5.0, 2.0]])
tokens, probs = BlockRefinementScheduler._sample_from_logits(
logits,
temperature=0.0,
top_k=None,
top_p=None,
generator=None,
use_multinomial=False,
)
self.assertEqual(tokens.item(), 1)
self.assertEqual(tokens.shape, (1,))
self.assertEqual(probs.shape, (1,))
def test_multinomial_sampling(self):
logits = torch.tensor([[0.0, 100.0, -100.0]])
gen = torch.Generator().manual_seed(42)
tokens, probs = BlockRefinementScheduler._sample_from_logits(
logits,
temperature=1.0,
top_k=None,
top_p=None,
generator=gen,
use_multinomial=True,
)
self.assertEqual(tokens.item(), 1)
def test_temperature_scaling(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
tokens, _ = BlockRefinementScheduler._sample_from_logits(
logits,
temperature=0.01,
top_k=None,
top_p=None,
generator=None,
use_multinomial=False,
)
self.assertEqual(tokens.item(), 2)
def test_negative_temperature_raises(self):
logits = torch.tensor([[1.0, 2.0]])
with self.assertRaises(ValueError):
BlockRefinementScheduler._sample_from_logits(
logits,
temperature=-1.0,
top_k=None,
top_p=None,
generator=None,
use_multinomial=False,
)
if __name__ == "__main__":
unittest.main()