mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-17 05:55:59 +08:00
Compare commits
7 Commits
tests-cond
...
add-more-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55a7c29305 | ||
|
|
4fac905586 | ||
|
|
e5aa719241 | ||
|
|
4bc1c59a67 | ||
|
|
764f7ede33 | ||
|
|
8d0f3e1ba8 | ||
|
|
094caf398f |
33
.ai/AGENTS.md
Normal file
33
.ai/AGENTS.md
Normal file
@@ -0,0 +1,33 @@
|
||||
# Diffusers — Agent Guide
|
||||
|
||||
## Coding style
|
||||
|
||||
Strive to write code as simple and explicit as possible.
|
||||
|
||||
- Minimize small helper/utility functions — inline the logic instead. A reader should be able to follow the full flow without jumping between functions.
|
||||
- No defensive code or unused code paths — do not add fallback paths, safety checks, or configuration options "just in case". When porting from a research repo, delete training-time code paths, experimental flags, and ablation branches entirely — only keep the inference path you are actually integrating.
|
||||
- Do not guess user intent and silently correct behavior. Make the expected inputs clear in the docstring, and raise a concise error for unsupported cases rather than adding complex fallback logic.
|
||||
|
||||
---
|
||||
|
||||
### Dependencies
|
||||
- No new mandatory dependency without discussion (e.g. `einops`)
|
||||
- Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`
|
||||
|
||||
## Code formatting
|
||||
- `make style` and `make fix-copies` should be run as the final step before opening a PR
|
||||
|
||||
### Copied Code
|
||||
- Many classes are kept in sync with a source via a `# Copied from ...` header comment
|
||||
- Do not edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source
|
||||
- Remove the header to intentionally break the link
|
||||
|
||||
### Models
|
||||
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
|
||||
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations.
|
||||
- See the **model-integration** skill for the attention pattern, pipeline rules, and test setup details.
|
||||
|
||||
## Skills
|
||||
|
||||
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents.
|
||||
Available skills: **model-integration** (adding/converting pipelines), **parity-testing** (debugging numerical parity).
|
||||
272
.ai/skills/model-integration/SKILL.md
Normal file
272
.ai/skills/model-integration/SKILL.md
Normal file
@@ -0,0 +1,272 @@
|
||||
---
|
||||
name: Model Integration
|
||||
description: >
|
||||
Patterns for integrating a new model into diffusers: standard pipeline setup,
|
||||
modular pipeline conversion, file structure templates, checklists, and conventions.
|
||||
Trigger: adding a new model, converting to modular pipeline, setting up file structure.
|
||||
---
|
||||
|
||||
## Standard Pipeline Integration
|
||||
|
||||
### File structure for a new model
|
||||
|
||||
```
|
||||
src/diffusers/
|
||||
models/transformers/transformer_<model>.py # The core model
|
||||
schedulers/scheduling_<model>.py # If model needs a custom scheduler
|
||||
pipelines/<model>/
|
||||
__init__.py
|
||||
pipeline_<model>.py # Main pipeline
|
||||
pipeline_<model>_<variant>.py # Variant pipelines (e.g. pyramid, distilled)
|
||||
pipeline_output.py # Output dataclass
|
||||
loaders/lora_pipeline.py # LoRA mixin (add to existing file)
|
||||
|
||||
tests/
|
||||
models/transformers/test_models_transformer_<model>.py
|
||||
pipelines/<model>/test_<model>.py
|
||||
lora/test_lora_layers_<model>.py
|
||||
|
||||
docs/source/en/api/
|
||||
pipelines/<model>.md
|
||||
models/<model>_transformer3d.md # or appropriate name
|
||||
```
|
||||
|
||||
### Integration checklist
|
||||
|
||||
- [ ] Implement transformer model with `from_pretrained` support
|
||||
- [ ] Implement or reuse scheduler
|
||||
- [ ] Implement pipeline(s) with `__call__` method
|
||||
- [ ] Add LoRA support if applicable
|
||||
- [ ] Register all classes in `__init__.py` files (lazy imports)
|
||||
- [ ] Write unit tests (model, pipeline, LoRA)
|
||||
- [ ] Write docs
|
||||
- [ ] Run `make style` and `make quality`
|
||||
- [ ] Test parity with reference implementation (see `parity-testing` skill)
|
||||
|
||||
### Attention pattern
|
||||
|
||||
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
|
||||
|
||||
```python
|
||||
# transformer_mymodel.py
|
||||
|
||||
class MyModelAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(self, attn, hidden_states, attention_mask=None, ...):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
# reshape, apply rope, etc.
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
return attn.to_out[0](hidden_states)
|
||||
|
||||
|
||||
class MyModelAttention(nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = MyModelAttnProcessor
|
||||
_available_processors = [MyModelAttnProcessor]
|
||||
|
||||
def __init__(self, query_dim, heads=8, dim_head=64, ...):
|
||||
super().__init__()
|
||||
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
|
||||
self.set_processor(MyModelAttnProcessor())
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
||||
return self.processor(self, hidden_states, attention_mask, **kwargs)
|
||||
```
|
||||
|
||||
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
|
||||
|
||||
### Pipeline rules
|
||||
|
||||
- All pipelines must inherit from `DiffusionPipeline`. Consult implementations in `src/diffusers/pipelines` in case you need references.
|
||||
- DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline` which will be a part of the core codebase (`src`).
|
||||
|
||||
### Test setup
|
||||
|
||||
- Slow tests gated with `@slow` and `RUN_SLOW=1`
|
||||
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.
|
||||
|
||||
### Common diffusers conventions
|
||||
|
||||
- Pipelines inherit from `DiffusionPipeline`
|
||||
- Models use `ModelMixin` with `register_to_config` for config serialization
|
||||
- Schedulers use `SchedulerMixin` with `ConfigMixin`
|
||||
- Use `@torch.no_grad()` on pipeline `__call__`
|
||||
- Support `output_type="latent"` for skipping VAE decode
|
||||
- Support `generator` parameter for reproducibility
|
||||
- Use `self.progress_bar(timesteps)` for progress tracking
|
||||
|
||||
---
|
||||
|
||||
## Modular Pipeline Conversion
|
||||
|
||||
### When to use
|
||||
|
||||
Modular pipelines break a monolithic `__call__` into composable blocks. Convert when:
|
||||
- The model supports multiple workflows (T2V, I2V, V2V, etc.)
|
||||
- Users need to swap guidance strategies (CFG, CFG-Zero*, PAG)
|
||||
- You want to share blocks across pipeline variants
|
||||
|
||||
### File structure
|
||||
|
||||
```
|
||||
src/diffusers/modular_pipelines/<model>/
|
||||
__init__.py # Lazy imports
|
||||
modular_pipeline.py # Pipeline class (tiny, mostly config)
|
||||
encoders.py # Text encoder + image/video VAE encoder blocks
|
||||
before_denoise.py # Pre-denoise setup blocks
|
||||
denoise.py # The denoising loop blocks
|
||||
decoders.py # VAE decode block
|
||||
modular_blocks_<model>.py # Block assembly (AutoBlocks)
|
||||
```
|
||||
|
||||
### Block types decision tree
|
||||
|
||||
```
|
||||
Is this a single operation?
|
||||
YES -> ModularPipelineBlocks (leaf block)
|
||||
|
||||
Does it run multiple blocks in sequence?
|
||||
YES -> SequentialPipelineBlocks
|
||||
Does it iterate (e.g. chunk loop)?
|
||||
YES -> LoopSequentialPipelineBlocks
|
||||
|
||||
Does it choose ONE block based on which input is present?
|
||||
Is the selection 1:1 with trigger inputs?
|
||||
YES -> AutoPipelineBlocks (simple trigger mapping)
|
||||
NO -> ConditionalPipelineBlocks (custom select_block method)
|
||||
```
|
||||
|
||||
### Build order (easiest first)
|
||||
|
||||
1. `decoders.py` -- Takes latents, runs VAE decode, returns images/videos
|
||||
2. `encoders.py` -- Takes prompt, returns prompt_embeds. Add image/video VAE encoder if needed
|
||||
3. `before_denoise.py` -- Timesteps, latent prep, noise setup. Each logical operation = one block
|
||||
4. `denoise.py` -- The hardest. Convert guidance to guider abstraction
|
||||
|
||||
### Key pattern: Guider abstraction
|
||||
|
||||
Original pipeline has guidance baked in:
|
||||
```python
|
||||
for i, t in enumerate(timesteps):
|
||||
noise_pred = self.transformer(latents, prompt_embeds, ...)
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(latents, negative_prompt_embeds, ...)
|
||||
noise_pred = noise_uncond + scale * (noise_pred - noise_uncond)
|
||||
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
||||
```
|
||||
|
||||
Modular pipeline separates concerns:
|
||||
```python
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
|
||||
}
|
||||
|
||||
for i, t in enumerate(timesteps):
|
||||
components.guider.set_state(step=i, num_inference_steps=num_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
for batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = {k: getattr(batch, k) for k in guider_inputs}
|
||||
context_name = getattr(batch, components.guider._identifier_key)
|
||||
with components.transformer.cache_context(context_name):
|
||||
batch.noise_pred = components.transformer(
|
||||
hidden_states=latents, timestep=timestep,
|
||||
return_dict=False, **cond_kwargs, **shared_kwargs,
|
||||
)[0]
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
noise_pred = components.guider(guider_state)[0]
|
||||
latents = components.scheduler.step(noise_pred, t, latents, generator=generator)[0]
|
||||
```
|
||||
|
||||
### Key pattern: Chunk loops for video models
|
||||
|
||||
Use `LoopSequentialPipelineBlocks` for outer loop:
|
||||
```python
|
||||
class ChunkDenoiseStep(LoopSequentialPipelineBlocks):
|
||||
block_classes = [PrepareChunkStep, NoiseGenStep, DenoiseInnerStep, UpdateStep]
|
||||
```
|
||||
|
||||
Note: blocks inside `LoopSequentialPipelineBlocks` receive `(components, block_state, k)` where `k` is the loop iteration index.
|
||||
|
||||
### Key pattern: Workflow selection
|
||||
|
||||
```python
|
||||
class AutoDenoise(ConditionalPipelineBlocks):
|
||||
block_classes = [V2VDenoiseStep, I2VDenoiseStep, T2VDenoiseStep]
|
||||
block_trigger_inputs = ["video_latents", "image_latents"]
|
||||
default_block_name = "text2video"
|
||||
```
|
||||
|
||||
### Standard InputParam/OutputParam templates
|
||||
|
||||
```python
|
||||
# Inputs
|
||||
InputParam.template("prompt") # str, required
|
||||
InputParam.template("negative_prompt") # str, optional
|
||||
InputParam.template("image") # PIL.Image, optional
|
||||
InputParam.template("generator") # torch.Generator, optional
|
||||
InputParam.template("num_inference_steps") # int, default=50
|
||||
InputParam.template("latents") # torch.Tensor, optional
|
||||
|
||||
# Outputs
|
||||
OutputParam.template("prompt_embeds")
|
||||
OutputParam.template("negative_prompt_embeds")
|
||||
OutputParam.template("image_latents")
|
||||
OutputParam.template("latents")
|
||||
OutputParam.template("videos")
|
||||
OutputParam.template("images")
|
||||
```
|
||||
|
||||
### ComponentSpec patterns
|
||||
|
||||
```python
|
||||
# Heavy models - loaded from pretrained
|
||||
ComponentSpec("transformer", YourTransformerModel)
|
||||
ComponentSpec("vae", AutoencoderKL)
|
||||
|
||||
# Lightweight objects - created inline from config
|
||||
ComponentSpec("guider", ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config")
|
||||
```
|
||||
|
||||
### Conversion checklist
|
||||
|
||||
- [ ] Read original pipeline's `__call__` end-to-end, map stages
|
||||
- [ ] Write test scripts (reference + target) with identical seeds
|
||||
- [ ] Create file structure under `modular_pipelines/<model>/`
|
||||
- [ ] Write decoder block (simplest)
|
||||
- [ ] Write encoder blocks (text, image, video)
|
||||
- [ ] Write before_denoise blocks (timesteps, latent prep, noise)
|
||||
- [ ] Write denoise block with guider abstraction (hardest)
|
||||
- [ ] Create pipeline class with `default_blocks_name`
|
||||
- [ ] Assemble blocks in `modular_blocks_<model>.py`
|
||||
- [ ] Wire up `__init__.py` with lazy imports
|
||||
- [ ] Run `make style`
|
||||
- [ ] Test all workflows for parity with reference
|
||||
|
||||
---
|
||||
|
||||
## Weight Conversion Tips
|
||||
|
||||
<!-- TODO: Add concrete examples as we encounter them. Common patterns to watch for:
|
||||
- Fused QKV weights that need splitting into separate Q, K, V
|
||||
- Scale/shift ordering differences (reference stores [shift, scale], diffusers expects [scale, shift])
|
||||
- Weight transpositions (linear stored as transposed conv, or vice versa)
|
||||
- Interleaved head dimensions that need reshaping
|
||||
- Bias terms absorbed into different layers
|
||||
Add each with a before/after code snippet showing the conversion. -->
|
||||
387
.ai/skills/parity-testing/SKILL.md
Normal file
387
.ai/skills/parity-testing/SKILL.md
Normal file
@@ -0,0 +1,387 @@
|
||||
---
|
||||
name: Parity Testing
|
||||
description: >
|
||||
Testing pipeline parity between reference and diffusers implementations:
|
||||
checkpoint mechanism, stage tests (encode/decode/denoise), injection debugging,
|
||||
visual comparison, comparison utilities, and 18 common pitfalls.
|
||||
Trigger: debugging parity, writing conversion tests, investigating divergence.
|
||||
---
|
||||
|
||||
## Testing Pipeline Parity
|
||||
|
||||
Applies to any conversion: research repo -> diffusers, standard -> modular, or research repo -> modular.
|
||||
|
||||
## Principles
|
||||
|
||||
1. **Don't combine structural changes with behavioral changes.** For research repo -> diffusers, you must restructure code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) -- that's unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks ugly. For standard -> modular, this is stricter: copy loop logic verbatim and only restructure into blocks. In both cases, clean up in a separate commit after parity is confirmed.
|
||||
|
||||
2. **Match the reference noise generation first.** The way initial noise/latents are constructed (seed handling, generator, randn call order) often differs between reference and diffusers. If the noise doesn't match, nothing downstream will match, making it impossible to isolate other bugs. Strategy: in the first implementation, replicate the reference's exact noise construction to get parity. After everything else is confirmed working, swap to diffusers-style noise generation as a final step.
|
||||
|
||||
3. **Test from the start.** Have component tests ready BEFORE writing conversion code. Test bottom-up: components first, then pipeline stages, then e2e.
|
||||
|
||||
## Test strategy
|
||||
|
||||
**Component parity (CPU/float32) -- always run, as you build.**
|
||||
Test each component before assembling the pipeline. This is the foundation -- if individual pieces are wrong, the pipeline can't be right. Each component in isolation, strict max_diff < 1e-3. Two modes:
|
||||
- **Fresh**: convert from checkpoint weights, compare against reference (catches conversion bugs)
|
||||
- **Saved**: load from saved model on disk, compare against reference (catches stale saves)
|
||||
|
||||
Keep component test scripts around -- you will need to re-run them during pipeline debugging with different inputs or config values. For example, you might initially test a transformer with random inputs, then later re-run it with actual pipeline-captured inputs to confirm it still matches. Having the test ready and easy to modify saves significant time.
|
||||
|
||||
Template -- one self-contained script per component, reference and diffusers side-by-side:
|
||||
```python
|
||||
@torch.inference_mode()
|
||||
def test_my_component(mode="fresh", model_path=None):
|
||||
# 1. Deterministic input
|
||||
gen = torch.Generator().manual_seed(42)
|
||||
x = torch.randn(1, 3, 64, 64, generator=gen, dtype=torch.float32)
|
||||
|
||||
# 2. Reference: load from checkpoint, run, free
|
||||
ref_model = ReferenceModel.from_config(config)
|
||||
ref_model.load_state_dict(load_weights("prefix"), strict=True)
|
||||
ref_model = ref_model.float().eval()
|
||||
ref_out = ref_model(x).clone()
|
||||
del ref_model
|
||||
|
||||
# 3. Diffusers: fresh (convert weights) or saved (from_pretrained)
|
||||
if mode == "fresh":
|
||||
diff_model = convert_my_component(load_weights("prefix"))
|
||||
else:
|
||||
diff_model = DiffusersModel.from_pretrained(model_path, torch_dtype=torch.float32)
|
||||
diff_model = diff_model.float().eval()
|
||||
diff_out = diff_model(x)
|
||||
del diff_model
|
||||
|
||||
# 4. Compare in same script -- no saving to disk
|
||||
max_diff = (ref_out - diff_out).abs().max().item()
|
||||
assert max_diff < 1e-3, f"FAIL: max_diff={max_diff:.2e}"
|
||||
```
|
||||
Key points: (a) both sides in one script -- never split into separate scripts that save/load intermediates, (b) deterministic input via seeded generator, (c) load one model at a time to fit in CPU RAM, (d) `.clone()` the reference output before deleting the model.
|
||||
|
||||
**E2E visual (GPU/bfloat16) -- once the pipeline is assembled.**
|
||||
Both pipelines generate independently with identical seeds/params. Save outputs and compare visually. If outputs look identical, you're done -- no need for deeper testing.
|
||||
|
||||
**Pipeline stage tests -- only if E2E fails and you need to isolate the bug.**
|
||||
For small models, run on CPU/float32 for strict comparison. For large models (e.g. 22B params), CPU/float32 is impractical -- use GPU/bfloat16 with `enable_model_cpu_offload()` and relax tolerances (max_diff < 1e-1 for bfloat16 is typical for passing tests; cosine similarity > 0.9999 is a good secondary check).
|
||||
|
||||
Test encode and decode stages first -- they're simpler and bugs there are easier to fix. Only debug the denoising loop if encode and decode both pass.
|
||||
|
||||
The challenge: pipelines are monolithic `__call__` methods -- you can't just call "the encode part". The solution is a checkpoint mechanism that lets you stop, save, or inject tensors at named locations inside the pipeline.
|
||||
|
||||
**Add a `_checkpoints` argument to both pipelines.**
|
||||
|
||||
The Checkpoint class is minimal:
|
||||
```python
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
save: bool = False # capture variables into ckpt.data
|
||||
stop: bool = False # halt pipeline after this point
|
||||
load: bool = False # inject ckpt.data into local variables
|
||||
data: dict = field(default_factory=dict)
|
||||
```
|
||||
|
||||
The pipeline accepts an optional `dict[str, Checkpoint]`. Place checkpoint calls at boundaries between pipeline stages -- after each encoder, before the denoising loop (capture all loop inputs), after each loop iteration, after the loop (capture final latents before decode). Here's a skeleton showing where they go:
|
||||
|
||||
```python
|
||||
def __call__(self, prompt, ..., _checkpoints=None):
|
||||
# --- text encoding ---
|
||||
prompt_embeds = self.text_encoder(prompt)
|
||||
_maybe_checkpoint(_checkpoints, "text_encoding", {
|
||||
"prompt_embeds": prompt_embeds,
|
||||
})
|
||||
|
||||
# --- prepare latents, sigmas, positions ---
|
||||
latents = self.prepare_latents(...)
|
||||
sigmas = self.scheduler.sigmas
|
||||
# ...
|
||||
|
||||
_maybe_checkpoint(_checkpoints, "preloop", {
|
||||
"latents": latents,
|
||||
"sigmas": sigmas,
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"prompt_attention_mask": prompt_attention_mask,
|
||||
"video_coords": video_coords,
|
||||
# capture EVERYTHING the loop needs -- every tensor the transformer
|
||||
# forward() receives. Missing even one variable here means you can't
|
||||
# tell if it's the source of divergence during denoise debugging.
|
||||
})
|
||||
|
||||
# --- denoising loop ---
|
||||
for i, t in enumerate(timesteps):
|
||||
noise_pred = self.transformer(latents, t, prompt_embeds, ...)
|
||||
latents = self.scheduler.step(noise_pred, t, latents)[0]
|
||||
|
||||
_maybe_checkpoint(_checkpoints, f"after_step_{i}", {
|
||||
"latents": latents,
|
||||
})
|
||||
|
||||
_maybe_checkpoint(_checkpoints, "post_loop", {
|
||||
"latents": latents,
|
||||
})
|
||||
|
||||
# --- decode ---
|
||||
video = self.vae.decode(latents)
|
||||
return video
|
||||
```
|
||||
|
||||
Each `_maybe_checkpoint` call does three things based on the Checkpoint's flags: `save` captures the local variables into `ckpt.data`, `load` injects pre-populated `ckpt.data` back into local variables, `stop` halts execution (raises an exception caught at the top level). The helper:
|
||||
|
||||
```python
|
||||
def _maybe_checkpoint(checkpoints, name, data):
|
||||
if not checkpoints:
|
||||
return
|
||||
ckpt = checkpoints.get(name)
|
||||
if ckpt is None:
|
||||
return
|
||||
if ckpt.save:
|
||||
ckpt.data.update(data)
|
||||
if ckpt.stop:
|
||||
raise PipelineStop # caught at __call__ level, returns None
|
||||
```
|
||||
|
||||
**Write stage tests using checkpoints.**
|
||||
|
||||
Three stages, tested in this order -- **encode, decode, then denoise**:
|
||||
|
||||
- **`encode`** (test first): Stop both pipelines at `"preloop"`. Compare **every single variable** that will be consumed by the denoising loop -- not just latents and sigmas, but also prompt embeddings, attention masks, positional coordinates, connector outputs, and any conditioning inputs. If you only compare a subset, you'll miss divergent inputs and waste time debugging the loop for a bug that's actually upstream. List every argument the transformer's forward() takes and make sure each one is compared.
|
||||
- **`decode`** (test second, before denoise): Run the reference pipeline fully -- checkpoint the post-loop latents AND let it finish to get the decoded output. Then feed those same post-loop latents through the diffusers pipeline's decode path (unpacking, denormalization, VAE decode, etc). Compare the two decoded outputs. **Always test decode before spending time on denoise.** Decoder bugs (e.g. wrong config values, incorrect operation ordering) can cause severe visual artifacts (pixelation, color shifts) that look like denoising bugs but are much simpler to fix. Always visually inspect decoded output -- numerical metrics like PSNR can be misleadingly "close" (e.g. 28 dB) while hiding obvious visual defects.
|
||||
- **`denoise`** (test last): Run both pipelines with realistic `num_steps` (e.g. 30) so the scheduler computes correct sigmas/timesteps, but stop after 2 loop iterations using `after_step_1`. Don't set `num_steps=2` -- that produces unrealistic sigma schedules. Compare the latents after those 2 steps.
|
||||
|
||||
```python
|
||||
# Encode stage -- stop before the loop, compare ALL inputs:
|
||||
ref_ckpts = {"preloop": Checkpoint(save=True, stop=True)}
|
||||
run_reference_pipeline(ref_ckpts)
|
||||
ref_data = ref_ckpts["preloop"].data
|
||||
|
||||
diff_ckpts = {"preloop": Checkpoint(save=True, stop=True)}
|
||||
run_diffusers_pipeline(diff_ckpts)
|
||||
diff_data = diff_ckpts["preloop"].data
|
||||
|
||||
# Compare EVERY variable consumed by the denoise loop:
|
||||
compare_tensors("latents", ref_data["latents"], diff_data["latents"])
|
||||
compare_tensors("sigmas", ref_data["sigmas"], diff_data["sigmas"])
|
||||
compare_tensors("prompt_embeds", ref_data["prompt_embeds"], diff_data["prompt_embeds"])
|
||||
compare_tensors("prompt_attention_mask", ref_data["prompt_attention_mask"], diff_data["prompt_attention_mask"])
|
||||
compare_tensors("video_coords", ref_data["video_coords"], diff_data["video_coords"])
|
||||
# ... every single tensor the transformer forward() will receive
|
||||
|
||||
# Decode stage -- same latents through both decoders:
|
||||
ref_ckpts = {"post_loop": Checkpoint(save=True)}
|
||||
run_reference_pipeline(ref_ckpts)
|
||||
ref_latents = ref_ckpts["post_loop"].data["latents"]
|
||||
# Feed ref_latents through diffusers decode path, compare output visually AND numerically
|
||||
|
||||
# Denoise stage -- realistic steps, early stop after 2 iterations:
|
||||
ref_ckpts = {"after_step_1": Checkpoint(save=True, stop=True)}
|
||||
run_reference_pipeline(ref_ckpts) # uses default num_steps=30
|
||||
compare_tensors("latents", ref_ckpts[...].data["latents"], diff_ckpts[...].data["latents"])
|
||||
```
|
||||
|
||||
The key insight: the checkpoint dict is passed into the pipeline and mutated in-place. After the pipeline returns (or stops early), you read back `ckpt.data` to get the captured tensors. Both pipelines save under their own key names, so the test maps between them (e.g. reference `"video_state.latent"` -> diffusers `"latents"`).
|
||||
|
||||
**E2E-injected visual test**: Once you've identified a suspected root cause using stage tests, confirm it with an e2e-injected run -- inject the known-good tensor from reference and generate a full video. If the output looks identical to reference, you've confirmed the root cause. Fix it, then re-run the standard E2E test to verify.
|
||||
|
||||
## Debugging technique: Injection for root-cause isolation
|
||||
|
||||
When stage tests show divergence, you need to narrow down *which input* is causing it. The general technique: **inject a known-good tensor from one pipeline into the other** to test whether the remaining code is correct.
|
||||
|
||||
The principle is simple -- if you suspect input X is the root cause of divergence in stage S:
|
||||
1. Run the reference pipeline and capture X
|
||||
2. Run the diffusers pipeline but **replace** its X with the reference's X (via checkpoint load)
|
||||
3. Compare outputs of stage S
|
||||
|
||||
If outputs now match: X was the root cause. If they still diverge: the bug is in the stage logic itself, not in X.
|
||||
|
||||
This is the same pattern applied at different pipeline boundaries:
|
||||
|
||||
| What you're testing | What you inject | Where you inject |
|
||||
|---|---|---|
|
||||
| Is the decode stage correct? | Post-loop latents from reference | Before decode |
|
||||
| Is the denoise loop correct? | Pre-loop latents from reference | Before the loop |
|
||||
| Is step N correct? | Post-step-(N-1) latents from reference | Before step N |
|
||||
|
||||
Add `load` support at each checkpoint where you might want to inject:
|
||||
|
||||
```python
|
||||
_maybe_checkpoint(_checkpoints, "preloop", {"latents": latents, ...})
|
||||
|
||||
# Load support: replace local variables with injected data
|
||||
if _checkpoints:
|
||||
ckpt = _checkpoints.get("preloop")
|
||||
if ckpt is not None and ckpt.load:
|
||||
latents = ckpt.data["latents"].to(device=device, dtype=latents.dtype)
|
||||
```
|
||||
|
||||
For large models, free the source pipeline's GPU memory before loading the target pipeline. Clone injected tensors to CPU, delete everything else, then run the target with `enable_model_cpu_offload()`.
|
||||
|
||||
**Per-step accumulation tracing**: When injection confirms the loop is correct but you want to understand *how* a small initial difference compounds, capture `after_step_{i}` for every step and plot the max_diff curve. A healthy curve stays bounded; an exponential blowup in later steps points to an amplification mechanism (see Pitfall #13).
|
||||
|
||||
## Debugging technique: Visual comparison via frame extraction
|
||||
|
||||
For video pipelines, numerical metrics alone can be misleading (max_diff=0.25 might look identical, or max_diff=0.05 might be visibly wrong in specific regions). Extract and view individual frames programmatically:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
def extract_frames(video_np, frame_indices):
|
||||
"""video_np: (frames, H, W, 3) float array in [0, 1]"""
|
||||
for idx in frame_indices:
|
||||
frame = (video_np[idx] * 255).clip(0, 255).astype(np.uint8)
|
||||
img = Image.fromarray(frame)
|
||||
img.save(f"frame_{idx}.png")
|
||||
|
||||
# Compare specific frames from both pipelines
|
||||
extract_frames(ref_video, [0, 60, 120])
|
||||
extract_frames(diff_video, [0, 60, 120])
|
||||
```
|
||||
|
||||
This is especially useful for: (a) confirming a fix works before running expensive full-pipeline tests, (b) diagnosing *what kind* of visual artifact a numerical divergence produces (washed out? color shift? spatial distortion?), (c) e2e-injected tests where you want visual proof that the loop is correct when given identical inputs.
|
||||
|
||||
## Testing rules
|
||||
|
||||
1. **Never use reference code in the diffusers test path.** Each side must use only its own code. Using reference helper functions inside the diffusers path defeats the purpose -- you're no longer testing the diffusers implementation.
|
||||
2. **Never monkey-patch model internals in tests.** Do not replace `model.forward` or patch internal methods. A passing test with a patched forward proves nothing about the actual model.
|
||||
3. **Debugging instrumentation must be non-destructive.** Checkpoint captures (e.g. a `_checkpoint` dict) for debugging are fine, but must not alter control flow or outputs.
|
||||
4. **Prefer CPU/float32 for numerical comparison when practical.** Float32 avoids bfloat16 precision noise that obscures real bugs. But for large models (22B+), GPU/bfloat16 with `enable_model_cpu_offload()` is necessary -- use relaxed tolerances and cosine similarity as a secondary metric.
|
||||
5. **Test both fresh conversion AND saved model.** Fresh catches conversion logic bugs; saved catches stale/corrupted weights from previous runs.
|
||||
6. **Diff configs before debugging.** Before investigating any divergence, dump and compare all config values from both the reference checkpoint and the diffusers model. Reference configs can often be read from checkpoint metadata without loading the model. Don't trust code defaults -- the checkpoint may override them. A 30-second config diff prevents hours of debugging based on wrong assumptions.
|
||||
7. **Never modify cached/downloaded model configs directly.** If you need to test with a different config value (e.g. fixing `upsample_residual` from `true` to `false`), do NOT edit the file in `~/.cache/huggingface/`. That change is invisible -- no git tracking, no diff, easy to forget. Instead, either (a) save the model to a local repo directory and edit the config there, or (b) open a PR on the upstream HF repo and load with `revision="refs/pr/N"`. Both approaches make the change visible and trackable.
|
||||
8. **Test decode before denoise.** Always verify the decoder works correctly before spending time on the denoising loop. Feed identical post-loop latents from the reference through both decoders and compare outputs -- both numerically AND visually. Decoder config bugs (e.g. wrong `upsample_residual`) cause severe pixelation or artifacts that are trivial to fix once found, but look like denoising bugs from the E2E output. A decoder bug found after days of denoise debugging is wasted time.
|
||||
9. **Compare ALL loop inputs in the encode test.** The preloop checkpoint must capture every single tensor the transformer forward() will receive: latents, sigmas/timesteps, prompt embeddings, attention masks, positional coordinates, connector outputs, and any conditioning tensors. If you only compare latents and sigmas, you'll miss divergent conditioning inputs and waste time debugging the loop for a bug that's actually upstream.
|
||||
|
||||
## Comparison utilities
|
||||
|
||||
```python
|
||||
def compare_tensors(name: str, a: torch.Tensor, b: torch.Tensor, tol: float = 1e-3) -> bool:
|
||||
if a.shape != b.shape:
|
||||
print(f" FAIL {name}: shape mismatch {a.shape} vs {b.shape}")
|
||||
return False
|
||||
diff = (a.float() - b.float()).abs()
|
||||
max_diff = diff.max().item()
|
||||
mean_diff = diff.mean().item()
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
a.float().flatten().unsqueeze(0), b.float().flatten().unsqueeze(0)
|
||||
).item()
|
||||
passed = max_diff < tol
|
||||
print(f" {'PASS' if passed else 'FAIL'} {name}: max={max_diff:.2e}, mean={mean_diff:.2e}, cos={cos:.5f}")
|
||||
return passed
|
||||
```
|
||||
Cosine similarity is especially useful for GPU/bfloat16 tests where max_diff can be noisy -- `cos > 0.9999` is a strong signal even when max_diff exceeds tolerance.
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
### 1. Global CPU RNG
|
||||
`MultivariateNormal.sample()` uses the global CPU RNG, not `torch.Generator`. Must call `torch.manual_seed(seed)` before each pipeline run. A `generator=` kwarg won't help.
|
||||
|
||||
### 2. Timestep dtype
|
||||
Many transformers expect `int64` timesteps. `get_timestep_embedding` casts to float, so `745.3` and `745` produce different embeddings. Match the reference's casting.
|
||||
|
||||
### 3. Guidance parameter mapping
|
||||
Parameter names may differ: reference `zero_steps=1` (meaning `i <= 1`, 2 steps) vs target `zero_init_steps=2` (meaning `step < 2`, same thing). Check exact semantics.
|
||||
|
||||
### 4. `patch_size` in noise generation
|
||||
If noise generation depends on `patch_size` (e.g. `sample_block_noise`), it must be passed through. Missing it changes noise spatial structure.
|
||||
|
||||
### 5. Variable shadowing in nested loops
|
||||
Nested loops (stages -> chunks -> timesteps) can shadow variable names. If outer loop uses `latents` and inner loop also assigns to `latents`, scoping must match the reference.
|
||||
|
||||
### 6. Float precision differences -- don't dismiss them
|
||||
Target may compute in float32 where reference used bfloat16. Small per-element diffs (1e-3 to 1e-2) *look* harmless but can compound catastrophically over iterative processes like denoising loops (see Pitfalls #11 and #13). Before dismissing a precision difference: (a) check whether it feeds into an iterative process, (b) if so, trace the accumulation curve over all iterations to see if it stays bounded or grows exponentially. Only truly non-iterative precision diffs (e.g. in a single-pass encoder) are safe to accept.
|
||||
|
||||
### 7. Scheduler state reset between stages
|
||||
Some schedulers accumulate state (e.g. `model_outputs` in UniPC) that must be cleared between stages.
|
||||
|
||||
### 8. Component access
|
||||
Standard: `self.transformer`. Modular: `components.transformer`. Missing this causes AttributeError.
|
||||
|
||||
### 9. Guider state across stages
|
||||
In multi-stage denoising, the guider's internal state (e.g. `zero_init_steps`) may need save/restore between stages.
|
||||
|
||||
### 10. Model storage location
|
||||
NEVER store converted models in `/tmp/` -- temporary directories get wiped on restart. Always save converted checkpoints under a persistent path in the project repo (e.g. `models/ltx23-diffusers/`).
|
||||
|
||||
### 11. Noise dtype mismatch (causes washed-out output)
|
||||
|
||||
Reference code often generates noise in float32 then casts to model dtype (bfloat16) before storing:
|
||||
|
||||
```python
|
||||
noise = torch.randn(..., dtype=torch.float32, generator=gen)
|
||||
noise = noise.to(dtype=model_dtype) # bfloat16 -- values get quantized
|
||||
```
|
||||
|
||||
Diffusers pipelines may keep latents in float32 throughout the loop. The per-element difference is only ~1.5e-02, but this compounds over 30 denoising steps via 1/sigma amplification (Pitfall #13) and produces completely washed-out output.
|
||||
|
||||
**Fix**: Match the reference -- generate noise in the model's working dtype:
|
||||
```python
|
||||
latent_dtype = self.transformer.dtype # e.g. bfloat16
|
||||
latents = self.prepare_latents(..., dtype=latent_dtype, ...)
|
||||
```
|
||||
|
||||
**Detection**: Encode stage test shows initial latent max_diff of exactly ~1.5e-02. This specific magnitude is the signature of float32->bfloat16 quantization error.
|
||||
|
||||
### 12. RoPE position dtype
|
||||
|
||||
RoPE cosine/sine values are sensitive to position coordinate dtype. If reference uses bfloat16 positions but diffusers uses float32, the RoPE output diverges significantly (max_diff up to 2.0). Different modalities may use different position dtypes (e.g. video bfloat16, audio float32) -- check the reference carefully.
|
||||
|
||||
### 13. 1/sigma error amplification in Euler denoising
|
||||
|
||||
In Euler/flow-matching, the velocity formula divides by sigma: `v = (latents - pred_x0) / sigma`. As sigma shrinks from ~1.0 (step 0) to ~0.001 (step 29), errors are amplified up to 1000x. A 1.5e-02 init difference grows linearly through mid-steps, then exponentially in final steps, reaching max_diff ~6.0. This is why dtype mismatches (Pitfalls #11, #12) that seem tiny at init produce visually broken output. Use per-step accumulation tracing to diagnose.
|
||||
|
||||
### 14. Config value assumptions -- always diff, never assume
|
||||
|
||||
When debugging parity, don't assume config values match code defaults. The published model checkpoint may override defaults with different values. A wrong assumption about a single config field can send you down hours of debugging in the wrong direction.
|
||||
|
||||
**The pattern that goes wrong:**
|
||||
1. You see `param_x` has default `1` in the code
|
||||
2. The reference code also uses `param_x` with a default of `1`
|
||||
3. You assume both sides use `1` and apply a "fix" based on that
|
||||
4. But the actual checkpoint config has `param_x: 1000`, and so does the published diffusers config
|
||||
5. Your "fix" now *creates* divergence instead of fixing it
|
||||
|
||||
**Prevention -- config diff first:**
|
||||
```python
|
||||
# Reference: read from checkpoint metadata (no model loading needed)
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
ref_config = json.loads(safe_open(checkpoint_path, framework="pt").metadata()["config"])
|
||||
|
||||
# Diffusers: read from model config
|
||||
from diffusers import MyModel
|
||||
diff_model = MyModel.from_pretrained(model_path, subfolder="transformer")
|
||||
diff_config = dict(diff_model.config)
|
||||
|
||||
# Compare all values
|
||||
for key in sorted(set(list(ref_config.get("transformer", {}).keys()) + list(diff_config.keys()))):
|
||||
ref_val = ref_config.get("transformer", {}).get(key, "MISSING")
|
||||
diff_val = diff_config.get(key, "MISSING")
|
||||
if ref_val != diff_val:
|
||||
print(f" DIFF {key}: ref={ref_val}, diff={diff_val}")
|
||||
```
|
||||
|
||||
Run this **before** writing any hooks, analysis code, or fixes. It takes 30 seconds and catches wrong assumptions immediately.
|
||||
|
||||
**When debugging divergence -- trace values, don't reason about them:**
|
||||
If two implementations diverge, hook the actual intermediate values at the point of divergence rather than reading code to figure out what the values "should" be. Code analysis builds on assumptions; value tracing reveals facts.
|
||||
|
||||
### 15. Decoder config mismatch (causes pixelated artifacts)
|
||||
|
||||
The upstream model config may have wrong values for decoder-specific parameters (e.g. `upsample_residual`, `upsample_type`). These control whether the decoder uses skip connections in upsampling -- getting them wrong produces severe pixelation or blocky artifacts.
|
||||
|
||||
**Detection**: Feed identical post-loop latents through both decoders. If max pixel diff is large (PSNR < 40 dB) on CPU/float32, it's a real bug, not precision noise. Trace through decoder blocks (conv_in -> mid_block -> up_blocks) to find where divergence starts.
|
||||
|
||||
**Fix**: Correct the config value. Don't edit cached files in `~/.cache/huggingface/` -- either save to a local model directory or open a PR on the upstream repo (see Testing Rule #7).
|
||||
|
||||
### 16. Incomplete injection tests -- inject ALL variables or the test is invalid
|
||||
|
||||
When doing injection tests (feeding reference tensors into the diffusers pipeline), you must inject **every** divergent input, including sigmas/timesteps. A common mistake: the preloop checkpoint saves sigmas but the injection code only loads latents and embeddings. The test then runs with different sigma schedules, making it impossible to isolate the real cause.
|
||||
|
||||
**Prevention**: After writing injection code, verify by listing every variable the injected stage consumes and checking each one is either (a) injected from reference, or (b) confirmed identical between pipelines.
|
||||
|
||||
### 17. bf16 connector/encoder divergence -- don't chase it
|
||||
|
||||
When running on GPU/bfloat16, multi-layer encoders (e.g. 8-layer connector transformers) accumulate bf16 rounding noise that looks alarming (max_diff 0.3-2.7). Before investigating, re-run the component test on CPU/float32. If it passes (max_diff < 1e-4), the divergence is pure precision noise, not a code bug. Don't spend hours tracing through layers -- confirm on CPU/f32 and move on.
|
||||
|
||||
### 18. Stale test fixtures
|
||||
|
||||
When using saved tensors for cross-pipeline comparison, always ensure both sets of tensors were captured from the same run configuration (same seed, same config, same code version). Mixing fixtures from different runs (e.g. reference tensors from yesterday, diffusers tensors from today after a code change) creates phantom divergence that wastes debugging time. Regenerate both sides in a single test script execution.
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -178,4 +178,10 @@ tags
|
||||
.ruff_cache
|
||||
|
||||
# wandb
|
||||
wandb
|
||||
wandb
|
||||
|
||||
# AI agent generated symlinks
|
||||
/AGENTS.md
|
||||
/CLAUDE.md
|
||||
/.agents/skills
|
||||
/.claude/skills
|
||||
20
Makefile
20
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples
|
||||
.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples codex claude clean-ai
|
||||
|
||||
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
||||
export PYTHONPATH = src
|
||||
@@ -98,3 +98,21 @@ post-release:
|
||||
|
||||
post-patch:
|
||||
python utils/release.py --post_release --patch
|
||||
|
||||
# AI agent symlinks
|
||||
|
||||
codex:
|
||||
ln -snf .ai/AGENTS.md AGENTS.md
|
||||
mkdir -p .agents
|
||||
rm -rf .agents/skills
|
||||
ln -snf ../.ai/skills .agents/skills
|
||||
|
||||
claude:
|
||||
ln -snf .ai/AGENTS.md CLAUDE.md
|
||||
mkdir -p .claude
|
||||
rm -rf .claude/skills
|
||||
ln -snf ../.ai/skills .claude/skills
|
||||
|
||||
clean-ai:
|
||||
rm -f AGENTS.md CLAUDE.md
|
||||
rm -rf .agents/skills .claude/skills
|
||||
|
||||
@@ -17,3 +17,7 @@ A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-
|
||||
## Flux2Transformer2DModel
|
||||
|
||||
[[autodoc]] Flux2Transformer2DModel
|
||||
|
||||
## Flux2Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.transformers.transformer_flux2.Flux2Transformer2DModelOutput
|
||||
|
||||
@@ -41,5 +41,11 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a
|
||||
## Flux2KleinPipeline
|
||||
|
||||
[[autodoc]] Flux2KleinPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Flux2KleinKVPipeline
|
||||
|
||||
[[autodoc]] Flux2KleinKVPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -565,4 +565,16 @@ $ git push --set-upstream origin your-branch-for-syncing
|
||||
|
||||
### Style guide
|
||||
|
||||
For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).
|
||||
For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).
|
||||
|
||||
|
||||
## Coding with AI agents
|
||||
|
||||
The repository keeps AI-agent configuration in `.ai/` and exposes local agent files via symlinks.
|
||||
|
||||
- **Source of truth** — edit files under `.ai/` (`AGENTS.md` for coding guidelines, `skills/` for on-demand task knowledge)
|
||||
- **Don't edit** generated root-level `AGENTS.md`, `CLAUDE.md`, or `.agents/skills`/`.claude/skills` — they are symlinks
|
||||
- Setup commands:
|
||||
- `make codex` — symlink guidelines + skills for OpenAI Codex
|
||||
- `make claude` — symlink guidelines + skills for Claude Code
|
||||
- `make clean-ai` — remove all generated symlinks
|
||||
@@ -510,6 +510,7 @@ else:
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimatePipeline",
|
||||
"Flux2KleinKVPipeline",
|
||||
"Flux2KleinPipeline",
|
||||
"Flux2Pipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
@@ -1266,6 +1267,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
Flux2KleinKVPipeline,
|
||||
Flux2KleinPipeline,
|
||||
Flux2Pipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
|
||||
@@ -2538,8 +2538,12 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
alpha_tensor = state_dict.pop(alpha_key, None)
|
||||
if alpha_tensor is None:
|
||||
return 1.0, 1.0
|
||||
scale = (
|
||||
alpha_tensor.item() / rank
|
||||
) # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -21,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import BaseOutput, apply_lora_scale, logging
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -32,7 +33,6 @@ from ..embeddings import (
|
||||
apply_rotary_emb,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
|
||||
@@ -40,6 +40,216 @@ from ..normalization import AdaLayerNormContinuous
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Flux2Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Flux2Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input.
|
||||
kv_cache (`Flux2KVCache`, *optional*):
|
||||
The populated KV cache for reference image tokens. Only returned when `kv_cache_mode="extract"`.
|
||||
"""
|
||||
|
||||
sample: "torch.Tensor" # noqa: F821
|
||||
kv_cache: "Flux2KVCache | None" = None
|
||||
|
||||
|
||||
class Flux2KVLayerCache:
|
||||
"""Per-layer KV cache for reference image tokens in the Flux2 Klein KV model.
|
||||
|
||||
Stores the K and V projections (post-RoPE) for reference tokens extracted during the first denoising step. Tensor
|
||||
format: (batch_size, num_ref_tokens, num_heads, head_dim).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.k_ref: torch.Tensor | None = None
|
||||
self.v_ref: torch.Tensor | None = None
|
||||
|
||||
def store(self, k_ref: torch.Tensor, v_ref: torch.Tensor):
|
||||
"""Store reference token K/V."""
|
||||
self.k_ref = k_ref
|
||||
self.v_ref = v_ref
|
||||
|
||||
def get(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Retrieve cached reference token K/V."""
|
||||
if self.k_ref is None:
|
||||
raise RuntimeError("KV cache has not been populated yet.")
|
||||
return self.k_ref, self.v_ref
|
||||
|
||||
def clear(self):
|
||||
self.k_ref = None
|
||||
self.v_ref = None
|
||||
|
||||
|
||||
class Flux2KVCache:
|
||||
"""Container for all layers' reference-token KV caches.
|
||||
|
||||
Holds separate cache lists for double-stream and single-stream transformer blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, num_double_layers: int, num_single_layers: int):
|
||||
self.double_block_caches = [Flux2KVLayerCache() for _ in range(num_double_layers)]
|
||||
self.single_block_caches = [Flux2KVLayerCache() for _ in range(num_single_layers)]
|
||||
self.num_ref_tokens: int = 0
|
||||
|
||||
def get_double(self, layer_idx: int) -> Flux2KVLayerCache:
|
||||
return self.double_block_caches[layer_idx]
|
||||
|
||||
def get_single(self, layer_idx: int) -> Flux2KVLayerCache:
|
||||
return self.single_block_caches[layer_idx]
|
||||
|
||||
def clear(self):
|
||||
for cache in self.double_block_caches:
|
||||
cache.clear()
|
||||
for cache in self.single_block_caches:
|
||||
cache.clear()
|
||||
self.num_ref_tokens = 0
|
||||
|
||||
|
||||
def _flux2_kv_causal_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_txt_tokens: int,
|
||||
num_ref_tokens: int,
|
||||
kv_cache: Flux2KVLayerCache | None = None,
|
||||
backend=None,
|
||||
) -> torch.Tensor:
|
||||
"""Causal attention for KV caching where reference tokens only self-attend.
|
||||
|
||||
All tensors use the diffusers convention: (batch_size, seq_len, num_heads, head_dim).
|
||||
|
||||
Without cache (extract mode): sequence layout is [txt, ref, img]. txt+img tokens attend to all tokens, ref tokens
|
||||
only attend to themselves. With cache (cached mode): sequence layout is [txt, img]. Cached ref K/V are injected
|
||||
between txt and img.
|
||||
"""
|
||||
# No ref tokens and no cache — standard full attention
|
||||
if num_ref_tokens == 0 and kv_cache is None:
|
||||
return dispatch_attention_fn(query, key, value, backend=backend)
|
||||
|
||||
if kv_cache is not None:
|
||||
# Cached mode: inject ref K/V between txt and img
|
||||
k_ref, v_ref = kv_cache.get()
|
||||
|
||||
k_all = torch.cat([key[:, :num_txt_tokens], k_ref, key[:, num_txt_tokens:]], dim=1)
|
||||
v_all = torch.cat([value[:, :num_txt_tokens], v_ref, value[:, num_txt_tokens:]], dim=1)
|
||||
|
||||
return dispatch_attention_fn(query, k_all, v_all, backend=backend)
|
||||
|
||||
# Extract mode: ref tokens self-attend, txt+img attend to all
|
||||
ref_start = num_txt_tokens
|
||||
ref_end = num_txt_tokens + num_ref_tokens
|
||||
|
||||
q_txt = query[:, :ref_start]
|
||||
q_ref = query[:, ref_start:ref_end]
|
||||
q_img = query[:, ref_end:]
|
||||
|
||||
k_txt = key[:, :ref_start]
|
||||
k_ref = key[:, ref_start:ref_end]
|
||||
k_img = key[:, ref_end:]
|
||||
|
||||
v_txt = value[:, :ref_start]
|
||||
v_ref = value[:, ref_start:ref_end]
|
||||
v_img = value[:, ref_end:]
|
||||
|
||||
# txt+img attend to all tokens
|
||||
q_txt_img = torch.cat([q_txt, q_img], dim=1)
|
||||
k_all = torch.cat([k_txt, k_ref, k_img], dim=1)
|
||||
v_all = torch.cat([v_txt, v_ref, v_img], dim=1)
|
||||
attn_txt_img = dispatch_attention_fn(q_txt_img, k_all, v_all, backend=backend)
|
||||
attn_txt = attn_txt_img[:, :ref_start]
|
||||
attn_img = attn_txt_img[:, ref_start:]
|
||||
|
||||
# ref tokens self-attend only
|
||||
attn_ref = dispatch_attention_fn(q_ref, k_ref, v_ref, backend=backend)
|
||||
|
||||
return torch.cat([attn_txt, attn_ref, attn_img], dim=1)
|
||||
|
||||
|
||||
def _blend_mod_params(
|
||||
img_params: tuple[torch.Tensor, ...],
|
||||
ref_params: tuple[torch.Tensor, ...],
|
||||
num_ref: int,
|
||||
seq_len: int,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Blend modulation parameters so that the first `num_ref` positions use `ref_params`."""
|
||||
blended = []
|
||||
for im, rm in zip(img_params, ref_params):
|
||||
if im.ndim == 2:
|
||||
im = im.unsqueeze(1)
|
||||
rm = rm.unsqueeze(1)
|
||||
B = im.shape[0]
|
||||
blended.append(
|
||||
torch.cat(
|
||||
[rm.expand(B, num_ref, -1), im.expand(B, seq_len, -1)[:, num_ref:, :]],
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
return tuple(blended)
|
||||
|
||||
|
||||
def _blend_double_block_mods(
|
||||
img_mod: torch.Tensor,
|
||||
ref_mod: torch.Tensor,
|
||||
num_ref: int,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Blend double-block image-stream modulations for a [ref, img] sequence layout.
|
||||
|
||||
Takes raw modulation tensors (before `Flux2Modulation.split`) and returns a blended raw tensor that is compatible
|
||||
with `Flux2Modulation.split(mod, 2)`.
|
||||
"""
|
||||
if img_mod.ndim == 2:
|
||||
img_mod = img_mod.unsqueeze(1)
|
||||
ref_mod = ref_mod.unsqueeze(1)
|
||||
img_chunks = torch.chunk(img_mod, 6, dim=-1)
|
||||
ref_chunks = torch.chunk(ref_mod, 6, dim=-1)
|
||||
img_mods = (img_chunks[0:3], img_chunks[3:6])
|
||||
ref_mods = (ref_chunks[0:3], ref_chunks[3:6])
|
||||
|
||||
all_params = []
|
||||
for img_set, ref_set in zip(img_mods, ref_mods):
|
||||
blended = _blend_mod_params(img_set, ref_set, num_ref, seq_len)
|
||||
all_params.extend(blended)
|
||||
return torch.cat(all_params, dim=-1)
|
||||
|
||||
|
||||
def _blend_single_block_mods(
|
||||
single_mod: torch.Tensor,
|
||||
ref_mod: torch.Tensor,
|
||||
num_txt: int,
|
||||
num_ref: int,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Blend single-block modulations for a [txt, ref, img] sequence layout.
|
||||
|
||||
Takes raw modulation tensors and returns a blended raw tensor compatible with `Flux2Modulation.split(mod, 1)`.
|
||||
"""
|
||||
if single_mod.ndim == 2:
|
||||
single_mod = single_mod.unsqueeze(1)
|
||||
ref_mod = ref_mod.unsqueeze(1)
|
||||
img_params = torch.chunk(single_mod, 3, dim=-1)
|
||||
ref_params = torch.chunk(ref_mod, 3, dim=-1)
|
||||
|
||||
blended = []
|
||||
for im, rm in zip(img_params, ref_params):
|
||||
if im.ndim == 2:
|
||||
im = im.unsqueeze(1)
|
||||
rm = rm.unsqueeze(1)
|
||||
B = im.shape[0]
|
||||
im_expanded = im.expand(B, seq_len, -1)
|
||||
rm_expanded = rm.expand(B, num_ref, -1)
|
||||
blended.append(
|
||||
torch.cat(
|
||||
[im_expanded[:, :num_txt, :], rm_expanded, im_expanded[:, num_txt + num_ref :, :]],
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
return torch.cat(blended, dim=-1)
|
||||
|
||||
|
||||
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
@@ -181,9 +391,108 @@ class Flux2AttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2KVAttnProcessor:
|
||||
"""
|
||||
Attention processor for Flux2 double-stream blocks with KV caching support for reference image tokens.
|
||||
|
||||
When `kv_cache_mode` is "extract", reference token K/V are stored in the cache after RoPE and causal attention is
|
||||
used (ref tokens self-attend only, txt+img attend to all). When `kv_cache_mode` is "cached", cached ref K/V are
|
||||
injected during attention. When no KV args are provided, behaves identically to `Flux2AttnProcessor`.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
image_rotary_emb: torch.Tensor | None = None,
|
||||
kv_cache: Flux2KVLayerCache | None = None,
|
||||
kv_cache_mode: str | None = None,
|
||||
num_ref_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if attn.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
num_txt_tokens = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
|
||||
|
||||
# Extract ref K/V from the combined sequence
|
||||
if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
|
||||
ref_start = num_txt_tokens
|
||||
ref_end = num_txt_tokens + num_ref_tokens
|
||||
kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
|
||||
|
||||
# Dispatch attention
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
hidden_states = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
|
||||
)
|
||||
elif kv_cache_mode == "cached" and kv_cache is not None:
|
||||
hidden_states = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
|
||||
)
|
||||
else:
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = Flux2AttnProcessor
|
||||
_available_processors = [Flux2AttnProcessor]
|
||||
_available_processors = [Flux2AttnProcessor, Flux2KVAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -312,6 +621,90 @@ class Flux2ParallelSelfAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2KVParallelSelfAttnProcessor:
|
||||
"""
|
||||
Attention processor for Flux2 single-stream blocks with KV caching support for reference image tokens.
|
||||
|
||||
When `kv_cache_mode` is "extract", reference token K/V are stored and causal attention is used. When
|
||||
`kv_cache_mode` is "cached", cached ref K/V are injected during attention. When no KV args are provided, behaves
|
||||
identically to `Flux2ParallelSelfAttnProcessor`.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2ParallelSelfAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
image_rotary_emb: torch.Tensor | None = None,
|
||||
kv_cache: Flux2KVLayerCache | None = None,
|
||||
kv_cache_mode: str | None = None,
|
||||
num_txt_tokens: int = 0,
|
||||
num_ref_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
# Parallel in (QKV + MLP in) projection
|
||||
hidden_states_proj = attn.to_qkv_mlp_proj(hidden_states)
|
||||
qkv, mlp_hidden_states = torch.split(
|
||||
hidden_states_proj, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
||||
)
|
||||
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
# Extract ref K/V from the combined sequence
|
||||
if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
|
||||
ref_start = num_txt_tokens
|
||||
ref_end = num_txt_tokens + num_ref_tokens
|
||||
kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
|
||||
|
||||
# Dispatch attention
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
attn_output = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
|
||||
)
|
||||
elif kv_cache_mode == "cached" and kv_cache is not None:
|
||||
attn_output = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
|
||||
)
|
||||
else:
|
||||
attn_output = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
attn_output = attn_output.flatten(2, 3)
|
||||
attn_output = attn_output.to(query.dtype)
|
||||
|
||||
# Handle the feedforward (FF) logic
|
||||
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
||||
|
||||
# Concatenate and parallel output projection
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1)
|
||||
hidden_states = attn.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
|
||||
@@ -322,7 +715,7 @@ class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
|
||||
_default_processor_cls = Flux2ParallelSelfAttnProcessor
|
||||
_available_processors = [Flux2ParallelSelfAttnProcessor]
|
||||
_available_processors = [Flux2ParallelSelfAttnProcessor, Flux2KVParallelSelfAttnProcessor]
|
||||
# Does not support QKV fusion as the QKV projections are always fused
|
||||
_supports_qkv_fusion = False
|
||||
|
||||
@@ -780,6 +1173,8 @@ class Flux2Transformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
_skip_keys = ["kv_cache"]
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
@@ -791,19 +1186,21 @@ class Flux2Transformer2DModel(
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: dict[str, Any] | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor | Transformer2DModelOutput:
|
||||
kv_cache: "Flux2KVCache | None" = None,
|
||||
kv_cache_mode: str | None = None,
|
||||
num_ref_tokens: int = 0,
|
||||
ref_fixed_timestep: float = 0.0,
|
||||
) -> torch.Tensor | Flux2Transformer2DModelOutput:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
The [`Flux2Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
timestep ( `torch.LongTensor`):
|
||||
timestep (`torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
@@ -811,13 +1208,23 @@ class Flux2Transformer2DModel(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
kv_cache (`Flux2KVCache`, *optional*):
|
||||
KV cache for reference image tokens. When `kv_cache_mode` is "extract", a new cache is created and
|
||||
returned. When "cached", the provided cache is used to inject ref K/V during attention.
|
||||
kv_cache_mode (`str`, *optional*):
|
||||
One of "extract" (first step with ref tokens) or "cached" (subsequent steps using cached ref K/V). When
|
||||
`None`, standard forward pass without KV caching.
|
||||
num_ref_tokens (`int`, defaults to `0`):
|
||||
Number of reference image tokens prepended to `hidden_states` (only used when
|
||||
`kv_cache_mode="extract"`).
|
||||
ref_fixed_timestep (`float`, defaults to `0.0`):
|
||||
Fixed timestep for reference token modulation (only used when `kv_cache_mode="extract"`).
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
`tuple` where the first element is the sample tensor. When `kv_cache_mode="extract"`, also returns the
|
||||
populated `Flux2KVCache`.
|
||||
"""
|
||||
# 0. Handle input arguments
|
||||
|
||||
num_txt_tokens = encoder_hidden_states.shape[1]
|
||||
|
||||
# 1. Calculate timestep embedding and modulation parameters
|
||||
@@ -832,13 +1239,33 @@ class Flux2Transformer2DModel(
|
||||
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
||||
single_stream_mod = self.single_stream_modulation(temb)
|
||||
|
||||
# KV extract mode: create cache and blend modulations for ref tokens
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
num_img_tokens = hidden_states.shape[1] # includes ref tokens
|
||||
|
||||
kv_cache = Flux2KVCache(
|
||||
num_double_layers=len(self.transformer_blocks),
|
||||
num_single_layers=len(self.single_transformer_blocks),
|
||||
)
|
||||
kv_cache.num_ref_tokens = num_ref_tokens
|
||||
|
||||
# Ref tokens use a fixed timestep for modulation
|
||||
ref_timestep = torch.full_like(timestep, ref_fixed_timestep * 1000)
|
||||
ref_temb = self.time_guidance_embed(ref_timestep, guidance)
|
||||
|
||||
ref_double_mod_img = self.double_stream_modulation_img(ref_temb)
|
||||
ref_single_mod = self.single_stream_modulation(ref_temb)
|
||||
|
||||
# Blend double block img modulation: [ref_mod, img_mod]
|
||||
double_stream_mod_img = _blend_double_block_mods(
|
||||
double_stream_mod_img, ref_double_mod_img, num_ref_tokens, num_img_tokens
|
||||
)
|
||||
|
||||
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
# 3. Calculate RoPE embeddings from image and text tokens
|
||||
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
||||
# text prompts of differents lengths. Is this a use case we want to support?
|
||||
if img_ids.ndim == 3:
|
||||
img_ids = img_ids[0]
|
||||
if txt_ids.ndim == 3:
|
||||
@@ -851,8 +1278,29 @@ class Flux2Transformer2DModel(
|
||||
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
||||
)
|
||||
|
||||
# 4. Double Stream Transformer Blocks
|
||||
# 4. Build joint_attention_kwargs with KV cache info
|
||||
if kv_cache_mode == "extract":
|
||||
kv_attn_kwargs = {
|
||||
**(joint_attention_kwargs or {}),
|
||||
"kv_cache": None,
|
||||
"kv_cache_mode": "extract",
|
||||
"num_ref_tokens": num_ref_tokens,
|
||||
}
|
||||
elif kv_cache_mode == "cached" and kv_cache is not None:
|
||||
kv_attn_kwargs = {
|
||||
**(joint_attention_kwargs or {}),
|
||||
"kv_cache": None,
|
||||
"kv_cache_mode": "cached",
|
||||
"num_ref_tokens": kv_cache.num_ref_tokens,
|
||||
}
|
||||
else:
|
||||
kv_attn_kwargs = joint_attention_kwargs
|
||||
|
||||
# 5. Double Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if kv_cache_mode is not None and kv_cache is not None:
|
||||
kv_attn_kwargs["kv_cache"] = kv_cache.get_double(index_block)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -861,7 +1309,7 @@ class Flux2Transformer2DModel(
|
||||
double_stream_mod_img,
|
||||
double_stream_mod_txt,
|
||||
concat_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
kv_attn_kwargs,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
@@ -870,13 +1318,30 @@ class Flux2Transformer2DModel(
|
||||
temb_mod_img=double_stream_mod_img,
|
||||
temb_mod_txt=double_stream_mod_txt,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
joint_attention_kwargs=kv_attn_kwargs,
|
||||
)
|
||||
|
||||
# Concatenate text and image streams for single-block inference
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 5. Single Stream Transformer Blocks
|
||||
# Blend single block modulation for extract mode: [txt_mod, ref_mod, img_mod]
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
total_single_len = hidden_states.shape[1]
|
||||
single_stream_mod = _blend_single_block_mods(
|
||||
single_stream_mod, ref_single_mod, num_txt_tokens, num_ref_tokens, total_single_len
|
||||
)
|
||||
|
||||
# Build single-block KV kwargs (single blocks need num_txt_tokens)
|
||||
if kv_cache_mode is not None:
|
||||
kv_attn_kwargs_single = {**kv_attn_kwargs, "num_txt_tokens": num_txt_tokens}
|
||||
else:
|
||||
kv_attn_kwargs_single = kv_attn_kwargs
|
||||
|
||||
# 6. Single Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if kv_cache_mode is not None and kv_cache is not None:
|
||||
kv_attn_kwargs_single["kv_cache"] = kv_cache.get_single(index_block)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -884,7 +1349,7 @@ class Flux2Transformer2DModel(
|
||||
None,
|
||||
single_stream_mod,
|
||||
concat_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
kv_attn_kwargs_single,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
@@ -892,16 +1357,25 @@ class Flux2Transformer2DModel(
|
||||
encoder_hidden_states=None,
|
||||
temb_mod=single_stream_mod,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
joint_attention_kwargs=kv_attn_kwargs_single,
|
||||
)
|
||||
# Remove text tokens from concatenated stream
|
||||
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
||||
|
||||
# 6. Output layers
|
||||
# Remove text tokens (and ref tokens in extract mode) from concatenated stream
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
hidden_states = hidden_states[:, num_txt_tokens + num_ref_tokens :, ...]
|
||||
else:
|
||||
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
||||
|
||||
# 7. Output layers
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if kv_cache_mode == "extract":
|
||||
if not return_dict:
|
||||
return (output, kv_cache)
|
||||
return Flux2Transformer2DModelOutput(sample=output, kv_cache=kv_cache)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
return Flux2Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -129,7 +129,7 @@ else:
|
||||
]
|
||||
_import_structure["bria"] = ["BriaPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinKVPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
@@ -671,7 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxPriorReduxPipeline,
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
|
||||
from .flux2 import Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .helios import HeliosPipeline, HeliosPyramidPipeline
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
|
||||
@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
|
||||
_import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
|
||||
_import_structure["pipeline_flux2_klein_kv"] = ["Flux2KleinKVPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_flux2 import Flux2Pipeline
|
||||
from .pipeline_flux2_klein import Flux2KleinPipeline
|
||||
from .pipeline_flux2_klein_kv import Flux2KleinKVPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -744,7 +744,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: list[PIL.Image.Image, PIL.Image.Image] | None = None,
|
||||
image: PIL.Image.Image | list[PIL.Image.Image] | None = None,
|
||||
prompt: str | list[str] = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
|
||||
886
src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py
Normal file
886
src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py
Normal file
@@ -0,0 +1,886 @@
|
||||
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
|
||||
from ...models.transformers.transformer_flux2 import Flux2KVAttnProcessor, Flux2KVParallelSelfAttnProcessor
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .image_processor import Flux2ImageProcessor
|
||||
from .pipeline_output import Flux2PipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from diffusers import Flux2KleinKVPipeline
|
||||
|
||||
>>> pipe = Flux2KleinKVPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.2-klein-9b-kv", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> ref_image = Image.open("reference.png")
|
||||
>>> image = pipe("A cat dressed like a wizard", image=ref_image, num_inference_steps=4).images[0]
|
||||
>>> image.save("flux2_kv_output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu
|
||||
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
|
||||
if image_seq_len > 4300:
|
||||
mu = a2 * image_seq_len + b2
|
||||
return float(mu)
|
||||
|
||||
m_200 = a2 * image_seq_len + b2
|
||||
m_10 = a1 * image_seq_len + b1
|
||||
|
||||
a = (m_200 - m_10) / 190.0
|
||||
b = m_200 - 200.0 * a
|
||||
mu = a * num_steps + b
|
||||
|
||||
return float(mu)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
sigmas: list[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`list[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`list[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class Flux2KleinKVPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
r"""
|
||||
The Flux2 Klein KV pipeline for text-to-image generation with KV-cached reference image conditioning.
|
||||
|
||||
On the first denoising step, reference image tokens are included in the forward pass and their attention K/V
|
||||
projections are cached. On subsequent steps, the cached K/V are reused without recomputing, providing faster
|
||||
inference when using reference images.
|
||||
|
||||
Reference:
|
||||
[https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence)
|
||||
|
||||
Args:
|
||||
transformer ([`Flux2Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLFlux2`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen3ForCausalLM`]):
|
||||
[Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM)
|
||||
tokenizer (`Qwen2TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLFlux2,
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
transformer: Flux2Transformer2DModel,
|
||||
is_distilled: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.tokenizer_max_length = 512
|
||||
self.default_sample_size = 128
|
||||
|
||||
# Set KV-cache-aware attention processors
|
||||
self._set_kv_attn_processors()
|
||||
|
||||
@staticmethod
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: str | list[str],
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: list[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids
|
||||
def _prepare_text_ids(
|
||||
x: torch.Tensor, # (B, L, D) or (L, D)
|
||||
t_coord: torch.Tensor | None = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids
|
||||
def _prepare_latent_ids(
|
||||
latents: torch.Tensor, # (B, C, H, W)
|
||||
):
|
||||
r"""
|
||||
Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor):
|
||||
Latent tensor of shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
|
||||
H=[0..H-1], W=[0..W-1], L=0
|
||||
"""
|
||||
|
||||
batch_size, _, height, width = latents.shape
|
||||
|
||||
t = torch.arange(1) # [0] - time dimension
|
||||
h = torch.arange(height)
|
||||
w = torch.arange(width)
|
||||
l = torch.arange(1) # [0] - layer dimension
|
||||
|
||||
# Create position IDs: (H*W, 4)
|
||||
latent_ids = torch.cartesian_prod(t, h, w, l)
|
||||
|
||||
# Expand to batch: (B, H*W, 4)
|
||||
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
||||
return latent_ids
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
|
||||
def _prepare_image_ids(
|
||||
image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
|
||||
scale: int = 10,
|
||||
):
|
||||
r"""
|
||||
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
||||
|
||||
This function creates a unique coordinate for every pixel/patch across all input latent with different
|
||||
dimensions.
|
||||
|
||||
Args:
|
||||
image_latents (list[torch.Tensor]):
|
||||
A list of image latent feature tensors, typically of shape (C, H, W).
|
||||
scale (int, optional):
|
||||
A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
|
||||
latent is: 'scale + scale * i'. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
|
||||
input latents.
|
||||
|
||||
Coordinate Components (Dimension 4):
|
||||
- T (Time): The unique index indicating which latent image the coordinate belongs to.
|
||||
- H (Height): The row index within that latent image.
|
||||
- W (Width): The column index within that latent image.
|
||||
- L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
|
||||
"""
|
||||
|
||||
if not isinstance(image_latents, list):
|
||||
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
||||
|
||||
# create time offset for each reference image
|
||||
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||
t_coords = [t.view(-1) for t in t_coords]
|
||||
|
||||
image_latent_ids = []
|
||||
for x, t in zip(image_latents, t_coords):
|
||||
x = x.squeeze(0)
|
||||
_, height, width = x.shape
|
||||
|
||||
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||
image_latent_ids.append(x_ids)
|
||||
|
||||
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||
|
||||
return image_latent_ids
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents
|
||||
def _patchify_latents(latents):
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
||||
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents
|
||||
def _unpatchify_latents(latents):
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
||||
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents
|
||||
def _pack_latents(latents):
|
||||
"""
|
||||
pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
|
||||
"""
|
||||
|
||||
batch_size, num_channels, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
|
||||
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
|
||||
"""
|
||||
using position ids to scatter tokens into place
|
||||
"""
|
||||
x_list = []
|
||||
for data, pos in zip(x, x_ids):
|
||||
_, ch = data.shape # noqa: F841
|
||||
h_ids = pos[:, 1].to(torch.int64)
|
||||
w_ids = pos[:, 2].to(torch.int64)
|
||||
|
||||
h = torch.max(h_ids) + 1
|
||||
w = torch.max(w_ids) + 1
|
||||
|
||||
flat_ids = h_ids * w + w_ids
|
||||
|
||||
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
||||
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
||||
|
||||
# reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
|
||||
|
||||
out = out.view(h, w, ch).permute(2, 0, 1)
|
||||
x_list.append(out)
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
def _set_kv_attn_processors(self):
|
||||
"""Replace default attention processors with KV-cache-aware variants."""
|
||||
for block in self.transformer.transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVAttnProcessor())
|
||||
for block in self.transformer.single_transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str | list[str],
|
||||
device: torch.device | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: tuple[int] = (9, 18, 27),
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
hidden_states_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
text_ids = self._prepare_text_ids(prompt_embeds)
|
||||
text_ids = text_ids.to(device)
|
||||
return prompt_embeds, text_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if image.ndim != 4:
|
||||
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
||||
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
||||
image_latents = self._patchify_latents(image_latents)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
|
||||
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
||||
|
||||
return image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_latents_channels,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator: torch.Generator,
|
||||
latents: torch.Tensor | None = None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
latent_ids = self._prepare_latent_ids(latents)
|
||||
latent_ids = latent_ids.to(device)
|
||||
|
||||
latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
|
||||
return latents, latent_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
images: list[torch.Tensor],
|
||||
batch_size,
|
||||
generator: torch.Generator,
|
||||
device,
|
||||
dtype,
|
||||
):
|
||||
image_latents = []
|
||||
for image in images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
imagge_latent = self._encode_vae_image(image=image, generator=generator)
|
||||
image_latents.append(imagge_latent) # (1, 128, 32, 32)
|
||||
|
||||
image_latent_ids = self._prepare_image_ids(image_latents)
|
||||
|
||||
# Pack each latent and concatenate
|
||||
packed_latents = []
|
||||
for latent in image_latents:
|
||||
# latent: (1, 128, 32, 32)
|
||||
packed = self._pack_latents(latent) # (1, 1024, 128)
|
||||
packed = packed.squeeze(0) # (1024, 128) - remove batch dim
|
||||
packed_latents.append(packed)
|
||||
|
||||
# Concatenate all reference tokens along sequence dimension
|
||||
image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
|
||||
image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
|
||||
|
||||
image_latents = image_latents.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.to(device)
|
||||
|
||||
return image_latents, image_latent_ids
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if (
|
||||
height is not None
|
||||
and height % (self.vae_scale_factor * 2) != 0
|
||||
or width is not None
|
||||
and width % (self.vae_scale_factor * 2) != 0
|
||||
):
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: list[PIL.Image.Image] | PIL.Image.Image | None = None,
|
||||
prompt: str | list[str] = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
num_inference_steps: int = 4,
|
||||
sigmas: list[float] | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: tuple[int] = (9, 18, 27),
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*):
|
||||
Reference image(s) for conditioning. On the first denoising step, reference tokens are included in the
|
||||
forward pass and their attention K/V are cached. On subsequent steps, the cached K/V are reused without
|
||||
recomputing.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 4):
|
||||
The number of denoising steps.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas for the denoising schedule.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
Generator(s) for deterministic generation.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
Output format: `"pil"` or `"np"`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a `Flux2PipelineOutput` or a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Extra kwargs passed to attention processors.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
Callback function called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
Tensor inputs for the callback function.
|
||||
max_sequence_length (`int`, defaults to 512):
|
||||
Maximum sequence length for the prompt.
|
||||
text_encoder_out_layers (`tuple[int]`):
|
||||
Layer indices for text encoder hidden state extraction.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`.
|
||||
"""
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. prepare text embeddings
|
||||
prompt_embeds, text_ids = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
text_encoder_out_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
# 4. process images
|
||||
if image is not None and not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
condition_images = None
|
||||
if image is not None:
|
||||
for img in image:
|
||||
self.image_processor.check_image_input(img)
|
||||
|
||||
condition_images = []
|
||||
for img in image:
|
||||
image_width, image_height = img.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||
image_width, image_height = img.size
|
||||
|
||||
multiple_of = self.vae_scale_factor * 2
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
||||
condition_images.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 5. prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, latent_ids = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_latents_channels=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
image_latents = None
|
||||
image_latent_ids = None
|
||||
if condition_images is not None:
|
||||
image_latents, image_latent_ids = self.prepare_image_latents(
|
||||
images=condition_images,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
|
||||
# 6. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 7. Denoising loop with KV caching
|
||||
# Step 0 with ref images: forward_kv_extract (full pass, cache ref K/V)
|
||||
# Steps 1+: forward_kv_cached (reuse cached ref K/V)
|
||||
# No ref images: standard forward
|
||||
self.scheduler.set_begin_index(0)
|
||||
kv_cache = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
if i == 0 and image_latents is not None:
|
||||
# Step 0: include ref tokens, extract KV cache
|
||||
latent_model_input = torch.cat([image_latents, latents], dim=1).to(self.transformer.dtype)
|
||||
latent_image_ids = torch.cat([image_latent_ids, latent_ids], dim=1)
|
||||
|
||||
noise_pred, kv_cache = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=image_latents.shape[1],
|
||||
)
|
||||
|
||||
elif kv_cache is not None:
|
||||
# Steps 1+: use cached ref KV, no ref tokens in input
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents.to(self.transformer.dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_ids,
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_cache=kv_cache,
|
||||
kv_cache_mode="cached",
|
||||
)[0]
|
||||
|
||||
else:
|
||||
# No reference images: standard forward
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents.to(self.transformer.dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_ids,
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# Clean up KV cache
|
||||
if kv_cache is not None:
|
||||
kv_cache.clear()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
latents = self._unpatchify_latents(latents)
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return Flux2PipelineOutput(images=image)
|
||||
@@ -1202,6 +1202,21 @@ class EasyAnimatePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinKVPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
AutoPipelineBlocks,
|
||||
ConditionalPipelineBlocks,
|
||||
InputParam,
|
||||
ModularPipelineBlocks,
|
||||
)
|
||||
|
||||
|
||||
class TextToImageBlock(ModularPipelineBlocks):
|
||||
model_name = "text2img"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "text-to-image workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "text2img"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ImageToImageBlock(ModularPipelineBlocks):
|
||||
model_name = "img2img"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt"), InputParam(name="image")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "image-to-image workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "img2img"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class InpaintBlock(ModularPipelineBlocks):
|
||||
model_name = "inpaint"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "inpaint workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "inpaint"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ConditionalImageBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = "text2img"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Conditional image blocks for testing"
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None # falls back to default_block_name
|
||||
|
||||
|
||||
class OptionalConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = None # no default; block can be skipped
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Optional conditional blocks (skippable)"
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None
|
||||
|
||||
|
||||
class AutoImageBlocks(AutoPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto image blocks for testing"
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksSelectBlock:
|
||||
def test_select_block_with_mask(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask="something") == "inpaint"
|
||||
|
||||
def test_select_block_with_image(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(image="something") == "img2img"
|
||||
|
||||
def test_select_block_with_mask_and_image(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask="m", image="i") == "inpaint"
|
||||
|
||||
def test_select_block_no_triggers_returns_none(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block() is None
|
||||
|
||||
def test_select_block_explicit_none_values(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask=None, image=None) is None
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksWorkflowSelection:
|
||||
def test_default_workflow_when_no_triggers(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert execution is not None
|
||||
assert isinstance(execution, TextToImageBlock)
|
||||
|
||||
def test_mask_trigger_selects_inpaint(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_image_trigger_selects_img2img(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
def test_mask_and_image_selects_inpaint(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True, image=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_skippable_block_returns_none(self):
|
||||
blocks = OptionalConditionalBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert execution is None
|
||||
|
||||
def test_skippable_block_still_selects_when_triggered(self):
|
||||
blocks = OptionalConditionalBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
|
||||
class TestAutoPipelineBlocksSelectBlock:
|
||||
def test_auto_select_mask(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(mask="m") == "inpaint"
|
||||
|
||||
def test_auto_select_image(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(image="i") == "img2img"
|
||||
|
||||
def test_auto_select_default(self):
|
||||
blocks = AutoImageBlocks()
|
||||
# No trigger -> returns None -> falls back to default (text2img)
|
||||
assert blocks.select_block() is None
|
||||
|
||||
def test_auto_select_priority_order(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(mask="m", image="i") == "inpaint"
|
||||
|
||||
|
||||
class TestAutoPipelineBlocksWorkflowSelection:
|
||||
def test_auto_default_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert isinstance(execution, TextToImageBlock)
|
||||
|
||||
def test_auto_mask_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_auto_image_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksStructure:
|
||||
def test_block_names_accessible(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
sub = dict(blocks.sub_blocks)
|
||||
assert set(sub.keys()) == {"inpaint", "img2img", "text2img"}
|
||||
|
||||
def test_sub_block_types(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
sub = dict(blocks.sub_blocks)
|
||||
assert isinstance(sub["inpaint"], InpaintBlock)
|
||||
assert isinstance(sub["img2img"], ImageToImageBlock)
|
||||
assert isinstance(sub["text2img"], TextToImageBlock)
|
||||
|
||||
def test_description(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert "Conditional" in blocks.description
|
||||
@@ -9,6 +9,11 @@ import torch
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines import (
|
||||
ConditionalPipelineBlocks,
|
||||
LoopSequentialPipelineBlocks,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
@@ -19,6 +24,7 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
@@ -431,6 +437,117 @@ class ModularGuiderTesterMixin:
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def get_dummy_conditional_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
class DummyConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [DummyBlockOne, DummyBlockTwo]
|
||||
block_names = ["block_one", "block_two"]
|
||||
block_trigger_inputs = []
|
||||
|
||||
def select_block(self, **kwargs):
|
||||
return "block_one"
|
||||
|
||||
return DummyConditionalBlocks()
|
||||
|
||||
def get_dummy_loop_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
|
||||
|
||||
def test_sequential_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_sequential_block_requirements_warnings(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
def test_conditional_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_conditional_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
def test_loop_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_loop_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
|
||||
class TestModularModelCardContent:
|
||||
def create_mock_block(self, name="TestBlock", description="Test block description"):
|
||||
class MockBlock:
|
||||
|
||||
@@ -24,18 +24,14 @@ import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.modular_pipelines import (
|
||||
ComponentSpec,
|
||||
ConditionalPipelineBlocks,
|
||||
InputParam,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipelineBlocks,
|
||||
OutputParam,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
WanModularPipeline,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import CaptureLogger, nightly, require_torch, slow
|
||||
from ..testing_utils import nightly, require_torch, slow
|
||||
|
||||
|
||||
class DummyCustomBlockSimple(ModularPipelineBlocks):
|
||||
@@ -358,117 +354,6 @@ class TestModularCustomBlocks:
|
||||
assert output_prompt.startswith("Modular diffusers + ")
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def get_dummy_conditional_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
class DummyConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [DummyBlockOne, DummyBlockTwo]
|
||||
block_names = ["block_one", "block_two"]
|
||||
block_trigger_inputs = []
|
||||
|
||||
def select_block(self, **kwargs):
|
||||
return "block_one"
|
||||
|
||||
return DummyConditionalBlocks()
|
||||
|
||||
def get_dummy_loop_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
|
||||
|
||||
def test_sequential_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_sequential_block_requirements_warnings(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
def test_conditional_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_conditional_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
def test_loop_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_loop_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch
|
||||
|
||||
174
tests/pipelines/flux2/test_pipeline_flux2_klein_kv.py
Normal file
174
tests/pipelines/flux2/test_pipeline_flux2_klein_kv.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLFlux2,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
Flux2KleinKVPipeline,
|
||||
Flux2Transformer2DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class Flux2KleinKVPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = Flux2KleinKVPipeline
|
||||
params = frozenset(["prompt", "height", "width", "prompt_embeds", "image"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = Flux2Transformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=16,
|
||||
timestep_guidance_channels=256,
|
||||
axes_dims_rope=[4, 4, 4, 4],
|
||||
guidance_embeds=False,
|
||||
)
|
||||
|
||||
# Create minimal Qwen3 config
|
||||
config = Qwen3Config(
|
||||
intermediate_size=16,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=151936,
|
||||
max_position_embeddings=512,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = Qwen3ForCausalLM(config)
|
||||
|
||||
# Use a simple tokenizer for testing
|
||||
tokenizer = Qwen2TokenizerFast.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLFlux2(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=1,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "a dog is dancing",
|
||||
"image": Image.new("RGB", (64, 64)),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 64,
|
||||
"output_type": "np",
|
||||
"text_encoder_out_layers": (1,),
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
|
||||
("Fusion of QKV projections shouldn't affect the outputs."),
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
|
||||
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
|
||||
("Original outputs should match when fused QKV projections are disabled."),
|
||||
)
|
||||
|
||||
def test_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (72, 57)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height - height % (pipe.vae_scale_factor * 2)
|
||||
expected_width = width - width % (pipe.vae_scale_factor * 2)
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
self.assertEqual(
|
||||
(output_height, output_width),
|
||||
(expected_height, expected_width),
|
||||
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
|
||||
)
|
||||
|
||||
def test_without_image(self):
|
||||
device = "cpu"
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["image"]
|
||||
image = pipe(**inputs).images
|
||||
self.assertEqual(image.shape, (1, 8, 8, 3))
|
||||
|
||||
@unittest.skip("Needs to be revisited")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user