Compare commits

..

6 Commits

Author SHA1 Message Date
Sayak Paul
faca9b90a7 Merge branch 'main' into tests-load-components 2026-03-12 20:57:38 +05:30
sayakpaul
a1f63a398c up 2026-03-10 17:55:08 +05:30
sayakpaul
bf846f722c u[ 2026-03-10 17:49:51 +05:30
sayakpaul
78a86e85cf fix 2026-03-10 17:46:55 +05:30
sayakpaul
7673ab1757 fix 2026-03-10 16:50:27 +05:30
sayakpaul
b7648557d4 test load_components. 2026-03-10 16:09:02 +05:30
37 changed files with 447 additions and 2076 deletions

View File

@@ -1,77 +0,0 @@
# Diffusers — Agent Guide
## Coding style
Strive to write code as simple and explicit as possible.
- Minimize small helper/utility functions — inline the logic instead. A reader should be able to follow the full flow without jumping between functions.
- No defensive code or unused code paths — do not add fallback paths, safety checks, or configuration options "just in case". When porting from a research repo, delete training-time code paths, experimental flags, and ablation branches entirely — only keep the inference path you are actually integrating.
- Do not guess user intent and silently correct behavior. Make the expected inputs clear in the docstring, and raise a concise error for unsupported cases rather than adding complex fallback logic.
---
### Dependencies
- No new mandatory dependency without discussion (e.g. `einops`)
- Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`
## Code formatting
- `make style` and `make fix-copies` should be run as the final step before opening a PR
### Copied Code
- Many classes are kept in sync with a source via a `# Copied from ...` header comment
- Do not edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source
- Remove the header to intentionally break the link
### Models
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
- Try to not introduce graph breaks as much as possible for better compatibility with `torch.compile`. For example, DO NOT arbitrarily insert operations from NumPy in the forward implementations.
- Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
```python
# transformer_mymodel.py
class MyModelAttnProcessor:
_attention_backend = None
_parallel_config = None
def __call__(self, attn, hidden_states, attention_mask=None, ...):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# reshape, apply rope, etc.
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
return attn.to_out[0](hidden_states)
class MyModelAttention(nn.Module, AttentionModuleMixin):
_default_processor_cls = MyModelAttnProcessor
_available_processors = [MyModelAttnProcessor]
def __init__(self, query_dim, heads=8, dim_head=64, ...):
super().__init__()
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
self.set_processor(MyModelAttnProcessor())
def forward(self, hidden_states, attention_mask=None, **kwargs):
return self.processor(self, hidden_states, attention_mask, **kwargs)
```
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
### Pipeline
- All pipelines must inherit from `DiffusionPipeline`. Consult implementations in `src/diffusers/pipelines` in case you need references.
- DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline` which will be a part of the core codebase (`src`).
### Tests
- Slow tests gated with `@slow` and `RUN_SLOW=1`
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.

6
.gitignore vendored
View File

@@ -178,8 +178,4 @@ tags
.ruff_cache .ruff_cache
# wandb # wandb
wandb wandb
# AI agent generated symlinks
/AGENTS.md
/CLAUDE.md

View File

@@ -1,4 +1,4 @@
.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples codex claude clean-ai .PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src export PYTHONPATH = src
@@ -98,14 +98,3 @@ post-release:
post-patch: post-patch:
python utils/release.py --post_release --patch python utils/release.py --post_release --patch
# AI agent symlinks
codex:
ln -snf .ai/AGENTS.md AGENTS.md
claude:
ln -snf .ai/AGENTS.md CLAUDE.md
clean-ai:
rm -f AGENTS.md CLAUDE.md

View File

@@ -22,8 +22,6 @@
title: Reproducibility title: Reproducibility
- local: using-diffusers/schedulers - local: using-diffusers/schedulers
title: Schedulers title: Schedulers
- local: using-diffusers/guiders
title: Guiders
- local: using-diffusers/automodel - local: using-diffusers/automodel
title: AutoModel title: AutoModel
- local: using-diffusers/other-formats - local: using-diffusers/other-formats
@@ -112,6 +110,8 @@
title: ModularPipeline title: ModularPipeline
- local: modular_diffusers/components_manager - local: modular_diffusers/components_manager
title: ComponentsManager title: ComponentsManager
- local: modular_diffusers/guiders
title: Guiders
- local: modular_diffusers/custom_blocks - local: modular_diffusers/custom_blocks
title: Building Custom Blocks title: Building Custom Blocks
- local: modular_diffusers/mellon - local: modular_diffusers/mellon

View File

@@ -17,7 +17,3 @@ A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-
## Flux2Transformer2DModel ## Flux2Transformer2DModel
[[autodoc]] Flux2Transformer2DModel [[autodoc]] Flux2Transformer2DModel
## Flux2Transformer2DModelOutput
[[autodoc]] models.transformers.transformer_flux2.Flux2Transformer2DModelOutput

View File

@@ -41,11 +41,5 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a
## Flux2KleinPipeline ## Flux2KleinPipeline
[[autodoc]] Flux2KleinPipeline [[autodoc]] Flux2KleinPipeline
- all
- __call__
## Flux2KleinKVPipeline
[[autodoc]] Flux2KleinKVPipeline
- all - all
- __call__ - __call__

View File

@@ -99,7 +99,7 @@ To update guider configuration, you can run `pipe.guider = pipe.guider.new(...)`
pipe.guider = pipe.guider.new(guidance_scale=5.0) pipe.guider = pipe.guider.new(guidance_scale=5.0)
``` ```
Read more on Guider [here](../../using-diffusers/guiders). Read more on Guider [here](../../modular_diffusers/guiders).

View File

@@ -30,7 +30,7 @@ HunyuanImage-2.1 comes in the following variants:
## HunyuanImage-2.1 ## HunyuanImage-2.1
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../../using-diffusers/guiders)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead. HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
```python ```python
import torch import torch

View File

@@ -565,16 +565,4 @@ $ git push --set-upstream origin your-branch-for-syncing
### Style guide ### Style guide
For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html). For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).
## Coding with AI agents
The repository keeps AI-agent configuration in `.ai/` and exposes local agent files via symlinks.
- **Source of truth** — edit `.ai/AGENTS.md` (and any future `.ai/skills/`)
- **Don't edit** generated root-level `AGENTS.md` or `CLAUDE.md` — they are symlinks
- Setup commands:
- `make codex` — symlink for OpenAI Codex
- `make claude` — symlink for Claude Code
- `make clean-ai` — remove generated symlinks

View File

@@ -338,7 +338,7 @@ guider = ClassifierFreeGuidance(guidance_scale=5.0)
pipeline.update_components(guider=guider) pipeline.update_components(guider=guider)
``` ```
See the [Guiders](../using-diffusers/guiders) guide for more details on available guiders and how to configure them. See the [Guiders](./guiders) guide for more details on available guiders and how to configure them.
## Splitting a pipeline into stages ## Splitting a pipeline into stages

View File

@@ -39,7 +39,7 @@ The Modular Diffusers docs are organized as shown below.
- [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`]. - [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`].
- [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines. - [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines.
- [Guiders](../using-diffusers/guiders) shows you how to use different guidance methods in the pipeline. - [Guiders](./guiders) shows you how to use different guidance methods in the pipeline.
## Mellon Integration ## Mellon Integration

View File

@@ -482,6 +482,144 @@ print(
) # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works ) # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works
``` ```
## torch.jit.trace
[torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) records the operations a model performs on a sample input and creates a new, optimized representation of the model based on the recorded execution path. During tracing, the model is optimized to reduce overhead from Python and dynamic control flows and operations are fused together for more efficiency. The returned executable or [ScriptFunction](https://pytorch.org/docs/stable/generated/torch.jit.ScriptFunction.html) can be compiled.
```py
import time
import torch
from diffusers import StableDiffusionPipeline
import functools
# torch disable grad
torch.set_grad_enabled(False)
# set variables
n_experiments = 2
unet_runs_per_experiment = 50
# load sample inputs
def generate_inputs():
sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
return sample, timestep, encoder_hidden_states
pipeline = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
unet = pipeline.unet
unet.eval()
unet.to(memory_format=torch.channels_last) # use channels_last memory format
unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
# warmup
for _ in range(3):
with torch.inference_mode():
inputs = generate_inputs()
orig_output = unet(*inputs)
# trace
print("tracing..")
unet_traced = torch.jit.trace(unet, inputs)
unet_traced.eval()
print("done tracing")
# warmup and optimize graph
for _ in range(5):
with torch.inference_mode():
inputs = generate_inputs()
orig_output = unet_traced(*inputs)
# benchmarking
with torch.inference_mode():
for _ in range(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
for _ in range(unet_runs_per_experiment):
orig_output = unet_traced(*inputs)
torch.cuda.synchronize()
print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
for _ in range(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
for _ in range(unet_runs_per_experiment):
orig_output = unet(*inputs)
torch.cuda.synchronize()
print(f"unet inference took {time.time() - start_time:.2f} seconds")
# save the model
unet_traced.save("unet_traced.pt")
```
Replace the pipeline's UNet with the traced version.
```py
import torch
from diffusers import StableDiffusionPipeline
from dataclasses import dataclass
@dataclass
class UNet2DConditionOutput:
sample: torch.Tensor
pipeline = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
# use jitted unet
unet_traced = torch.jit.load("unet_traced.pt")
# del pipeline.unet
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = pipe.unet.config.in_channels
self.device = pipe.unet.device
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
pipeline.unet = TracedUNet()
with torch.inference_mode():
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
```
## Memory-efficient attention ## Memory-efficient attention
Diffusers supports multiple memory-efficient attention backends (FlashAttention, xFormers, SageAttention, and more) through [`~ModelMixin.set_attention_backend`]. Refer to the [Attention backends](./attention_backends) guide to learn how to switch between them. > [!TIP]
> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention)!
The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types.
By default, if PyTorch >= 2.0 is installed, [scaled dot-product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) is used. You don't need to make any additional changes to your code.
SDPA supports [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [xFormers](https://github.com/facebookresearch/xformers) as well as a native C++ PyTorch implementation. It automatically selects the most optimal implementation based on your input.
You can explicitly use xFormers with the [`~ModelMixin.enable_xformers_memory_efficient_attention`] method.
```py
# pip install xformers
import torch
from diffusers import StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_xformers_memory_efficient_attention()
```
Call [`~ModelMixin.disable_xformers_memory_efficient_attention`] to disable it.
```py
pipeline.disable_xformers_memory_efficient_attention()
```

View File

@@ -23,7 +23,7 @@ pip install xformers
> [!TIP] > [!TIP]
> The xFormers `pip` package requires the latest version of PyTorch. If you need to use a previous version of PyTorch, then we recommend [installing xFormers from the source](https://github.com/facebookresearch/xformers#installing-xformers). > The xFormers `pip` package requires the latest version of PyTorch. If you need to use a previous version of PyTorch, then we recommend [installing xFormers from the source](https://github.com/facebookresearch/xformers#installing-xformers).
After xFormers is installed, you can use it with [`~ModelMixin.set_attention_backend`] as shown in the [Attention backends](./attention_backends) guide. After xFormers is installed, you can use `enable_xformers_memory_efficient_attention()` for faster inference and reduced memory consumption as shown in this [section](memory#memory-efficient-attention).
> [!WARNING] > [!WARNING]
> According to this [issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training (fine-tune or DreamBooth) in some GPUs. If you observe this problem, please install a development version as indicated in the issue comments. > According to this [issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training (fine-tune or DreamBooth) in some GPUs. If you observe this problem, please install a development version as indicated in the issue comments.

View File

@@ -14,8 +14,6 @@
sections: sections:
- local: using-diffusers/schedulers - local: using-diffusers/schedulers
title: Load schedulers and models title: Load schedulers and models
- local: using-diffusers/guiders
title: Guiders
- title: Inference - title: Inference
isExpanded: false isExpanded: false
@@ -82,6 +80,8 @@
title: ModularPipeline title: ModularPipeline
- local: modular_diffusers/components_manager - local: modular_diffusers/components_manager
title: ComponentsManager title: ComponentsManager
- local: modular_diffusers/guiders
title: Guiders
- title: Training - title: Training
isExpanded: false isExpanded: false

View File

@@ -12,7 +12,6 @@ from termcolor import colored
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import ( from diffusers import (
AutoencoderKLLTX2Video,
AutoencoderKLWan, AutoencoderKLWan,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler,
@@ -25,10 +24,7 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [ ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
"Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth",
]
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py # https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
@@ -96,22 +92,12 @@ def main(args):
if args.video_size == 480: if args.video_size == 480:
sample_size = 30 # Wan-VAE: 8xp2 downsample factor sample_size = 30 # Wan-VAE: 8xp2 downsample factor
patch_size = (1, 2, 2) patch_size = (1, 2, 2)
in_channels = 16
out_channels = 16
elif args.video_size == 720: elif args.video_size == 720:
sample_size = 22 # DC-AE-V: 32xp1 downsample factor sample_size = 22 # Wan-VAE: 32xp1 downsample factor
patch_size = (1, 1, 1) patch_size = (1, 1, 1)
in_channels = 32
out_channels = 32
else: else:
raise ValueError(f"Video size {args.video_size} is not supported.") raise ValueError(f"Video size {args.video_size} is not supported.")
if args.vae_type == "ltx2":
sample_size = 22
patch_size = (1, 1, 1)
in_channels = 128
out_channels = 128
for depth in range(layer_num): for depth in range(layer_num):
# Transformer blocks. # Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
@@ -196,8 +182,8 @@ def main(args):
# Transformer # Transformer
with CTX(): with CTX():
transformer_kwargs = { transformer_kwargs = {
"in_channels": in_channels, "in_channels": 16,
"out_channels": out_channels, "out_channels": 16,
"num_attention_heads": 20, "num_attention_heads": 20,
"attention_head_dim": 112, "attention_head_dim": 112,
"num_layers": 20, "num_layers": 20,
@@ -249,12 +235,9 @@ def main(args):
else: else:
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"])) print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE # VAE
if args.vae_type == "ltx2": vae = AutoencoderKLWan.from_pretrained(
vae_path = args.vae_path or "Lightricks/LTX-2" "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32) )
else:
vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
# Text Encoder # Text Encoder
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it" text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
@@ -331,23 +314,7 @@ if __name__ == "__main__":
choices=["flow-dpm_solver", "flow-euler", "uni-pc"], choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.", help="Scheduler type to use.",
) )
parser.add_argument( parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
"--vae_type",
default="wan",
type=str,
choices=["wan", "ltx2"],
help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).",
)
parser.add_argument(
"--vae_path",
default=None,
type=str,
required=False,
help="Optional VAE path or repo id. If not set, a default is used per VAE type.",
)
parser.add_argument(
"--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v."
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.") parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.") parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")

View File

@@ -510,7 +510,6 @@ else:
"EasyAnimateControlPipeline", "EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline", "EasyAnimateInpaintPipeline",
"EasyAnimatePipeline", "EasyAnimatePipeline",
"Flux2KleinKVPipeline",
"Flux2KleinPipeline", "Flux2KleinPipeline",
"Flux2Pipeline", "Flux2Pipeline",
"FluxControlImg2ImgPipeline", "FluxControlImg2ImgPipeline",
@@ -1267,7 +1266,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateControlPipeline, EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline, EasyAnimateInpaintPipeline,
EasyAnimatePipeline, EasyAnimatePipeline,
Flux2KleinKVPipeline,
Flux2KleinPipeline, Flux2KleinPipeline,
Flux2Pipeline, Flux2Pipeline,
FluxControlImg2ImgPipeline, FluxControlImg2ImgPipeline,

View File

@@ -2538,12 +2538,8 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
def get_alpha_scales(down_weight, alpha_key): def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0] rank = down_weight.shape[0]
alpha_tensor = state_dict.pop(alpha_key, None) alpha = state_dict.pop(alpha_key).item()
if alpha_tensor is None: scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
return 1.0, 1.0
scale = (
alpha_tensor.item() / rank
) # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale scale_down = scale
scale_up = 1.0 scale_up = 1.0
while scale_down * 2 < scale_up: while scale_down * 2 < scale_up:

View File

@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from dataclasses import dataclass
from typing import Any from typing import Any
import torch import torch
@@ -22,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import BaseOutput, apply_lora_scale, logging from ...utils import apply_lora_scale, logging
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn from ..attention_dispatch import dispatch_attention_fn
@@ -33,6 +32,7 @@ from ..embeddings import (
apply_rotary_emb, apply_rotary_emb,
get_1d_rotary_pos_embed, get_1d_rotary_pos_embed,
) )
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous from ..normalization import AdaLayerNormContinuous
@@ -40,216 +40,6 @@ from ..normalization import AdaLayerNormContinuous
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class Flux2Transformer2DModelOutput(BaseOutput):
"""
The output of [`Flux2Transformer2DModel`].
Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output conditioned on the `encoder_hidden_states` input.
kv_cache (`Flux2KVCache`, *optional*):
The populated KV cache for reference image tokens. Only returned when `kv_cache_mode="extract"`.
"""
sample: "torch.Tensor" # noqa: F821
kv_cache: "Flux2KVCache | None" = None
class Flux2KVLayerCache:
"""Per-layer KV cache for reference image tokens in the Flux2 Klein KV model.
Stores the K and V projections (post-RoPE) for reference tokens extracted during the first denoising step. Tensor
format: (batch_size, num_ref_tokens, num_heads, head_dim).
"""
def __init__(self):
self.k_ref: torch.Tensor | None = None
self.v_ref: torch.Tensor | None = None
def store(self, k_ref: torch.Tensor, v_ref: torch.Tensor):
"""Store reference token K/V."""
self.k_ref = k_ref
self.v_ref = v_ref
def get(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Retrieve cached reference token K/V."""
if self.k_ref is None:
raise RuntimeError("KV cache has not been populated yet.")
return self.k_ref, self.v_ref
def clear(self):
self.k_ref = None
self.v_ref = None
class Flux2KVCache:
"""Container for all layers' reference-token KV caches.
Holds separate cache lists for double-stream and single-stream transformer blocks.
"""
def __init__(self, num_double_layers: int, num_single_layers: int):
self.double_block_caches = [Flux2KVLayerCache() for _ in range(num_double_layers)]
self.single_block_caches = [Flux2KVLayerCache() for _ in range(num_single_layers)]
self.num_ref_tokens: int = 0
def get_double(self, layer_idx: int) -> Flux2KVLayerCache:
return self.double_block_caches[layer_idx]
def get_single(self, layer_idx: int) -> Flux2KVLayerCache:
return self.single_block_caches[layer_idx]
def clear(self):
for cache in self.double_block_caches:
cache.clear()
for cache in self.single_block_caches:
cache.clear()
self.num_ref_tokens = 0
def _flux2_kv_causal_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_txt_tokens: int,
num_ref_tokens: int,
kv_cache: Flux2KVLayerCache | None = None,
backend=None,
) -> torch.Tensor:
"""Causal attention for KV caching where reference tokens only self-attend.
All tensors use the diffusers convention: (batch_size, seq_len, num_heads, head_dim).
Without cache (extract mode): sequence layout is [txt, ref, img]. txt+img tokens attend to all tokens, ref tokens
only attend to themselves. With cache (cached mode): sequence layout is [txt, img]. Cached ref K/V are injected
between txt and img.
"""
# No ref tokens and no cache — standard full attention
if num_ref_tokens == 0 and kv_cache is None:
return dispatch_attention_fn(query, key, value, backend=backend)
if kv_cache is not None:
# Cached mode: inject ref K/V between txt and img
k_ref, v_ref = kv_cache.get()
k_all = torch.cat([key[:, :num_txt_tokens], k_ref, key[:, num_txt_tokens:]], dim=1)
v_all = torch.cat([value[:, :num_txt_tokens], v_ref, value[:, num_txt_tokens:]], dim=1)
return dispatch_attention_fn(query, k_all, v_all, backend=backend)
# Extract mode: ref tokens self-attend, txt+img attend to all
ref_start = num_txt_tokens
ref_end = num_txt_tokens + num_ref_tokens
q_txt = query[:, :ref_start]
q_ref = query[:, ref_start:ref_end]
q_img = query[:, ref_end:]
k_txt = key[:, :ref_start]
k_ref = key[:, ref_start:ref_end]
k_img = key[:, ref_end:]
v_txt = value[:, :ref_start]
v_ref = value[:, ref_start:ref_end]
v_img = value[:, ref_end:]
# txt+img attend to all tokens
q_txt_img = torch.cat([q_txt, q_img], dim=1)
k_all = torch.cat([k_txt, k_ref, k_img], dim=1)
v_all = torch.cat([v_txt, v_ref, v_img], dim=1)
attn_txt_img = dispatch_attention_fn(q_txt_img, k_all, v_all, backend=backend)
attn_txt = attn_txt_img[:, :ref_start]
attn_img = attn_txt_img[:, ref_start:]
# ref tokens self-attend only
attn_ref = dispatch_attention_fn(q_ref, k_ref, v_ref, backend=backend)
return torch.cat([attn_txt, attn_ref, attn_img], dim=1)
def _blend_mod_params(
img_params: tuple[torch.Tensor, ...],
ref_params: tuple[torch.Tensor, ...],
num_ref: int,
seq_len: int,
) -> tuple[torch.Tensor, ...]:
"""Blend modulation parameters so that the first `num_ref` positions use `ref_params`."""
blended = []
for im, rm in zip(img_params, ref_params):
if im.ndim == 2:
im = im.unsqueeze(1)
rm = rm.unsqueeze(1)
B = im.shape[0]
blended.append(
torch.cat(
[rm.expand(B, num_ref, -1), im.expand(B, seq_len, -1)[:, num_ref:, :]],
dim=1,
)
)
return tuple(blended)
def _blend_double_block_mods(
img_mod: torch.Tensor,
ref_mod: torch.Tensor,
num_ref: int,
seq_len: int,
) -> torch.Tensor:
"""Blend double-block image-stream modulations for a [ref, img] sequence layout.
Takes raw modulation tensors (before `Flux2Modulation.split`) and returns a blended raw tensor that is compatible
with `Flux2Modulation.split(mod, 2)`.
"""
if img_mod.ndim == 2:
img_mod = img_mod.unsqueeze(1)
ref_mod = ref_mod.unsqueeze(1)
img_chunks = torch.chunk(img_mod, 6, dim=-1)
ref_chunks = torch.chunk(ref_mod, 6, dim=-1)
img_mods = (img_chunks[0:3], img_chunks[3:6])
ref_mods = (ref_chunks[0:3], ref_chunks[3:6])
all_params = []
for img_set, ref_set in zip(img_mods, ref_mods):
blended = _blend_mod_params(img_set, ref_set, num_ref, seq_len)
all_params.extend(blended)
return torch.cat(all_params, dim=-1)
def _blend_single_block_mods(
single_mod: torch.Tensor,
ref_mod: torch.Tensor,
num_txt: int,
num_ref: int,
seq_len: int,
) -> torch.Tensor:
"""Blend single-block modulations for a [txt, ref, img] sequence layout.
Takes raw modulation tensors and returns a blended raw tensor compatible with `Flux2Modulation.split(mod, 1)`.
"""
if single_mod.ndim == 2:
single_mod = single_mod.unsqueeze(1)
ref_mod = ref_mod.unsqueeze(1)
img_params = torch.chunk(single_mod, 3, dim=-1)
ref_params = torch.chunk(ref_mod, 3, dim=-1)
blended = []
for im, rm in zip(img_params, ref_params):
if im.ndim == 2:
im = im.unsqueeze(1)
rm = rm.unsqueeze(1)
B = im.shape[0]
im_expanded = im.expand(B, seq_len, -1)
rm_expanded = rm.expand(B, num_ref, -1)
blended.append(
torch.cat(
[im_expanded[:, :num_txt, :], rm_expanded, im_expanded[:, num_txt + num_ref :, :]],
dim=1,
)
)
return torch.cat(blended, dim=-1)
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states) key = attn.to_k(hidden_states)
@@ -391,108 +181,9 @@ class Flux2AttnProcessor:
return hidden_states return hidden_states
class Flux2KVAttnProcessor:
"""
Attention processor for Flux2 double-stream blocks with KV caching support for reference image tokens.
When `kv_cache_mode` is "extract", reference token K/V are stored in the cache after RoPE and causal attention is
used (ref tokens self-attend only, txt+img attend to all). When `kv_cache_mode` is "cached", cached ref K/V are
injected during attention. When no KV args are provided, behaves identically to `Flux2AttnProcessor`.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "Flux2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
kv_cache: Flux2KVLayerCache | None = None,
kv_cache_mode: str | None = None,
num_ref_tokens: int = 0,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
num_txt_tokens = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
# Extract ref K/V from the combined sequence
if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
ref_start = num_txt_tokens
ref_end = num_txt_tokens + num_ref_tokens
kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
# Dispatch attention
if kv_cache_mode == "extract" and num_ref_tokens > 0:
hidden_states = _flux2_kv_causal_attention(
query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
)
elif kv_cache_mode == "cached" and kv_cache is not None:
hidden_states = _flux2_kv_causal_attention(
query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
)
else:
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class Flux2Attention(torch.nn.Module, AttentionModuleMixin): class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = Flux2AttnProcessor _default_processor_cls = Flux2AttnProcessor
_available_processors = [Flux2AttnProcessor, Flux2KVAttnProcessor] _available_processors = [Flux2AttnProcessor]
def __init__( def __init__(
self, self,
@@ -621,90 +312,6 @@ class Flux2ParallelSelfAttnProcessor:
return hidden_states return hidden_states
class Flux2KVParallelSelfAttnProcessor:
"""
Attention processor for Flux2 single-stream blocks with KV caching support for reference image tokens.
When `kv_cache_mode` is "extract", reference token K/V are stored and causal attention is used. When
`kv_cache_mode` is "cached", cached ref K/V are injected during attention. When no KV args are provided, behaves
identically to `Flux2ParallelSelfAttnProcessor`.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "Flux2ParallelSelfAttention",
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
kv_cache: Flux2KVLayerCache | None = None,
kv_cache_mode: str | None = None,
num_txt_tokens: int = 0,
num_ref_tokens: int = 0,
) -> torch.Tensor:
# Parallel in (QKV + MLP in) projection
hidden_states_proj = attn.to_qkv_mlp_proj(hidden_states)
qkv, mlp_hidden_states = torch.split(
hidden_states_proj, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
)
query, key, value = qkv.chunk(3, dim=-1)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
# Extract ref K/V from the combined sequence
if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
ref_start = num_txt_tokens
ref_end = num_txt_tokens + num_ref_tokens
kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
# Dispatch attention
if kv_cache_mode == "extract" and num_ref_tokens > 0:
attn_output = _flux2_kv_causal_attention(
query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
)
elif kv_cache_mode == "cached" and kv_cache is not None:
attn_output = _flux2_kv_causal_attention(
query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
)
else:
attn_output = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
attn_output = attn_output.flatten(2, 3)
attn_output = attn_output.to(query.dtype)
# Handle the feedforward (FF) logic
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
# Concatenate and parallel output projection
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1)
hidden_states = attn.to_out(hidden_states)
return hidden_states
class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin): class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
""" """
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
@@ -715,7 +322,7 @@ class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
""" """
_default_processor_cls = Flux2ParallelSelfAttnProcessor _default_processor_cls = Flux2ParallelSelfAttnProcessor
_available_processors = [Flux2ParallelSelfAttnProcessor, Flux2KVParallelSelfAttnProcessor] _available_processors = [Flux2ParallelSelfAttnProcessor]
# Does not support QKV fusion as the QKV projections are always fused # Does not support QKV fusion as the QKV projections are always fused
_supports_qkv_fusion = False _supports_qkv_fusion = False
@@ -1173,8 +780,6 @@ class Flux2Transformer2DModel(
self.gradient_checkpointing = False self.gradient_checkpointing = False
_skip_keys = ["kv_cache"]
@apply_lora_scale("joint_attention_kwargs") @apply_lora_scale("joint_attention_kwargs")
def forward( def forward(
self, self,
@@ -1186,21 +791,19 @@ class Flux2Transformer2DModel(
guidance: torch.Tensor = None, guidance: torch.Tensor = None,
joint_attention_kwargs: dict[str, Any] | None = None, joint_attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True, return_dict: bool = True,
kv_cache: "Flux2KVCache | None" = None, ) -> torch.Tensor | Transformer2DModelOutput:
kv_cache_mode: str | None = None,
num_ref_tokens: int = 0,
ref_fixed_timestep: float = 0.0,
) -> torch.Tensor | Flux2Transformer2DModelOutput:
""" """
The [`Flux2Transformer2DModel`] forward method. The [`FluxTransformer2DModel`] forward method.
Args: Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`. Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
timestep (`torch.LongTensor`): timestep ( `torch.LongTensor`):
Used to indicate denoising step. Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*): joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
@@ -1208,23 +811,13 @@ class Flux2Transformer2DModel(
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple. tuple.
kv_cache (`Flux2KVCache`, *optional*):
KV cache for reference image tokens. When `kv_cache_mode` is "extract", a new cache is created and
returned. When "cached", the provided cache is used to inject ref K/V during attention.
kv_cache_mode (`str`, *optional*):
One of "extract" (first step with ref tokens) or "cached" (subsequent steps using cached ref K/V). When
`None`, standard forward pass without KV caching.
num_ref_tokens (`int`, defaults to `0`):
Number of reference image tokens prepended to `hidden_states` (only used when
`kv_cache_mode="extract"`).
ref_fixed_timestep (`float`, defaults to `0.0`):
Fixed timestep for reference token modulation (only used when `kv_cache_mode="extract"`).
Returns: Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor. When `kv_cache_mode="extract"`, also returns the `tuple` where the first element is the sample tensor.
populated `Flux2KVCache`.
""" """
# 0. Handle input arguments
num_txt_tokens = encoder_hidden_states.shape[1] num_txt_tokens = encoder_hidden_states.shape[1]
# 1. Calculate timestep embedding and modulation parameters # 1. Calculate timestep embedding and modulation parameters
@@ -1239,33 +832,13 @@ class Flux2Transformer2DModel(
double_stream_mod_txt = self.double_stream_modulation_txt(temb) double_stream_mod_txt = self.double_stream_modulation_txt(temb)
single_stream_mod = self.single_stream_modulation(temb) single_stream_mod = self.single_stream_modulation(temb)
# KV extract mode: create cache and blend modulations for ref tokens
if kv_cache_mode == "extract" and num_ref_tokens > 0:
num_img_tokens = hidden_states.shape[1] # includes ref tokens
kv_cache = Flux2KVCache(
num_double_layers=len(self.transformer_blocks),
num_single_layers=len(self.single_transformer_blocks),
)
kv_cache.num_ref_tokens = num_ref_tokens
# Ref tokens use a fixed timestep for modulation
ref_timestep = torch.full_like(timestep, ref_fixed_timestep * 1000)
ref_temb = self.time_guidance_embed(ref_timestep, guidance)
ref_double_mod_img = self.double_stream_modulation_img(ref_temb)
ref_single_mod = self.single_stream_modulation(ref_temb)
# Blend double block img modulation: [ref_mod, img_mod]
double_stream_mod_img = _blend_double_block_mods(
double_stream_mod_img, ref_double_mod_img, num_ref_tokens, num_img_tokens
)
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states) # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
hidden_states = self.x_embedder(hidden_states) hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states)
# 3. Calculate RoPE embeddings from image and text tokens # 3. Calculate RoPE embeddings from image and text tokens
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
# text prompts of differents lengths. Is this a use case we want to support?
if img_ids.ndim == 3: if img_ids.ndim == 3:
img_ids = img_ids[0] img_ids = img_ids[0]
if txt_ids.ndim == 3: if txt_ids.ndim == 3:
@@ -1278,29 +851,8 @@ class Flux2Transformer2DModel(
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
) )
# 4. Build joint_attention_kwargs with KV cache info # 4. Double Stream Transformer Blocks
if kv_cache_mode == "extract":
kv_attn_kwargs = {
**(joint_attention_kwargs or {}),
"kv_cache": None,
"kv_cache_mode": "extract",
"num_ref_tokens": num_ref_tokens,
}
elif kv_cache_mode == "cached" and kv_cache is not None:
kv_attn_kwargs = {
**(joint_attention_kwargs or {}),
"kv_cache": None,
"kv_cache_mode": "cached",
"num_ref_tokens": kv_cache.num_ref_tokens,
}
else:
kv_attn_kwargs = joint_attention_kwargs
# 5. Double Stream Transformer Blocks
for index_block, block in enumerate(self.transformer_blocks): for index_block, block in enumerate(self.transformer_blocks):
if kv_cache_mode is not None and kv_cache is not None:
kv_attn_kwargs["kv_cache"] = kv_cache.get_double(index_block)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block, block,
@@ -1309,7 +861,7 @@ class Flux2Transformer2DModel(
double_stream_mod_img, double_stream_mod_img,
double_stream_mod_txt, double_stream_mod_txt,
concat_rotary_emb, concat_rotary_emb,
kv_attn_kwargs, joint_attention_kwargs,
) )
else: else:
encoder_hidden_states, hidden_states = block( encoder_hidden_states, hidden_states = block(
@@ -1318,30 +870,13 @@ class Flux2Transformer2DModel(
temb_mod_img=double_stream_mod_img, temb_mod_img=double_stream_mod_img,
temb_mod_txt=double_stream_mod_txt, temb_mod_txt=double_stream_mod_txt,
image_rotary_emb=concat_rotary_emb, image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=kv_attn_kwargs, joint_attention_kwargs=joint_attention_kwargs,
) )
# Concatenate text and image streams for single-block inference # Concatenate text and image streams for single-block inference
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# Blend single block modulation for extract mode: [txt_mod, ref_mod, img_mod] # 5. Single Stream Transformer Blocks
if kv_cache_mode == "extract" and num_ref_tokens > 0:
total_single_len = hidden_states.shape[1]
single_stream_mod = _blend_single_block_mods(
single_stream_mod, ref_single_mod, num_txt_tokens, num_ref_tokens, total_single_len
)
# Build single-block KV kwargs (single blocks need num_txt_tokens)
if kv_cache_mode is not None:
kv_attn_kwargs_single = {**kv_attn_kwargs, "num_txt_tokens": num_txt_tokens}
else:
kv_attn_kwargs_single = kv_attn_kwargs
# 6. Single Stream Transformer Blocks
for index_block, block in enumerate(self.single_transformer_blocks): for index_block, block in enumerate(self.single_transformer_blocks):
if kv_cache_mode is not None and kv_cache is not None:
kv_attn_kwargs_single["kv_cache"] = kv_cache.get_single(index_block)
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func( hidden_states = self._gradient_checkpointing_func(
block, block,
@@ -1349,7 +884,7 @@ class Flux2Transformer2DModel(
None, None,
single_stream_mod, single_stream_mod,
concat_rotary_emb, concat_rotary_emb,
kv_attn_kwargs_single, joint_attention_kwargs,
) )
else: else:
hidden_states = block( hidden_states = block(
@@ -1357,25 +892,16 @@ class Flux2Transformer2DModel(
encoder_hidden_states=None, encoder_hidden_states=None,
temb_mod=single_stream_mod, temb_mod=single_stream_mod,
image_rotary_emb=concat_rotary_emb, image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=kv_attn_kwargs_single, joint_attention_kwargs=joint_attention_kwargs,
) )
# Remove text tokens from concatenated stream
hidden_states = hidden_states[:, num_txt_tokens:, ...]
# Remove text tokens (and ref tokens in extract mode) from concatenated stream # 6. Output layers
if kv_cache_mode == "extract" and num_ref_tokens > 0:
hidden_states = hidden_states[:, num_txt_tokens + num_ref_tokens :, ...]
else:
hidden_states = hidden_states[:, num_txt_tokens:, ...]
# 7. Output layers
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states) output = self.proj_out(hidden_states)
if kv_cache_mode == "extract":
if not return_dict:
return (output, kv_cache)
return Flux2Transformer2DModelOutput(sample=output, kv_cache=kv_cache)
if not return_dict: if not return_dict:
return (output,) return (output,)
return Flux2Transformer2DModelOutput(sample=output) return Transformer2DModelOutput(sample=output)

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from functools import lru_cache
from typing import Any from typing import Any
import torch import torch
@@ -342,6 +343,7 @@ class HeliosRotaryPosEmbed(nn.Module):
return freqs.cos(), freqs.sin() return freqs.cos(), freqs.sin()
@torch.no_grad() @torch.no_grad()
@lru_cache(maxsize=32)
def _get_spatial_meshgrid(self, height, width, device_str): def _get_spatial_meshgrid(self, height, width, device_str):
device = torch.device(device_str) device = torch.device(device_str)
grid_y_coords = torch.arange(height, device=device, dtype=torch.float32) grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)

View File

@@ -309,16 +309,16 @@ class ComponentSpec:
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}" f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
) )
from diffusers import AutoModel
# `torch_dtype` is not an accepted parameter for tokenizers and processors. # `torch_dtype` is not an accepted parameter for tokenizers and processors.
# As a result, it gets stored in `init_kwargs`, which are written to the config # As a result, it gets stored in `init_kwargs`, which are written to the config
# during save. This causes JSON serialization to fail when saving the component. # during save. This causes JSON serialization to fail when saving the component.
if self.type_hint is not None and not issubclass(self.type_hint, (torch.nn.Module, AutoModel)): if self.type_hint is not None and not issubclass(self.type_hint, torch.nn.Module):
kwargs.pop("torch_dtype", None) kwargs.pop("torch_dtype", None)
if self.type_hint is None: if self.type_hint is None:
try: try:
from diffusers import AutoModel
component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs) component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e: except Exception as e:
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}") raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
@@ -332,6 +332,12 @@ class ComponentSpec:
else getattr(self.type_hint, "from_pretrained") else getattr(self.type_hint, "from_pretrained")
) )
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
# As a result, it gets stored in `init_kwargs`, which are written to the config
# during save. This causes JSON serialization to fail when saving the component.
if not issubclass(self.type_hint, torch.nn.Module):
kwargs.pop("torch_dtype", None)
try: try:
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs) component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e: except Exception as e:

View File

@@ -129,7 +129,7 @@ else:
] ]
_import_structure["bria"] = ["BriaPipeline"] _import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"] _import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinKVPipeline"] _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
_import_structure["flux"] = [ _import_structure["flux"] = [
"FluxControlPipeline", "FluxControlPipeline",
"FluxControlInpaintPipeline", "FluxControlInpaintPipeline",
@@ -671,7 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxPriorReduxPipeline, FluxPriorReduxPipeline,
ReduxImageEncoder, ReduxImageEncoder,
) )
from .flux2 import Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline from .flux2 import Flux2KleinPipeline, Flux2Pipeline
from .glm_image import GlmImagePipeline from .glm_image import GlmImagePipeline
from .helios import HeliosPipeline, HeliosPyramidPipeline from .helios import HeliosPipeline, HeliosPyramidPipeline
from .hidream_image import HiDreamImagePipeline from .hidream_image import HiDreamImagePipeline

View File

@@ -24,7 +24,6 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"] _import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
_import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"] _import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
_import_structure["pipeline_flux2_klein_kv"] = ["Flux2KleinKVPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
@@ -34,7 +33,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .pipeline_flux2 import Flux2Pipeline from .pipeline_flux2 import Flux2Pipeline
from .pipeline_flux2_klein import Flux2KleinPipeline from .pipeline_flux2_klein import Flux2KleinPipeline
from .pipeline_flux2_klein_kv import Flux2KleinKVPipeline
else: else:
import sys import sys

View File

@@ -744,7 +744,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
image: PIL.Image.Image | list[PIL.Image.Image] | None = None, image: list[PIL.Image.Image, PIL.Image.Image] | None = None,
prompt: str | list[str] = None, prompt: str | list[str] = None,
height: int | None = None, height: int | None = None,
width: int | None = None, width: int | None = None,

View File

@@ -1,886 +0,0 @@
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable
import numpy as np
import PIL
import torch
from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
from ...loaders import Flux2LoraLoaderMixin
from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
from ...models.transformers.transformer_flux2 import Flux2KVAttnProcessor, Flux2KVParallelSelfAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .image_processor import Flux2ImageProcessor
from .pipeline_output import Flux2PipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from PIL import Image
>>> from diffusers import Flux2KleinKVPipeline
>>> pipe = Flux2KleinKVPipeline.from_pretrained(
... "black-forest-labs/FLUX.2-klein-9b-kv", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> ref_image = Image.open("reference.png")
>>> image = pipe("A cat dressed like a wizard", image=ref_image, num_inference_steps=4).images[0]
>>> image.save("flux2_kv_output.png")
```
"""
# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: int | None = None,
device: str | torch.device | None = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`list[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`list[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class Flux2KleinKVPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
r"""
The Flux2 Klein KV pipeline for text-to-image generation with KV-cached reference image conditioning.
On the first denoising step, reference image tokens are included in the forward pass and their attention K/V
projections are cached. On subsequent steps, the cached K/V are reused without recomputing, providing faster
inference when using reference images.
Reference:
[https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence)
Args:
transformer ([`Flux2Transformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLFlux2`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`Qwen3ForCausalLM`]):
[Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM)
tokenizer (`Qwen2TokenizerFast`):
Tokenizer of class
[Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLFlux2,
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
transformer: Flux2Transformer2DModel,
is_distilled: bool = True,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = 512
self.default_sample_size = 128
# Set KV-cache-aware attention processors
self._set_kv_attn_processors()
@staticmethod
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: str | list[str],
dtype: torch.dtype | None = None,
device: torch.device | None = None,
max_sequence_length: int = 512,
hidden_states_layers: list[int] = (9, 18, 27),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids
def _prepare_text_ids(
x: torch.Tensor, # (B, L, D) or (L, D)
t_coord: torch.Tensor | None = None,
):
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, l)
out_ids.append(coords)
return torch.stack(out_ids)
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids
def _prepare_latent_ids(
latents: torch.Tensor, # (B, C, H, W)
):
r"""
Generates 4D position coordinates (T, H, W, L) for latent tensors.
Args:
latents (torch.Tensor):
Latent tensor of shape (B, C, H, W)
Returns:
torch.Tensor:
Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
H=[0..H-1], W=[0..W-1], L=0
"""
batch_size, _, height, width = latents.shape
t = torch.arange(1) # [0] - time dimension
h = torch.arange(height)
w = torch.arange(width)
l = torch.arange(1) # [0] - layer dimension
# Create position IDs: (H*W, 4)
latent_ids = torch.cartesian_prod(t, h, w, l)
# Expand to batch: (B, H*W, 4)
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
return latent_ids
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
def _prepare_image_ids(
image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
scale: int = 10,
):
r"""
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
This function creates a unique coordinate for every pixel/patch across all input latent with different
dimensions.
Args:
image_latents (list[torch.Tensor]):
A list of image latent feature tensors, typically of shape (C, H, W).
scale (int, optional):
A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
latent is: 'scale + scale * i'. Defaults to 10.
Returns:
torch.Tensor:
The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
input latents.
Coordinate Components (Dimension 4):
- T (Time): The unique index indicating which latent image the coordinate belongs to.
- H (Height): The row index within that latent image.
- W (Width): The column index within that latent image.
- L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
"""
if not isinstance(image_latents, list):
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
# create time offset for each reference image
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
t_coords = [t.view(-1) for t in t_coords]
image_latent_ids = []
for x, t in zip(image_latents, t_coords):
x = x.squeeze(0)
_, height, width = x.shape
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
image_latent_ids.append(x_ids)
image_latent_ids = torch.cat(image_latent_ids, dim=0)
image_latent_ids = image_latent_ids.unsqueeze(0)
return image_latent_ids
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents
def _patchify_latents(latents):
batch_size, num_channels_latents, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 1, 3, 5, 2, 4)
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents
def _unpatchify_latents(latents):
batch_size, num_channels_latents, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
latents = latents.permute(0, 1, 4, 2, 5, 3)
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents
def _pack_latents(latents):
"""
pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
"""
batch_size, num_channels, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
"""
using position ids to scatter tokens into place
"""
x_list = []
for data, pos in zip(x, x_ids):
_, ch = data.shape # noqa: F841
h_ids = pos[:, 1].to(torch.int64)
w_ids = pos[:, 2].to(torch.int64)
h = torch.max(h_ids) + 1
w = torch.max(w_ids) + 1
flat_ids = h_ids * w + w_ids
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
# reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
out = out.view(h, w, ch).permute(2, 0, 1)
x_list.append(out)
return torch.stack(x_list, dim=0)
def _set_kv_attn_processors(self):
"""Replace default attention processors with KV-cache-aware variants."""
for block in self.transformer.transformer_blocks:
block.attn.set_processor(Flux2KVAttnProcessor())
for block in self.transformer.single_transformer_blocks:
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
def encode_prompt(
self,
prompt: str | list[str],
device: torch.device | None = None,
num_images_per_prompt: int = 1,
prompt_embeds: torch.Tensor | None = None,
max_sequence_length: int = 512,
text_encoder_out_layers: tuple[int] = (9, 18, 27),
):
device = device or self._execution_device
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
prompt=prompt,
device=device,
max_sequence_length=max_sequence_length,
hidden_states_layers=text_encoder_out_layers,
)
batch_size, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
text_ids = self._prepare_text_ids(prompt_embeds)
text_ids = text_ids.to(device)
return prompt_embeds, text_ids
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if image.ndim != 4:
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
image_latents = self._patchify_latents(image_latents)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
return image_latents
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents
def prepare_latents(
self,
batch_size,
num_latents_channels,
height,
width,
dtype,
device,
generator: torch.Generator,
latents: torch.Tensor | None = None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
latent_ids = self._prepare_latent_ids(latents)
latent_ids = latent_ids.to(device)
latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
return latents, latent_ids
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
def prepare_image_latents(
self,
images: list[torch.Tensor],
batch_size,
generator: torch.Generator,
device,
dtype,
):
image_latents = []
for image in images:
image = image.to(device=device, dtype=dtype)
imagge_latent = self._encode_vae_image(image=image, generator=generator)
image_latents.append(imagge_latent) # (1, 128, 32, 32)
image_latent_ids = self._prepare_image_ids(image_latents)
# Pack each latent and concatenate
packed_latents = []
for latent in image_latents:
# latent: (1, 128, 32, 32)
packed = self._pack_latents(latent) # (1, 1024, 128)
packed = packed.squeeze(0) # (1024, 128) - remove batch dim
packed_latents.append(packed)
# Concatenate all reference tokens along sequence dimension
image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
image_latents = image_latents.repeat(batch_size, 1, 1)
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
image_latent_ids = image_latent_ids.to(device)
return image_latents, image_latent_ids
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if (
height is not None
and height % (self.vae_scale_factor * 2) != 0
or width is not None
and width % (self.vae_scale_factor * 2) != 0
):
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: list[PIL.Image.Image] | PIL.Image.Image | None = None,
prompt: str | list[str] = None,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 4,
sigmas: list[float] | None = None,
num_images_per_prompt: int = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
prompt_embeds: torch.Tensor | None = None,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
text_encoder_out_layers: tuple[int] = (9, 18, 27),
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*):
Reference image(s) for conditioning. On the first denoising step, reference tokens are included in the
forward pass and their attention K/V are cached. On subsequent steps, the cached K/V are reused without
recomputing.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 4):
The number of denoising steps.
sigmas (`List[float]`, *optional*):
Custom sigmas for the denoising schedule.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
Generator(s) for deterministic generation.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
Output format: `"pil"` or `"np"`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a `Flux2PipelineOutput` or a plain tuple.
attention_kwargs (`dict`, *optional*):
Extra kwargs passed to attention processors.
callback_on_step_end (`Callable`, *optional*):
Callback function called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List`, *optional*):
Tensor inputs for the callback function.
max_sequence_length (`int`, defaults to 512):
Maximum sequence length for the prompt.
text_encoder_out_layers (`tuple[int]`):
Layer indices for text encoder hidden state extraction.
Examples:
Returns:
[`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`.
"""
# 1. Check inputs
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. prepare text embeddings
prompt_embeds, text_ids = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
text_encoder_out_layers=text_encoder_out_layers,
)
# 4. process images
if image is not None and not isinstance(image, list):
image = [image]
condition_images = None
if image is not None:
for img in image:
self.image_processor.check_image_input(img)
condition_images = []
for img in image:
image_width, image_height = img.size
if image_width * image_height > 1024 * 1024:
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
image_width, image_height = img.size
multiple_of = self.vae_scale_factor * 2
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
condition_images.append(img)
height = height or image_height
width = width or image_width
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 5. prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_ids = self.prepare_latents(
batch_size=batch_size * num_images_per_prompt,
num_latents_channels=num_channels_latents,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
)
image_latents = None
image_latent_ids = None
if condition_images is not None:
image_latents, image_latent_ids = self.prepare_image_latents(
images=condition_images,
batch_size=batch_size * num_images_per_prompt,
generator=generator,
device=device,
dtype=self.vae.dtype,
)
# 6. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
sigmas = None
image_seq_len = latents.shape[1]
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 7. Denoising loop with KV caching
# Step 0 with ref images: forward_kv_extract (full pass, cache ref K/V)
# Steps 1+: forward_kv_cached (reuse cached ref K/V)
# No ref images: standard forward
self.scheduler.set_begin_index(0)
kv_cache = None
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if i == 0 and image_latents is not None:
# Step 0: include ref tokens, extract KV cache
latent_model_input = torch.cat([image_latents, latents], dim=1).to(self.transformer.dtype)
latent_image_ids = torch.cat([image_latent_ids, latent_ids], dim=1)
noise_pred, kv_cache = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
kv_cache_mode="extract",
num_ref_tokens=image_latents.shape[1],
)
elif kv_cache is not None:
# Steps 1+: use cached ref KV, no ref tokens in input
noise_pred = self.transformer(
hidden_states=latents.to(self.transformer.dtype),
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
kv_cache=kv_cache,
kv_cache_mode="cached",
)[0]
else:
# No reference images: standard forward
noise_pred = self.transformer(
hidden_states=latents.to(self.transformer.dtype),
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# Clean up KV cache
if kv_cache is not None:
kv_cache.clear()
self._current_timestep = None
latents = self._unpack_latents_with_ids(latents, latent_ids)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
latents.device, latents.dtype
)
latents = latents * latents_bn_std + latents_bn_mean
latents = self._unpatchify_latents(latents)
if output_type == "latent":
image = latents
else:
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return Flux2PipelineOutput(images=image)

View File

@@ -720,7 +720,6 @@ class LDMBertModel(LDMBertPreTrainedModel):
super().__init__(config) super().__init__(config)
self.model = LDMBertEncoder(config) self.model = LDMBertEncoder(config)
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
self.post_init()
def forward( def forward(
self, self,

View File

@@ -35,8 +35,6 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
# uncondition for scaling # uncondition for scaling
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
self.post_init()
def forward(self, pixel_values, return_uncond_vector=False): def forward(self, pixel_values, return_uncond_vector=False):
clip_output = self.model(pixel_values=pixel_values) clip_output = self.model(pixel_values=pixel_values)
latent_states = clip_output.pooler_output latent_states = clip_output.pooler_output

View File

@@ -24,7 +24,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import SanaLoraLoaderMixin from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...schedulers import DPMSolverMultistepScheduler from ...schedulers import DPMSolverMultistepScheduler
from ...utils import ( from ...utils import (
BACKENDS_MAPPING, BACKENDS_MAPPING,
@@ -194,7 +194,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
The tokenizer used to tokenize the prompt. The tokenizer used to tokenize the prompt.
text_encoder ([`Gemma2PreTrainedModel`]): text_encoder ([`Gemma2PreTrainedModel`]):
Text encoder model to encode the input prompts. Text encoder model to encode the input prompts.
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]): vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer ([`SanaVideoTransformer3DModel`]): transformer ([`SanaVideoTransformer3DModel`]):
Conditional Transformer to denoise the input latents. Conditional Transformer to denoise the input latents.
@@ -213,7 +213,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self, self,
tokenizer: GemmaTokenizer | GemmaTokenizerFast, tokenizer: GemmaTokenizer | GemmaTokenizerFast,
text_encoder: Gemma2PreTrainedModel, text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan, vae: AutoencoderDC | AutoencoderKLWan,
transformer: SanaVideoTransformer3DModel, transformer: SanaVideoTransformer3DModel,
scheduler: DPMSolverMultistepScheduler, scheduler: DPMSolverMultistepScheduler,
): ):
@@ -223,19 +223,8 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
) )
if getattr(self, "vae", None): self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
if isinstance(self.vae, AutoencoderKLLTX2Video): self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
self.vae_scale_factor = self.vae_scale_factor_spatial self.vae_scale_factor = self.vae_scale_factor_spatial
@@ -996,21 +985,14 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if is_torch_version(">=", "2.5.0") if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError else torch_accelerator_module.OutOfMemoryError
) )
if isinstance(self.vae, AutoencoderKLLTX2Video): latents_mean = (
latents_mean = self.vae.latents_mean torch.tensor(self.vae.config.latents_mean)
latents_std = self.vae.latents_std .view(1, self.vae.config.z_dim, 1, 1, 1)
z_dim = self.vae.config.latent_channels .to(latents.device, latents.dtype)
elif isinstance(self.vae, AutoencoderKLWan): )
latents_mean = torch.tensor(self.vae.config.latents_mean) latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents_std = torch.tensor(self.vae.config.latents_std) latents.device, latents.dtype
z_dim = self.vae.config.z_dim )
else:
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
z_dim = latents.shape[1]
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean latents = latents / latents_std + latents_mean
try: try:
video = self.vae.decode(latents, return_dict=False)[0] video = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -26,7 +26,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput from ...image_processor import PipelineImageInput
from ...loaders import SanaLoraLoaderMixin from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import ( from ...utils import (
BACKENDS_MAPPING, BACKENDS_MAPPING,
@@ -184,7 +184,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
The tokenizer used to tokenize the prompt. The tokenizer used to tokenize the prompt.
text_encoder ([`Gemma2PreTrainedModel`]): text_encoder ([`Gemma2PreTrainedModel`]):
Text encoder model to encode the input prompts. Text encoder model to encode the input prompts.
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]): vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer ([`SanaVideoTransformer3DModel`]): transformer ([`SanaVideoTransformer3DModel`]):
Conditional Transformer to denoise the input latents. Conditional Transformer to denoise the input latents.
@@ -203,7 +203,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self, self,
tokenizer: GemmaTokenizer | GemmaTokenizerFast, tokenizer: GemmaTokenizer | GemmaTokenizerFast,
text_encoder: Gemma2PreTrainedModel, text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan, vae: AutoencoderDC | AutoencoderKLWan,
transformer: SanaVideoTransformer3DModel, transformer: SanaVideoTransformer3DModel,
scheduler: FlowMatchEulerDiscreteScheduler, scheduler: FlowMatchEulerDiscreteScheduler,
): ):
@@ -213,19 +213,8 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
) )
if getattr(self, "vae", None): self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
if isinstance(self.vae, AutoencoderKLLTX2Video): self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
self.vae_scale_factor = self.vae_scale_factor_spatial self.vae_scale_factor = self.vae_scale_factor_spatial
@@ -698,18 +687,14 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1) image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1)
if isinstance(self.vae, AutoencoderKLLTX2Video): latents_mean = (
_latents_mean = self.vae.latents_mean torch.tensor(self.vae.config.latents_mean)
_latents_std = self.vae.latents_std .view(1, -1, 1, 1, 1)
elif isinstance(self.vae, AutoencoderKLWan): .to(image_latents.device, image_latents.dtype)
_latents_mean = torch.tensor(self.vae.config.latents_mean) )
_latents_std = torch.tensor(self.vae.config.latents_std) latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(
else: image_latents.device, image_latents.dtype
_latents_mean = torch.zeros(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype) )
_latents_std = torch.ones(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
latents_mean = _latents_mean.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_std = 1.0 / _latents_std.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
image_latents = (image_latents - latents_mean) * latents_std image_latents = (image_latents - latents_mean) * latents_std
latents[:, :, 0:1] = image_latents.to(dtype) latents[:, :, 0:1] = image_latents.to(dtype)
@@ -1049,21 +1034,14 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if is_torch_version(">=", "2.5.0") if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError else torch_accelerator_module.OutOfMemoryError
) )
if isinstance(self.vae, AutoencoderKLLTX2Video): latents_mean = (
latents_mean = self.vae.latents_mean torch.tensor(self.vae.config.latents_mean)
latents_std = self.vae.latents_std .view(1, self.vae.config.z_dim, 1, 1, 1)
z_dim = self.vae.config.latent_channels .to(latents.device, latents.dtype)
elif isinstance(self.vae, AutoencoderKLWan): )
latents_mean = torch.tensor(self.vae.config.latents_mean) latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents_std = torch.tensor(self.vae.config.latents_std) latents.device, latents.dtype
z_dim = self.vae.config.z_dim )
else:
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
z_dim = latents.shape[1]
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean latents = latents / latents_std + latents_mean
try: try:
video = self.vae.decode(latents, return_dict=False)[0] video = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -1202,21 +1202,6 @@ class EasyAnimatePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class Flux2KleinKVPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinPipeline(metaclass=DummyObject): class Flux2KleinPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]

View File

@@ -26,17 +26,9 @@ from diffusers.models._modeling_parallel import ContextParallelConfig
from ...testing_utils import ( from ...testing_utils import (
is_context_parallel, is_context_parallel,
require_torch_multi_accelerator, require_torch_multi_accelerator,
torch_device,
) )
# Device configuration mapping
DEVICE_CONFIG = {
"cuda": {"backend": "nccl", "module": torch.cuda},
"xpu": {"backend": "xccl", "module": torch.xpu},
}
def _find_free_port(): def _find_free_port():
"""Find a free port on localhost.""" """Find a free port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -55,17 +47,12 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
os.environ["RANK"] = str(rank) os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
# Get device configuration
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
backend = device_config["backend"]
device_module = device_config["module"]
# Initialize process group # Initialize process group
dist.init_process_group(backend=backend, rank=rank, world_size=world_size) dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
# Set device for this process # Set device for this process
device_module.set_device(rank) torch.cuda.set_device(rank)
device = torch.device(f"{torch_device}:{rank}") device = torch.device(f"cuda:{rank}")
# Create model # Create model
model = model_class(**init_dict) model = model_class(**init_dict)
@@ -116,16 +103,10 @@ def _custom_mesh_worker(
os.environ["RANK"] = str(rank) os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
# Get device configuration dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
backend = device_config["backend"]
device_module = device_config["module"]
dist.init_process_group(backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
# Set device for this process
device_module.set_device(rank)
device = torch.device(f"{torch_device}:{rank}")
model = model_class(**init_dict) model = model_class(**init_dict)
model.to(device) model.to(device)
@@ -135,7 +116,7 @@ def _custom_mesh_worker(
# DeviceMesh must be created after init_process_group, inside each worker process. # DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh( mesh = torch.distributed.device_mesh.init_device_mesh(
torch_device, mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names "cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
) )
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh) cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config) model.enable_parallelism(config=cp_config)

View File

@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc. # Copyright 2025 HuggingFace Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,84 +13,49 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings import unittest
import torch import torch
from diffusers import QwenImageTransformer2DModel from diffusers import QwenImageTransformer2DModel
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import ( from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism() enable_full_determinism()
class QwenImageTransformerTesterConfig(BaseModelTesterConfig): class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
@property model_class = QwenImageTransformer2DModel
def model_class(self): main_input_name = "hidden_states"
return QwenImageTransformer2DModel # We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property @property
def output_shape(self) -> tuple[int, int]: def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 16) return (16, 16)
@property @property
def input_shape(self) -> tuple[int, int]: def output_shape(self):
return (16, 16) return (16, 16)
@property def prepare_dummy_input(self, height=4, width=4):
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1 batch_size = 1
num_latent_channels = embedding_dim = 16 num_latent_channels = embedding_dim = 16
height = width = 4 sequence_length = 7
sequence_length = 8
vae_scale_factor = 4 vae_scale_factor = 4
hidden_states = randn_tensor( hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor orig_height = height * 2 * vae_scale_factor
@@ -104,57 +70,89 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
"img_shapes": img_shapes, "img_shapes": img_shapes,
} }
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
def test_infers_text_seq_len_from_mask(self): def test_infers_text_seq_len_from_mask(self):
init_dict = self.get_init_dict() """Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
inputs = self.get_dummy_inputs() init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask[:, 2:] = 0 encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask inputs["encoder_hidden_states"], encoder_hidden_states_mask
) )
assert isinstance(rope_text_seq_len, int) # Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
assert isinstance(per_sample_len, torch.Tensor) self.assertIsInstance(rope_text_seq_len, int)
assert int(per_sample_len.max().item()) == 2
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
self.assertIsInstance(per_sample_len, torch.Tensor)
self.assertEqual(int(per_sample_len.max().item()), 2)
# Verify mask is normalized to bool dtype
self.assertTrue(normalized_mask.dtype == torch.bool)
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
# Verify rope_text_seq_len is at least the sequence length
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
# Test 2: Verify model runs successfully with inferred values
inputs["encoder_hidden_states_mask"] = normalized_mask inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad(): with torch.no_grad():
output = model(**inputs) output = model(**inputs)
assert output.sample.shape[1] == inputs["hidden_states"].shape[1] self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Different mask pattern (padding at beginning)
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone() encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask2[:, :3] = 0 encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
encoder_hidden_states_mask2[:, 3:] = 1 encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask( rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask2 inputs["encoder_hidden_states"], encoder_hidden_states_mask2
) )
assert int(per_sample_len2.max().item()) == 8 # Max valid position is 6 (last token), so per_sample_len should be 7
assert normalized_mask2.sum().item() == 5 self.assertEqual(int(per_sample_len2.max().item()), 7)
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
# Test 4: No mask provided (None case)
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None inputs["encoder_hidden_states"], None
) )
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1] self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
assert isinstance(rope_text_seq_len_none, int) self.assertIsInstance(rope_text_seq_len_none, int)
assert per_sample_len_none is None self.assertIsNone(per_sample_len_none)
assert normalized_mask_none is None self.assertIsNone(normalized_mask_none)
def test_non_contiguous_attention_mask(self): def test_non_contiguous_attention_mask(self):
init_dict = self.get_init_dict() """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
inputs = self.get_dummy_inputs() init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
# Pattern: [True, False, True, False, True, False, False]
encoder_hidden_states_mask[:, 1] = 0 encoder_hidden_states_mask[:, 1] = 0
encoder_hidden_states_mask[:, 3] = 0 encoder_hidden_states_mask[:, 3] = 0
encoder_hidden_states_mask[:, 5:] = 0 encoder_hidden_states_mask[:, 5:] = 0
@@ -162,85 +160,95 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask inputs["encoder_hidden_states"], encoder_hidden_states_mask
) )
assert int(per_sample_len.max().item()) == 5 self.assertEqual(int(per_sample_len.max().item()), 5)
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1] self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
assert isinstance(inferred_rope_len, int) self.assertIsInstance(inferred_rope_len, int)
assert normalized_mask.dtype == torch.bool self.assertTrue(normalized_mask.dtype == torch.bool)
inputs["encoder_hidden_states_mask"] = normalized_mask inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad(): with torch.no_grad():
output = model(**inputs) output = model(**inputs)
assert output.sample.shape[1] == inputs["hidden_states"].shape[1] self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
def test_txt_seq_lens_deprecation(self): def test_txt_seq_lens_deprecation(self):
init_dict = self.get_init_dict() """Test that passing txt_seq_lens raises a deprecation warning."""
inputs = self.get_dummy_inputs() init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
# Prepare inputs with txt_seq_lens (deprecated parameter)
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]] txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
# Remove encoder_hidden_states_mask to use the deprecated path
inputs_with_deprecated = inputs.copy() inputs_with_deprecated = inputs.copy()
inputs_with_deprecated.pop("encoder_hidden_states_mask") inputs_with_deprecated.pop("encoder_hidden_states_mask")
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
with warnings.catch_warnings(record=True) as w: # Test that deprecation warning is raised
warnings.simplefilter("always") with self.assertWarns(FutureWarning) as warning_context:
with torch.no_grad(): with torch.no_grad():
output = model(**inputs_with_deprecated) output = model(**inputs_with_deprecated)
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)] # Verify the warning message mentions the deprecation
assert len(future_warnings) > 0, "Expected FutureWarning to be raised" warning_message = str(warning_context.warning)
self.assertIn("txt_seq_lens", warning_message)
self.assertIn("deprecated", warning_message)
self.assertIn("encoder_hidden_states_mask", warning_message)
warning_message = str(future_warnings[0].message) # Verify the model still works correctly despite the deprecation
assert "txt_seq_lens" in warning_message self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
assert "deprecated" in warning_message
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
def test_layered_model_with_mask(self): def test_layered_model_with_mask(self):
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
# Create layered model config
init_dict = { init_dict = {
"patch_size": 2, "patch_size": 2,
"in_channels": 16, "in_channels": 16,
"out_channels": 4, "out_channels": 4,
"num_layers": 2, "num_layers": 2,
"attention_head_dim": 16, "attention_head_dim": 16,
"num_attention_heads": 4, "num_attention_heads": 3,
"joint_attention_dim": 16, "joint_attention_dim": 16,
"axes_dims_rope": (8, 4, 4), "axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
"use_layer3d_rope": True, "use_layer3d_rope": True, # Enable layered RoPE
"use_additional_t_cond": True, "use_additional_t_cond": True, # Enable additional time conditioning
} }
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
# Verify the model uses QwenEmbedLayer3DRope
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope) self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
# Test single generation with layered structure
batch_size = 1 batch_size = 1
text_seq_len = 8 text_seq_len = 7
img_h, img_w = 4, 4 img_h, img_w = 4, 4
layers = 4 layers = 4
# For layered model: (layers + 1) because we have N layers + 1 combined image
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device) hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device) encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
# Create mask with some padding
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device) encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
encoder_hidden_states_mask[0, 5:] = 0 encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
timestep = torch.tensor([1.0]).to(torch_device) timestep = torch.tensor([1.0]).to(torch_device)
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device) addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
# Layer structure: 4 layers + 1 condition image
img_shapes = [ img_shapes = [
[ [
(1, img_h, img_w), (1, img_h, img_w), # layer 0
(1, img_h, img_w), (1, img_h, img_w), # layer 1
(1, img_h, img_w), (1, img_h, img_w), # layer 2
(1, img_h, img_w), (1, img_h, img_w), # layer 3
(1, img_h, img_w), (1, img_h, img_w), # condition image (last one gets special treatment)
] ]
] ]
@@ -254,113 +262,37 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
additional_t_cond=addition_t_cond, additional_t_cond=addition_t_cond,
) )
assert output.sample.shape[1] == hidden_states.shape[1] self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin): class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
"""Memory optimization tests for QwenImage Transformer.""" model_class = QwenImageTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin): def prepare_dummy_input(self, height, width):
"""Training tests for QwenImage Transformer.""" return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
def test_gradient_checkpointing_is_applied(self): def test_torch_compile_recompilation_and_graph_break(self):
expected_set = {"QwenImageTransformer2DModel"} super().test_torch_compile_recompilation_and_graph_break()
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for QwenImage Transformer."""
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
def test_torch_compile_with_and_without_mask(self): def test_torch_compile_with_and_without_mask(self):
init_dict = self.get_init_dict() """Test that torch.compile works with both None mask and padding mask."""
inputs = self.get_dummy_inputs() init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
model.eval() model.eval()
model.compile(mode="default", fullgraph=True) model.compile(mode="default", fullgraph=True)
# Test 1: Run with None mask (no padding, all tokens are valid)
inputs_no_mask = inputs.copy() inputs_no_mask = inputs.copy()
inputs_no_mask["encoder_hidden_states_mask"] = None inputs_no_mask["encoder_hidden_states_mask"] = None
# First run to allow compilation
with torch.no_grad(): with torch.no_grad():
output_no_mask = model(**inputs_no_mask) output_no_mask = model(**inputs_no_mask)
# Second run to verify no recompilation
with ( with (
torch._inductor.utils.fresh_inductor_cache(), torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True), torch._dynamo.config.patch(error_on_recompile=True),
@@ -368,15 +300,19 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
): ):
output_no_mask_2 = model(**inputs_no_mask) output_no_mask_2 = model(**inputs_no_mask)
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1] self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1] self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 2: Run with all-ones mask (should behave like None)
inputs_all_ones = inputs.copy() inputs_all_ones = inputs.copy()
assert inputs_all_ones["encoder_hidden_states_mask"].all().item() # Keep the all-ones mask
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
# First run to allow compilation
with torch.no_grad(): with torch.no_grad():
output_all_ones = model(**inputs_all_ones) output_all_ones = model(**inputs_all_ones)
# Second run to verify no recompilation
with ( with (
torch._inductor.utils.fresh_inductor_cache(), torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True), torch._dynamo.config.patch(error_on_recompile=True),
@@ -384,18 +320,21 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
): ):
output_all_ones_2 = model(**inputs_all_ones) output_all_ones_2 = model(**inputs_all_ones)
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1] self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1] self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Run with actual padding mask (has zeros)
inputs_with_padding = inputs.copy() inputs_with_padding = inputs.copy()
mask_with_padding = inputs["encoder_hidden_states_mask"].clone() mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
mask_with_padding[:, 4:] = 0 mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
# First run to allow compilation
with torch.no_grad(): with torch.no_grad():
output_with_padding = model(**inputs_with_padding) output_with_padding = model(**inputs_with_padding)
# Second run to verify no recompilation
with ( with (
torch._inductor.utils.fresh_inductor_cache(), torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True), torch._dynamo.config.patch(error_on_recompile=True),
@@ -403,15 +342,8 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
): ):
output_with_padding_2 = model(**inputs_with_padding) output_with_padding_2 = model(**inputs_with_padding)
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1] self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1] self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3) # Verify that outputs are different (mask should affect results)
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for QwenImage Transformer."""
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for QwenImage Transformer."""

View File

@@ -5,6 +5,7 @@ from typing import Callable
import pytest import pytest
import torch import torch
from huggingface_hub import hf_hub_download
import diffusers import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
@@ -32,6 +33,33 @@ from ..testing_utils import (
) )
def _get_specified_components(path_or_repo_id, cache_dir=None):
if os.path.isdir(path_or_repo_id):
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
else:
try:
config_path = hf_hub_download(
repo_id=path_or_repo_id,
filename="modular_model_index.json",
local_dir=cache_dir,
)
except Exception:
return None
with open(config_path) as f:
config = json.load(f)
components = set()
for k, v in config.items():
if isinstance(v, (str, int, float, bool)):
continue
for entry in v:
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
components.add(k)
break
return components
class ModularPipelineTesterMixin: class ModularPipelineTesterMixin:
""" """
It provides a set of common tests for each modular pipeline, It provides a set of common tests for each modular pipeline,
@@ -360,6 +388,39 @@ class ModularPipelineTesterMixin:
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_load_expected_components_from_pretrained(self, tmp_path):
pipe = self.get_pipeline()
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
if not expected:
pytest.skip("Skipping test as we couldn't fetch the expected components.")
actual = {
name
for name in pipe.components
if getattr(pipe, name, None) is not None
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
}
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
def test_load_expected_components_from_save_pretrained(self, tmp_path):
pipe = self.get_pipeline()
save_dir = str(tmp_path / "saved-pipeline")
pipe.save_pretrained(save_dir)
expected = _get_specified_components(save_dir)
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
loaded_pipe.load_components(torch_dtype=torch.float32)
actual = {
name
for name in loaded_pipe.components
if getattr(loaded_pipe, name, None) is not None
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
}
assert expected == actual, (
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
)
def test_modular_index_consistency(self, tmp_path): def test_modular_index_consistency(self, tmp_path):
pipe = self.get_pipeline() pipe = self.get_pipeline()
components_spec = pipe._component_specs components_spec = pipe._component_specs

View File

@@ -1,174 +0,0 @@
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM
from diffusers import (
AutoencoderKLFlux2,
FlowMatchEulerDiscreteScheduler,
Flux2KleinKVPipeline,
Flux2Transformer2DModel,
)
from ...testing_utils import torch_device
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class Flux2KleinKVPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Flux2KleinKVPipeline
params = frozenset(["prompt", "height", "width", "prompt_embeds", "image"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
supports_dduf = False
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = Flux2Transformer2DModel(
patch_size=1,
in_channels=4,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=16,
timestep_guidance_channels=256,
axes_dims_rope=[4, 4, 4, 4],
guidance_embeds=False,
)
# Create minimal Qwen3 config
config = Qwen3Config(
intermediate_size=16,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
vocab_size=151936,
max_position_embeddings=512,
)
torch.manual_seed(0)
text_encoder = Qwen3ForCausalLM(config)
# Use a simple tokenizer for testing
tokenizer = Qwen2TokenizerFast.from_pretrained(
"hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
)
torch.manual_seed(0)
vae = AutoencoderKLFlux2(
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "a dog is dancing",
"image": Image.new("RGB", (64, 64)),
"generator": generator,
"num_inference_steps": 2,
"height": 8,
"width": 8,
"max_sequence_length": 64,
"output_type": "np",
"text_encoder_out_layers": (1,),
}
return inputs
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
self.assertTrue(
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
self.assertTrue(
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
("Fusion of QKV projections shouldn't affect the outputs."),
)
self.assertTrue(
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
)
self.assertTrue(
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
("Original outputs should match when fused QKV projections are disabled."),
)
def test_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
self.assertEqual(
(output_height, output_width),
(expected_height, expected_width),
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
)
def test_without_image(self):
device = "cpu"
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
inputs = self.get_dummy_inputs(device)
del inputs["image"]
image = pipe(**inputs).images
self.assertEqual(image.shape, (1, 8, 8, 3))
@unittest.skip("Needs to be revisited")
def test_encode_prompt_works_in_isolation(self):
pass

View File

@@ -139,9 +139,9 @@ class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
@unittest.skip("Helios uses a lot of mixed precision internally, which is not suitable for this test case") # Override to set a more lenient max diff threshold.
def test_save_load_float16(self): def test_save_load_float16(self):
pass super().test_save_load_float16(expected_max_diff=0.03)
@unittest.skip("Test not supported") @unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):

View File

@@ -139,9 +139,7 @@ class HunyuanVideoImageToVideoPipelineFastTests(
num_hidden_layers=2, num_hidden_layers=2,
image_size=224, image_size=224,
) )
llava_text_encoder_config = LlavaConfig( llava_text_encoder_config = LlavaConfig(vision_config, text_config, pad_token_id=100, image_token_index=101)
vision_config=vision_config, text_config=text_config, pad_token_id=100, image_token_index=101
)
clip_text_encoder_config = CLIPTextConfig( clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0, bos_token_id=0,