mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-09 20:35:18 +08:00
Compare commits
22 Commits
cache-docs
...
ltx2-infer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f4d47b9cec | ||
|
|
ce5a51430b | ||
|
|
9575e0632a | ||
|
|
9d68742214 | ||
|
|
f1a93c765f | ||
|
|
29a930a142 | ||
|
|
dad5cb55e6 | ||
|
|
faeccc557a | ||
|
|
96fbcd8301 | ||
|
|
837fd85c76 | ||
|
|
d988fc34f1 | ||
|
|
82c2e7f068 | ||
|
|
6fbeacf53b | ||
|
|
9c754a46aa | ||
|
|
b86bd99eac | ||
|
|
5b202111bf | ||
|
|
4ac2b4a521 | ||
|
|
418313bbf6 | ||
|
|
2120c3096f | ||
|
|
ed6e5ecf67 | ||
|
|
d44b5f86e6 | ||
|
|
3d78f9d17d |
@@ -29,7 +29,7 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
### FirstBlockCacheConfig
|
||||
## FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
|
||||
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
|
||||
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
|
||||
- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2).
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
> [!TIP]
|
||||
@@ -62,6 +63,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
|
||||
|
||||
## LTX2LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin
|
||||
|
||||
## CogVideoXLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||
|
||||
@@ -21,7 +21,7 @@ The abstract from the paper is:
|
||||
|
||||
*Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.*
|
||||
|
||||
The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html).
|
||||
The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/posts/2022-11-02-diffedit-implementation.html).
|
||||
|
||||
This pipeline was contributed by [clarencechen](https://github.com/clarencechen). ❤️
|
||||
|
||||
|
||||
@@ -14,6 +14,10 @@
|
||||
|
||||
# LTX-2
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
|
||||
|
||||
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
|
||||
|
||||
@@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
|
||||
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
|
||||
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
|
||||
image = pipe(
|
||||
image=[image_1, image_2],
|
||||
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
|
||||
image=[image_1, image_2],
|
||||
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
### torch.compile
|
||||
|
||||
Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe.transformer = torch.compile(pipe.transformer)
|
||||
|
||||
# First call triggers compilation (~7s overhead)
|
||||
# Subsequent calls run at ~2.4x faster
|
||||
image = pipe("a cat", num_inference_steps=50).images[0]
|
||||
```
|
||||
|
||||
### Batched Inference with Variable-Length Prompts
|
||||
|
||||
When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.
|
||||
|
||||
```python
|
||||
# CFG with different prompt lengths works correctly
|
||||
image = pipe(
|
||||
prompt="A cat",
|
||||
negative_prompt="blurry, low quality, distorted",
|
||||
true_cfg_scale=3.5,
|
||||
num_inference_steps=50,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).
|
||||
|
||||
## QwenImagePipeline
|
||||
|
||||
[[autodoc]] QwenImagePipeline
|
||||
|
||||
@@ -140,7 +140,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
description="""Output type from annotation predictions. Available options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
@@ -256,7 +256,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
description="""Output type from annotation predictions. Available options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
|
||||
@@ -53,7 +53,7 @@ The loop wrapper can pass additional arguments, like current iteration index, to
|
||||
|
||||
A loop block is a [`~modular_pipelines.ModularPipelineBlocks`], but the `__call__` method behaves differently.
|
||||
|
||||
- It recieves the iteration variable from the loop wrapper.
|
||||
- It receives the iteration variable from the loop wrapper.
|
||||
- It works directly with the [`~modular_pipelines.BlockState`] instead of the [`~modular_pipelines.PipelineState`].
|
||||
- It doesn't require retrieving or updating the [`~modular_pipelines.BlockState`].
|
||||
|
||||
|
||||
@@ -68,6 +68,20 @@ config = FasterCacheConfig(
|
||||
pipeline.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## FirstBlockCache
|
||||
|
||||
[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) checks how much the early layers of the denoiser changes from one timestep to the next. If the change is small, the model skips the expensive later layers and reuses the previous output.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16
|
||||
)
|
||||
apply_first_block_cache(pipeline.transformer, FirstBlockCacheConfig(threshold=0.2))
|
||||
```
|
||||
## TaylorSeer Cache
|
||||
|
||||
[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.
|
||||
@@ -87,8 +101,7 @@ from diffusers import FluxPipeline, TaylorSeerCacheConfig
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.to("cuda")
|
||||
).to("cuda")
|
||||
|
||||
config = TaylorSeerCacheConfig(
|
||||
cache_interval=5,
|
||||
@@ -97,4 +110,4 @@ config = TaylorSeerCacheConfig(
|
||||
taylor_factors_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
```
|
||||
|
||||
@@ -333,3 +333,31 @@ pipeline = DiffusionPipeline.from_pretrained(
|
||||
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
|
||||
).to(device)
|
||||
```
|
||||
### Unified Attention
|
||||
|
||||
[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout.
|
||||
|
||||
This hybrid approach leverages the strengths of both methods:
|
||||
- **Ulysses Attention** efficiently parallelizes across attention heads
|
||||
- **Ring Attention** handles very long sequences with minimal memory overhead
|
||||
- Together, they enable 2D parallelization across both heads and sequence dimensions
|
||||
|
||||
[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).
|
||||
Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`].
|
||||
|
||||
```py
|
||||
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2))
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).
|
||||
|
||||
We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows:
|
||||
|
||||
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) |
|
||||
|--------------------|------------------|-------------|------------------|
|
||||
| ulysses | 6670.789 | 7.50 | 33.85 |
|
||||
| ring | 13076.492 | 3.82 | 56.02 |
|
||||
| unified_balanced | 11068.705 | 4.52 | 33.85 |
|
||||
|
||||
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to number of attention-heads, a limitation that is solved by unified attention.
|
||||
|
||||
@@ -149,13 +149,13 @@ def get_args():
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_images",
|
||||
type=str,
|
||||
default=None,
|
||||
help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
|
||||
help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_separator' string. These should correspond to the order of the validation prompts.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt_separator",
|
||||
|
||||
@@ -140,7 +140,7 @@ def get_args():
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt_separator",
|
||||
|
||||
@@ -1228,7 +1228,7 @@ def main(args):
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
|
||||
@@ -1178,7 +1178,7 @@ def main(args):
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
@@ -1695,9 +1695,13 @@ def main(args):
|
||||
cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std
|
||||
|
||||
model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
|
||||
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to(
|
||||
cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]
|
||||
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to(
|
||||
device=cond_model_input.device
|
||||
)
|
||||
cond_model_input_ids = cond_model_input_ids.view(
|
||||
cond_model_input.shape[0], -1, model_input_ids.shape[-1]
|
||||
)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
@@ -1724,6 +1728,9 @@ def main(args):
|
||||
packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
|
||||
packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)
|
||||
|
||||
orig_input_shape = packed_noisy_model_input.shape
|
||||
orig_input_ids_shape = model_input_ids.shape
|
||||
|
||||
# concatenate the model inputs with the cond inputs
|
||||
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
|
||||
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
|
||||
@@ -1742,7 +1749,8 @@ def main(args):
|
||||
img_ids=model_input_ids, # B, image_seq_len, 4
|
||||
return_dict=False,
|
||||
)[0]
|
||||
model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
|
||||
model_pred = model_pred[:, : orig_input_shape[1], :]
|
||||
model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]
|
||||
|
||||
model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)
|
||||
|
||||
|
||||
@@ -1513,14 +1513,12 @@ def main(args):
|
||||
height=model_input.shape[3],
|
||||
width=model_input.shape[4],
|
||||
)
|
||||
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
|
||||
model_pred = transformer(
|
||||
hidden_states=packed_noisy_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
timestep=timesteps / 1000,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
return_dict=False,
|
||||
)[0]
|
||||
model_pred = QwenImagePipeline._unpack_latents(
|
||||
|
||||
@@ -4,7 +4,7 @@ The `train_text_to_image.py` script shows how to fine-tune stable diffusion mode
|
||||
|
||||
___Note___:
|
||||
|
||||
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
|
||||
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___
|
||||
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
@@ -18,7 +18,7 @@ cc.initialize_cache("/tmp/sdxl_cache")
|
||||
NUM_DEVICES = jax.device_count()
|
||||
|
||||
# 1. Let's start by downloading the model and loading it into our pipeline class
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned separately and
|
||||
# will have to be passed to the pipeline during inference
|
||||
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
|
||||
|
||||
@@ -63,6 +63,8 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
@@ -372,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -
|
||||
return connectors
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
def get_ltx2_video_vae_config(
|
||||
version: str, timestep_conditioning: bool = False
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
@@ -396,7 +400,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
@@ -433,7 +437,7 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
@@ -450,8 +454,10 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
|
||||
def convert_ltx2_video_vae(
|
||||
original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool
|
||||
) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
@@ -717,6 +723,9 @@ def get_args():
|
||||
help="Latent upsampler filename",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model"
|
||||
)
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
@@ -786,7 +795,9 @@ def main(args):
|
||||
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
|
||||
vae = convert_ltx2_video_vae(
|
||||
original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning
|
||||
)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ if is_torch_available():
|
||||
"SD3LoraLoaderMixin",
|
||||
"AuraFlowLoraLoaderMixin",
|
||||
"StableDiffusionXLLoraLoaderMixin",
|
||||
"LTX2LoraLoaderMixin",
|
||||
"LTXVideoLoraLoaderMixin",
|
||||
"LoraLoaderMixin",
|
||||
"FluxLoraLoaderMixin",
|
||||
@@ -121,6 +122,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
KandinskyLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
LTX2LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Lumina2LoraLoaderMixin,
|
||||
Mochi1LoraLoaderMixin,
|
||||
|
||||
@@ -2140,6 +2140,54 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
|
||||
# Remove the prefix
|
||||
state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{non_diffusers_prefix}.")}
|
||||
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
||||
|
||||
if non_diffusers_prefix == "diffusion_model":
|
||||
rename_dict = {
|
||||
"patchify_proj": "proj_in",
|
||||
"audio_patchify_proj": "audio_proj_in",
|
||||
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
||||
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
||||
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
||||
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
||||
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
else:
|
||||
rename_dict = {"aggregate_embed": "text_proj_in"}
|
||||
|
||||
# Apply renaming
|
||||
renamed_state_dict = {}
|
||||
for key, value in converted_state_dict.items():
|
||||
new_key = key[:]
|
||||
for old_pattern, new_pattern in rename_dict.items():
|
||||
new_key = new_key.replace(old_pattern, new_pattern)
|
||||
renamed_state_dict[new_key] = value
|
||||
|
||||
# Handle adaln_single -> time_embed and audio_adaln_single -> audio_time_embed
|
||||
final_state_dict = {}
|
||||
for key, value in renamed_state_dict.items():
|
||||
if key.startswith("adaln_single."):
|
||||
new_key = key.replace("adaln_single.", "time_embed.")
|
||||
final_state_dict[new_key] = value
|
||||
elif key.startswith("audio_adaln_single."):
|
||||
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
||||
final_state_dict[new_key] = value
|
||||
else:
|
||||
final_state_dict[key] = value
|
||||
|
||||
# Add transformer prefix
|
||||
prefix = "transformer" if non_diffusers_prefix == "diffusion_model" else "connectors"
|
||||
final_state_dict = {f"{prefix}.{k}": v for k, v in final_state_dict.items()}
|
||||
|
||||
return final_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if has_diffusion_model:
|
||||
|
||||
@@ -48,6 +48,7 @@ from .lora_conversion_utils import (
|
||||
_convert_non_diffusers_flux2_lora_to_diffusers,
|
||||
_convert_non_diffusers_hidream_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_ltx2_lora_to_diffusers,
|
||||
_convert_non_diffusers_ltxv_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_non_diffusers_qwen_lora_to_diffusers,
|
||||
@@ -74,6 +75,7 @@ logger = logging.get_logger(__name__)
|
||||
TEXT_ENCODER_NAME = "text_encoder"
|
||||
UNET_NAME = "unet"
|
||||
TRANSFORMER_NAME = "transformer"
|
||||
LTX2_CONNECTOR_NAME = "connectors"
|
||||
|
||||
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
|
||||
|
||||
@@ -212,7 +214,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
@@ -639,7 +641,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
@@ -1079,7 +1081,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -1375,7 +1377,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -1657,7 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
if not (has_lora_keys or has_norm_keys):
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
transformer_lora_state_dict = {
|
||||
k: state_dict.get(k)
|
||||
@@ -2504,7 +2506,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -2701,7 +2703,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -2904,7 +2906,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3011,6 +3013,233 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LTX2LoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`LTX2VideoTransformer3DModel`]. Specific to [`LTX2Pipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer", "connectors"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
connectors_name = LTX2_CONNECTOR_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
final_state_dict = state_dict
|
||||
is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
has_connector = any(k.startswith("text_embedding_projection.") for k in state_dict)
|
||||
if is_non_diffusers_format:
|
||||
final_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict)
|
||||
if has_connector:
|
||||
connectors_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(
|
||||
state_dict, "text_embedding_projection"
|
||||
)
|
||||
final_state_dict.update(connectors_state_dict)
|
||||
out = (final_state_dict, metadata) if return_lora_metadata else final_state_dict
|
||||
return out
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
transformer_peft_state_dict = {
|
||||
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
|
||||
}
|
||||
connectors_peft_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.connectors_name}.")}
|
||||
self.load_lora_into_transformer(
|
||||
transformer_peft_state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
if connectors_peft_state_dict:
|
||||
self.load_lora_into_transformer(
|
||||
connectors_peft_state_dict,
|
||||
transformer=getattr(self, self.connectors_name)
|
||||
if not hasattr(self, "connectors")
|
||||
else self.connectors,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
prefix=self.connectors_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
prefix: str = "transformer",
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {prefix}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
|
||||
@@ -3104,7 +3333,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3307,7 +3536,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3511,7 +3740,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3711,7 +3940,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3965,7 +4194,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
@@ -4242,7 +4471,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
@@ -4462,7 +4691,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -4665,7 +4894,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -4871,7 +5100,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5077,7 +5306,7 @@ class ZImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5280,7 +5509,7 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
|
||||
@@ -67,6 +67,8 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTX2TextConnectors": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -90,10 +90,6 @@ class ContextParallelConfig:
|
||||
)
|
||||
if self.ring_degree < 1 or self.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if self.ring_degree > 1 and self.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
|
||||
@@ -1177,6 +1177,103 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
|
||||
"""
|
||||
Perform dimension sharding / reassembly across processes using _all_to_all_single.
|
||||
|
||||
This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
|
||||
head dimension flexibly by accepting scatter_idx and gather_idx.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
Input tensor. Expected shapes:
|
||||
- When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
|
||||
- When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
|
||||
scatter_idx (int) :
|
||||
Dimension along which the tensor is partitioned before all-to-all.
|
||||
gather_idx (int):
|
||||
Dimension along which the output is reassembled after all-to-all.
|
||||
group :
|
||||
Distributed process group for the Ulysses group.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tensor with globally exchanged dimensions.
|
||||
- For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
|
||||
- For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
|
||||
"""
|
||||
group_world_size = torch.distributed.get_world_size(group)
|
||||
|
||||
if scatter_idx == 2 and gather_idx == 1:
|
||||
# Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
|
||||
# dimension and scatters head dimension
|
||||
batch_size, seq_len_local, num_heads, head_dim = x.shape
|
||||
seq_len = seq_len_local * group_world_size
|
||||
num_heads_local = num_heads // group_world_size
|
||||
|
||||
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
|
||||
x_temp = (
|
||||
x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim)
|
||||
.transpose(0, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
if group_world_size > 1:
|
||||
out = _all_to_all_single(x_temp, group=group)
|
||||
else:
|
||||
out = x_temp
|
||||
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
|
||||
out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous()
|
||||
out = out.reshape(batch_size, seq_len, num_heads_local, head_dim)
|
||||
return out
|
||||
elif scatter_idx == 1 and gather_idx == 2:
|
||||
# Used after ulysses sequence parallel in unified SP. gathers the head dimension
|
||||
# scatters back the sequence dimension.
|
||||
batch_size, seq_len, num_heads_local, head_dim = x.shape
|
||||
num_heads = num_heads_local * group_world_size
|
||||
seq_len_local = seq_len // group_world_size
|
||||
|
||||
# B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
|
||||
x_temp = (
|
||||
x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
|
||||
.permute(1, 3, 2, 0, 4)
|
||||
.reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
|
||||
)
|
||||
|
||||
if group_world_size > 1:
|
||||
output = _all_to_all_single(x_temp, group)
|
||||
else:
|
||||
output = x_temp
|
||||
output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous()
|
||||
output = output.reshape(batch_size, seq_len_local, num_heads, head_dim)
|
||||
return output
|
||||
else:
|
||||
raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")
|
||||
|
||||
|
||||
class SeqAllToAllDim(torch.autograd.Function):
|
||||
"""
|
||||
all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
|
||||
for more info.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, group, input, scatter_id=2, gather_id=1):
|
||||
ctx.group = group
|
||||
ctx.scatter_id = scatter_id
|
||||
ctx.gather_id = gather_id
|
||||
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_outputs):
|
||||
grad_input = SeqAllToAllDim.apply(
|
||||
ctx.group,
|
||||
grad_outputs,
|
||||
ctx.gather_id, # reversed
|
||||
ctx.scatter_id, # reversed
|
||||
)
|
||||
return (None, grad_input, None, None)
|
||||
|
||||
|
||||
class TemplatedRingAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@@ -1237,7 +1334,10 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
out = out.to(torch.float32)
|
||||
lse = lse.to(torch.float32)
|
||||
|
||||
lse = lse.unsqueeze(-1)
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if is_torch_version("<", "2.9.0"):
|
||||
lse = lse.unsqueeze(-1)
|
||||
if prev_out is not None:
|
||||
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
|
||||
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
|
||||
@@ -1298,7 +1398,7 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
|
||||
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
|
||||
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
@@ -1393,7 +1493,69 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
|
||||
)
|
||||
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def _templated_unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor],
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
scale: Optional[float],
|
||||
enable_gqa: bool,
|
||||
return_lse: bool,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
scatter_idx: int = 2,
|
||||
gather_idx: int = 1,
|
||||
):
|
||||
"""
|
||||
Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
|
||||
"""
|
||||
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
|
||||
ulysses_group = ulysses_mesh.get_group()
|
||||
|
||||
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
|
||||
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
|
||||
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
|
||||
out = TemplatedRingAttention.apply(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
context_layer, lse, *_ = out
|
||||
else:
|
||||
context_layer = out
|
||||
# context_layer is of shape (B, S, H_LOCAL, D)
|
||||
output = SeqAllToAllDim.apply(
|
||||
ulysses_group,
|
||||
context_layer,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
)
|
||||
if return_lse:
|
||||
# lse is of shape (B, S, H_LOCAL, 1)
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if is_torch_version("<", "2.9.0"):
|
||||
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
|
||||
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
|
||||
lse = lse.squeeze(-1)
|
||||
return (output, lse)
|
||||
return output
|
||||
|
||||
|
||||
def _templated_context_parallel_attention(
|
||||
@@ -1419,7 +1581,25 @@ def _templated_context_parallel_attention(
|
||||
raise ValueError("GQA is not yet supported for templated attention.")
|
||||
|
||||
# TODO: add support for unified attention with ring/ulysses degree both being > 1
|
||||
if _parallel_config.context_parallel_config.ring_degree > 1:
|
||||
if (
|
||||
_parallel_config.context_parallel_config.ring_degree > 1
|
||||
and _parallel_config.context_parallel_config.ulysses_degree > 1
|
||||
):
|
||||
return _templated_unified_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
elif _parallel_config.context_parallel_config.ring_degree > 1:
|
||||
return TemplatedRingAttention.apply(
|
||||
query,
|
||||
key,
|
||||
@@ -1945,6 +2125,43 @@ def _native_flex_attention(
|
||||
return out
|
||||
|
||||
|
||||
def _prepare_additive_attn_mask(
|
||||
attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA.
|
||||
|
||||
This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks.
|
||||
|
||||
Args:
|
||||
attn_mask: 2D tensor [batch_size, seq_len_k]
|
||||
- Boolean: True means attend, False means mask out
|
||||
- Additive: 0.0 means attend, -inf means mask out
|
||||
target_dtype: The dtype to convert the mask to (usually query.dtype)
|
||||
reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting
|
||||
|
||||
Returns:
|
||||
Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if
|
||||
reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True.
|
||||
"""
|
||||
# Check if the mask is boolean or already additive
|
||||
if attn_mask.dtype == torch.bool:
|
||||
# Convert boolean to additive: True -> 0.0, False -> -inf
|
||||
attn_mask = torch.where(attn_mask, 0.0, float("-inf"))
|
||||
# Convert to target dtype
|
||||
attn_mask = attn_mask.to(dtype=target_dtype)
|
||||
else:
|
||||
# Already additive mask - just ensure correct dtype
|
||||
attn_mask = attn_mask.to(dtype=target_dtype)
|
||||
|
||||
# Optionally reshape to 4D for broadcasting in attention mechanisms
|
||||
if reshape_4d:
|
||||
batch_size, seq_len_k = attn_mask.shape
|
||||
attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k)
|
||||
|
||||
return attn_mask
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.NATIVE,
|
||||
constraints=[_check_device, _check_shape],
|
||||
@@ -1964,6 +2181,19 @@ def _native_attention(
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
|
||||
|
||||
# Reshape 2D mask to 4D for SDPA
|
||||
# SDPA accepts both boolean masks (torch.bool) and additive masks (float)
|
||||
if (
|
||||
attn_mask is not None
|
||||
and attn_mask.ndim == 2
|
||||
and attn_mask.shape[0] == query.shape[0]
|
||||
and attn_mask.shape[1] == key.shape[1]
|
||||
):
|
||||
# Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k]
|
||||
# SDPA handles both boolean and additive masks correctly
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
if _parallel_config is None:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
@@ -2530,10 +2760,34 @@ def _xformers_attention(
|
||||
attn_mask = xops.LowerTriangularMask()
|
||||
elif attn_mask is not None:
|
||||
if attn_mask.ndim == 2:
|
||||
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
|
||||
# Convert 2D mask to 4D for xformers
|
||||
# Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask)
|
||||
# xformers requires 4D additive masks [batch, heads, seq_q, seq_k]
|
||||
# Need memory alignment - create larger tensor and slice for alignment
|
||||
original_seq_len = attn_mask.size(1)
|
||||
aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8
|
||||
|
||||
# Create aligned 4D tensor and slice to ensure proper memory layout
|
||||
aligned_mask = torch.zeros(
|
||||
(batch_size, num_heads_q, seq_len_q, aligned_seq_len),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
# Convert to 4D additive mask (handles both boolean and additive inputs)
|
||||
mask_additive = _prepare_additive_attn_mask(
|
||||
attn_mask, target_dtype=query.dtype
|
||||
) # [batch, 1, 1, seq_len_k]
|
||||
# Broadcast to [batch, heads, seq_q, seq_len_k]
|
||||
aligned_mask[:, :, :, :original_seq_len] = mask_additive
|
||||
# Mask out the padding (already -inf from zeros -> where with default)
|
||||
aligned_mask[:, :, :, original_seq_len:] = float("-inf")
|
||||
|
||||
# Slice to actual size with proper alignment
|
||||
attn_mask = aligned_mask[:, :, :, :seq_len_kv]
|
||||
elif attn_mask.ndim != 4:
|
||||
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
|
||||
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
|
||||
elif attn_mask.ndim == 4:
|
||||
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
|
||||
|
||||
if enable_gqa:
|
||||
if num_heads_q % num_heads_kv != 0:
|
||||
|
||||
@@ -102,14 +102,14 @@ def get_block(
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: Tuple[int, ...] = (),
|
||||
qkv_multiscales: Tuple[int, ...] = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
|
||||
elif block_type == "EfficientViTBlock":
|
||||
block = EfficientViTBlock(
|
||||
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales
|
||||
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_multiscales
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -247,7 +247,7 @@ class Encoder(nn.Module):
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type="rms_norm",
|
||||
act_fn="silu",
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
qkv_multiscales=qkv_multiscales[i],
|
||||
)
|
||||
down_block_list.append(block)
|
||||
|
||||
@@ -339,7 +339,7 @@ class Decoder(nn.Module):
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type[i],
|
||||
act_fn=act_fn[i],
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
qkv_multiscales=qkv_multiscales[i],
|
||||
)
|
||||
up_block_list.append(block)
|
||||
|
||||
|
||||
@@ -41,9 +41,11 @@ class CacheMixin:
|
||||
Enable caching techniques on the model.
|
||||
|
||||
Args:
|
||||
config (`Union[PyramidAttentionBroadcastConfig]`):
|
||||
config (`Union[PyramidAttentionBroadcastConfig, FasterCacheConfig, FirstBlockCacheConfig]`):
|
||||
The configuration for applying the caching technique. Currently supported caching techniques are:
|
||||
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
||||
- [`~hooks.FasterCacheConfig`]
|
||||
- [`~hooks.FirstBlockCacheConfig`]
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..controlnets.controlnet import zero_module
|
||||
@@ -31,6 +31,7 @@ from ..transformers.transformer_qwenimage import (
|
||||
QwenImageTransformerBlock,
|
||||
QwenTimestepProjEmbeddings,
|
||||
RMSNorm,
|
||||
compute_text_seq_len_from_mask,
|
||||
)
|
||||
|
||||
|
||||
@@ -136,7 +137,7 @@ class QwenImageControlNetModel(
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
The [`QwenImageControlNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
@@ -147,24 +148,39 @@ class QwenImageControlNetModel(
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
|
||||
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
|
||||
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
|
||||
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
|
||||
Image shapes for RoPE computation.
|
||||
txt_seq_lens (`List[int]`, *optional*):
|
||||
**Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence
|
||||
length.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where
|
||||
the first element is the controlnet block samples.
|
||||
"""
|
||||
# Handle deprecated txt_seq_lens parameter
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in "
|
||||
"version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` "
|
||||
"and `encoder_hidden_states_mask`.",
|
||||
standard_warn=False,
|
||||
)
|
||||
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
@@ -186,32 +202,47 @@ class QwenImageControlNetModel(
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
|
||||
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states, encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
# Construct joint attention mask once to avoid reconstructing in every block
|
||||
block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {}
|
||||
if encoder_hidden_states_mask is not None:
|
||||
# Build joint mask: [text_mask, all_ones_for_image]
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
block_samples = ()
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states_mask,
|
||||
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
block_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
joint_attention_kwargs=block_attention_kwargs,
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
@@ -267,6 +298,15 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[QwenImageControlNetOutput, Tuple]:
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be "
|
||||
"removed in version 0.39.0. The text sequence length is now automatically inferred from "
|
||||
"`encoder_hidden_states` and `encoder_hidden_states_mask`.",
|
||||
standard_warn=False,
|
||||
)
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1:
|
||||
@@ -281,7 +321,6 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
timestep=timestep,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
@@ -142,6 +142,32 @@ def apply_rotary_emb_qwen(
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
def compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor]
|
||||
) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask.
|
||||
"""
|
||||
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
|
||||
if encoder_hidden_states_mask is None:
|
||||
return text_seq_len, None, None
|
||||
|
||||
if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len):
|
||||
raise ValueError(
|
||||
f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match "
|
||||
f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})."
|
||||
)
|
||||
|
||||
if encoder_hidden_states_mask.dtype != torch.bool:
|
||||
encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool)
|
||||
|
||||
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
|
||||
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
|
||||
has_active = encoder_hidden_states_mask.any(dim=1)
|
||||
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
|
||||
return text_seq_len, per_sample_len, encoder_hidden_states_mask
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, use_additional_t_cond=False):
|
||||
super().__init__()
|
||||
@@ -207,21 +233,50 @@ class QwenEmbedRope(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
|
||||
txt_seq_lens: List[int],
|
||||
device: torch.device,
|
||||
txt_seq_lens: Optional[List[int]] = None,
|
||||
device: torch.device = None,
|
||||
max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
|
||||
A list of 3 integers [frame, height, width] representing the shape of the video.
|
||||
txt_seq_lens (`List[int]`):
|
||||
A list of integers of length batch_size representing the length of each text prompt.
|
||||
device: (`torch.device`):
|
||||
txt_seq_lens (`List[int]`, *optional*, **Deprecated**):
|
||||
Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used.
|
||||
device: (`torch.device`, *optional*):
|
||||
The device on which to perform the RoPE computation.
|
||||
max_txt_seq_len (`int` or `torch.Tensor`, *optional*):
|
||||
The maximum text sequence length for RoPE computation. This should match the encoder hidden states
|
||||
sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
# Handle deprecated txt_seq_lens parameter
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
|
||||
"Please use `max_txt_seq_len` instead. "
|
||||
"The new parameter accepts a single int or tensor value representing the maximum text sequence length.",
|
||||
standard_warn=False,
|
||||
)
|
||||
if max_txt_seq_len is None:
|
||||
# Use max of txt_seq_lens for backward compatibility
|
||||
max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens
|
||||
|
||||
if max_txt_seq_len is None:
|
||||
raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.")
|
||||
|
||||
# Validate batch inference with variable-sized images
|
||||
if isinstance(video_fhw, list) and len(video_fhw) > 1:
|
||||
# Check if all instances have the same size
|
||||
first_fhw = video_fhw[0]
|
||||
if not all(fhw == first_fhw for fhw in video_fhw):
|
||||
logger.warning(
|
||||
"Batch inference with variable-sized images is not currently supported in QwenEmbedRope. "
|
||||
"All images in the batch should have the same dimensions (frame, height, width). "
|
||||
f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} "
|
||||
"for RoPE computation, which may lead to incorrect results for other images in the batch."
|
||||
)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
@@ -233,8 +288,7 @@ class QwenEmbedRope(nn.Module):
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||
video_freq = video_freq.to(device)
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
if self.scale_rope:
|
||||
@@ -242,17 +296,23 @@ class QwenEmbedRope(nn.Module):
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
# Create device-specific copy for text freqs without modifying self.pos_freqs
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
|
||||
def _compute_video_freqs(
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
@@ -304,14 +364,35 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
|
||||
max_txt_seq_len: Union[int, torch.Tensor],
|
||||
device: torch.device = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
||||
txt_length: [bs] a list of 1 integers representing the length of the text
|
||||
Args:
|
||||
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
|
||||
A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer
|
||||
structures.
|
||||
max_txt_seq_len (`int` or `torch.Tensor`):
|
||||
The maximum text sequence length for RoPE computation. This should match the encoder hidden states
|
||||
sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
|
||||
device: (`torch.device`, *optional*):
|
||||
The device on which to perform the RoPE computation.
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
# Validate batch inference with variable-sized images
|
||||
# In Layer3DRope, the outer list represents batch, inner list/tuple represents layers
|
||||
if isinstance(video_fhw, list) and len(video_fhw) > 1:
|
||||
# Check if this is batch inference (list of layer lists/tuples)
|
||||
first_entry = video_fhw[0]
|
||||
if not all(entry == first_entry for entry in video_fhw):
|
||||
logger.warning(
|
||||
"Batch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. "
|
||||
"All images in the batch should have the same layer structure. "
|
||||
f"Detected sizes: {video_fhw}. Using the first image's layer structure {first_entry} "
|
||||
"for RoPE computation, which may lead to incorrect results for other images in the batch."
|
||||
)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
@@ -324,11 +405,10 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
if idx != layer_num:
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
|
||||
else:
|
||||
### For the condition image, we set the layer index to -1
|
||||
video_freq = self._compute_condition_freqs(frame, height, width)
|
||||
video_freq = video_freq.to(device)
|
||||
video_freq = self._compute_condition_freqs(frame, height, width, device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
if self.scale_rope:
|
||||
@@ -337,17 +417,21 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_vid_index = max(max_vid_index, layer_num)
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
# Create device-specific copy for text freqs without modifying self.pos_freqs
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0):
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
@@ -363,10 +447,13 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width):
|
||||
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
@@ -454,7 +541,6 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
# Compute joint attention
|
||||
joint_hidden_states = dispatch_attention_fn(
|
||||
joint_query,
|
||||
joint_key,
|
||||
@@ -762,14 +848,25 @@ class QwenImageTransformer2DModel(
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
|
||||
Mask of the input conditions.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
|
||||
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
|
||||
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
|
||||
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
|
||||
Image shapes for RoPE computation.
|
||||
txt_seq_lens (`List[int]`, *optional*, **Deprecated**):
|
||||
Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be
|
||||
used to compute RoPE sequence length.
|
||||
guidance (`torch.Tensor`, *optional*):
|
||||
Guidance tensor for conditional generation.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
controlnet_block_samples (*optional*):
|
||||
ControlNet block samples to add to the transformer blocks.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
@@ -778,6 +875,15 @@ class QwenImageTransformer2DModel(
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. "
|
||||
"Please use `encoder_hidden_states_mask` instead. "
|
||||
"The mask-based approach is more flexible and supports variable-length sequences.",
|
||||
standard_warn=False,
|
||||
)
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
@@ -810,6 +916,11 @@ class QwenImageTransformer2DModel(
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
|
||||
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states, encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
@@ -819,7 +930,17 @@ class QwenImageTransformer2DModel(
|
||||
else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond)
|
||||
)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
|
||||
|
||||
# Construct joint attention mask once to avoid reconstructing in every block
|
||||
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
|
||||
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
|
||||
if encoder_hidden_states_mask is not None:
|
||||
# Build joint mask: [text_mask, all_ones_for_image]
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
@@ -827,10 +948,10 @@ class QwenImageTransformer2DModel(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states_mask,
|
||||
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_kwargs,
|
||||
block_attention_kwargs,
|
||||
modulate_index,
|
||||
)
|
||||
|
||||
@@ -838,10 +959,10 @@ class QwenImageTransformer2DModel(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=attention_kwargs,
|
||||
joint_attention_kwargs=block_attention_kwargs,
|
||||
modulate_index=modulate_index,
|
||||
)
|
||||
|
||||
|
||||
@@ -455,7 +455,7 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
|
||||
return "Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
@@ -579,7 +579,7 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
|
||||
return "Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
@@ -682,18 +682,6 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -708,14 +696,6 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
)
|
||||
]
|
||||
] * block_state.batch_size
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
@@ -750,18 +730,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -783,15 +751,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
]
|
||||
] * block_state.batch_size
|
||||
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
@@ -155,7 +155,7 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description=(
|
||||
"All conditional model inputs for the denoiser. "
|
||||
"It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens."
|
||||
"It should contain prompt_embeds/negative_prompt_embeds."
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -182,7 +182,6 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
|
||||
img_shapes=block_state.img_shapes,
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
encoder_hidden_states_mask=block_state.prompt_embeds_mask,
|
||||
txt_seq_lens=block_state.txt_seq_lens,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
@@ -254,10 +253,6 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
getattr(block_state, "prompt_embeds_mask", None),
|
||||
getattr(block_state, "negative_prompt_embeds_mask", None),
|
||||
),
|
||||
"txt_seq_lens": (
|
||||
getattr(block_state, "txt_seq_lens", None),
|
||||
getattr(block_state, "negative_txt_seq_lens", None),
|
||||
),
|
||||
}
|
||||
|
||||
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
|
||||
@@ -358,10 +353,6 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
getattr(block_state, "prompt_embeds_mask", None),
|
||||
getattr(block_state, "negative_prompt_embeds_mask", None),
|
||||
),
|
||||
"txt_seq_lens": (
|
||||
getattr(block_state, "txt_seq_lens", None),
|
||||
getattr(block_state, "negative_txt_seq_lens", None),
|
||||
),
|
||||
}
|
||||
|
||||
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
|
||||
|
||||
@@ -12,9 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
QwenImageControlNetBeforeDenoiserStep,
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
@@ -394,6 +399,14 @@ class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DECODE
|
||||
@@ -467,3 +480,9 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
|
||||
]
|
||||
|
||||
@@ -12,11 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImageEditRoPEInputsStep,
|
||||
@@ -307,6 +310,14 @@ class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
|
||||
" - `QwenImageEditDecodeStep` (edit) is used when `mask` is not provided.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 5. AUTO BLOCKS & PRESETS
|
||||
@@ -334,3 +345,9 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
"- for edit (img2img) generation, you need to provide `image`\n"
|
||||
"- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
]
|
||||
|
||||
@@ -12,9 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
QwenImageEditPlusRoPEInputsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
@@ -136,6 +141,14 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit Plus edit (img2img) task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
@@ -179,3 +192,9 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
"- Each image is resized independently based on its own aspect ratio.\n"
|
||||
"- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
]
|
||||
|
||||
@@ -13,9 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
QwenImageLayeredPrepareLatentsStep,
|
||||
QwenImageLayeredRoPEInputsStep,
|
||||
@@ -134,6 +139,14 @@ class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Layered img2img task."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. AUTO BLOCKS & PRESETS
|
||||
@@ -157,3 +170,9 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto Modular pipeline for layered denoising tasks using QwenImage-Layered."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
|
||||
]
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
|
||||
@@ -252,7 +253,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2TextConnectors(ModelMixin, ConfigMixin):
|
||||
class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
|
||||
"""
|
||||
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
|
||||
streams.
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
|
||||
from ...models.transformers import LTX2VideoTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -184,7 +184,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
|
||||
class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
@@ -653,6 +653,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
if latents.ndim == 5:
|
||||
# latents are of shape [B, C, F, H, W], need to be packed
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
height = height // self.vae_spatial_compression_ratio
|
||||
@@ -677,29 +682,23 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 8,
|
||||
audio_latent_length: int = 1, # 1 is just a dummy value
|
||||
num_mel_bins: int = 64,
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 25.0,
|
||||
sampling_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
duration_s = num_frames / frame_rate
|
||||
latents_per_second = (
|
||||
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
latent_length = round(duration_s * latents_per_second)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype), latent_length
|
||||
if latents.ndim == 4:
|
||||
# latents are of shape [B, C, L, M], need to be packed
|
||||
latents = self._pack_audio_latents(latents)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
# TODO: confirm whether this logic is correct
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
|
||||
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -709,7 +708,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_audio_latents(latents)
|
||||
return latents, latent_length
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
@@ -750,6 +749,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 24.0,
|
||||
num_inference_steps: int = 40,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
@@ -788,6 +788,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
num_inference_steps (`int`, *optional*, defaults to 40):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
@@ -922,6 +926,14 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
if latents is not None:
|
||||
if latents.ndim == 5:
|
||||
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
|
||||
else:
|
||||
logger.warning(
|
||||
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
|
||||
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
|
||||
)
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
@@ -937,20 +949,30 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
latents,
|
||||
)
|
||||
|
||||
duration_s = num_frames / frame_rate
|
||||
audio_latents_per_second = (
|
||||
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
audio_num_frames = round(duration_s * audio_latents_per_second)
|
||||
if audio_latents is not None:
|
||||
if audio_latents.ndim == 4:
|
||||
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
|
||||
else:
|
||||
logger.warning(
|
||||
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
|
||||
f" cannot be inferred. Make sure the supplied `num_frames` is correct."
|
||||
)
|
||||
|
||||
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
num_channels_latents_audio = (
|
||||
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
|
||||
)
|
||||
audio_latents, audio_num_frames = self.prepare_audio_latents(
|
||||
audio_latents = self.prepare_audio_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=num_channels_latents_audio,
|
||||
audio_latent_length=audio_num_frames,
|
||||
num_mel_bins=num_mel_bins,
|
||||
num_frames=num_frames, # Video frames, audio frames will be calculated from this
|
||||
frame_rate=frame_rate,
|
||||
sampling_rate=self.audio_sampling_rate,
|
||||
hop_length=self.audio_hop_length,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -958,7 +980,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 1024),
|
||||
|
||||
@@ -689,6 +689,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
conditioning_mask = self._pack_latents(
|
||||
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
).squeeze(-1)
|
||||
if latents.ndim == 5:
|
||||
# latents are of shape [B, C, F, H, W], need to be packed
|
||||
latents = self._pack_latents(
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape:
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}."
|
||||
@@ -737,29 +742,23 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 8,
|
||||
audio_latent_length: int = 1, # 1 is just a dummy value
|
||||
num_mel_bins: int = 64,
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 25.0,
|
||||
sampling_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
duration_s = num_frames / frame_rate
|
||||
latents_per_second = (
|
||||
float(sampling_rate) / float(hop_length) / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
latent_length = round(duration_s * latents_per_second)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype), latent_length
|
||||
if latents.ndim == 4:
|
||||
# latents are of shape [B, C, L, M], need to be packed
|
||||
latents = self._pack_audio_latents(latents)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
# TODO: confirm whether this logic is correct
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
shape = (batch_size, num_channels_latents, latent_length, latent_mel_bins)
|
||||
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -769,7 +768,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_audio_latents(latents)
|
||||
return latents, latent_length
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
@@ -811,6 +810,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
num_frames: int = 121,
|
||||
frame_rate: float = 24.0,
|
||||
num_inference_steps: int = 40,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
@@ -851,6 +851,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
num_inference_steps (`int`, *optional*, defaults to 40):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
@@ -982,6 +986,19 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
if latents is not None:
|
||||
if latents.ndim == 5:
|
||||
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
|
||||
else:
|
||||
logger.warning(
|
||||
f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be"
|
||||
f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct."
|
||||
)
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
if latents is None:
|
||||
image = self.video_processor.preprocess(image, height=height, width=width)
|
||||
image = image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
@@ -1002,20 +1019,30 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
if self.do_classifier_free_guidance:
|
||||
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
|
||||
|
||||
duration_s = num_frames / frame_rate
|
||||
audio_latents_per_second = (
|
||||
self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
|
||||
)
|
||||
audio_num_frames = round(duration_s * audio_latents_per_second)
|
||||
if audio_latents is not None:
|
||||
if audio_latents.ndim == 4:
|
||||
_, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M]
|
||||
else:
|
||||
logger.warning(
|
||||
f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims"
|
||||
f" cannot be inferred. Make sure the supplied `num_frames` is correct."
|
||||
)
|
||||
|
||||
num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
num_channels_latents_audio = (
|
||||
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
|
||||
)
|
||||
audio_latents, audio_num_frames = self.prepare_audio_latents(
|
||||
audio_latents = self.prepare_audio_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=num_channels_latents_audio,
|
||||
audio_latent_length=audio_num_frames,
|
||||
num_mel_bins=num_mel_bins,
|
||||
num_frames=num_frames, # Video frames, audio frames will be calculated from this
|
||||
frame_rate=frame_rate,
|
||||
sampling_rate=self.audio_sampling_rate,
|
||||
hop_length=self.audio_hop_length,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -1023,12 +1050,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("base_image_seq_len", 1024),
|
||||
|
||||
4
src/diffusers/pipelines/ltx2/utils.py
Normal file
4
src/diffusers/pipelines/ltx2/utils.py
Normal file
@@ -0,0 +1,4 @@
|
||||
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875]
|
||||
|
||||
# Reduced schedule for super-resolution stage 2 (subset of distilled values)
|
||||
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875]
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
# Modifications by Decart AI Team:
|
||||
# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension.
|
||||
# - Based on pipeline_wan.py, but with supports receiving a condition video appended to the channel dimension.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -672,11 +672,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -695,7 +690,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -709,7 +703,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -909,7 +909,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
@@ -920,7 +919,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
@@ -935,7 +933,6 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
|
||||
@@ -852,7 +852,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
@@ -863,7 +862,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
@@ -878,7 +876,6 @@ class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderM
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
|
||||
@@ -793,11 +793,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -821,7 +816,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -836,7 +830,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -1008,11 +1008,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -1035,7 +1030,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -1050,7 +1044,6 @@ class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -663,6 +663,13 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# QwenImageEditPlusPipeline does not currently support batch_size > 1
|
||||
if batch_size > 1:
|
||||
raise ValueError(
|
||||
f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. "
|
||||
"Please process prompts one at a time."
|
||||
)
|
||||
|
||||
device = self._execution_device
|
||||
# 3. Preprocess image
|
||||
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
|
||||
@@ -777,11 +784,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -805,7 +807,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -820,7 +821,6 @@ class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -775,11 +775,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -797,7 +792,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -811,7 +805,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -944,11 +944,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -966,7 +961,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -980,7 +974,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -781,10 +781,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long)
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
@@ -809,7 +805,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
additional_t_cond=is_rgb,
|
||||
return_dict=False,
|
||||
@@ -825,7 +820,6 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
additional_t_cond=is_rgb,
|
||||
return_dict=False,
|
||||
@@ -885,7 +879,7 @@ the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>as
|
||||
|
||||
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
|
||||
|
||||
latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w)
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w
|
||||
|
||||
|
||||
@@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
@@ -114,23 +116,3 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 16, 16, 3)
|
||||
@@ -147,26 +149,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"text_encoder",
|
||||
)
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
@@ -162,23 +164,3 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers"
|
||||
denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
@@ -146,23 +148,3 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"text_encoder_2",
|
||||
)
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
@@ -172,26 +174,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
|
||||
271
tests/lora/test_lora_layers_ltx2.py
Normal file
271
tests/lora/test_lora_layers_ltx2.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
|
||||
from ..testing_utils import floats_tensor, require_peft_backend
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = LTX2Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 16,
|
||||
"audio_in_channels": 4,
|
||||
"audio_out_channels": 4,
|
||||
"audio_num_attention_heads": 2,
|
||||
"audio_attention_head_dim": 4,
|
||||
"audio_cross_attention_dim": 8,
|
||||
"num_layers": 1,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"caption_channels": 32,
|
||||
"rope_double_precision": False,
|
||||
"rope_type": "split",
|
||||
}
|
||||
transformer_cls = LTX2VideoTransformer3DModel
|
||||
|
||||
vae_kwargs = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"block_out_channels": (8,),
|
||||
"decoder_block_out_channels": (8,),
|
||||
"layers_per_block": (1,),
|
||||
"decoder_layers_per_block": (1, 1),
|
||||
"spatio_temporal_scaling": (True,),
|
||||
"decoder_spatio_temporal_scaling": (True,),
|
||||
"decoder_inject_noise": (False, False),
|
||||
"downsample_type": ("spatial",),
|
||||
"upsample_residual": (False,),
|
||||
"upsample_factor": (1,),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
}
|
||||
vae_cls = AutoencoderKLLTX2Video
|
||||
|
||||
audio_vae_kwargs = {
|
||||
"base_channels": 4,
|
||||
"output_channels": 2,
|
||||
"ch_mult": (1,),
|
||||
"num_res_blocks": 1,
|
||||
"attn_resolutions": None,
|
||||
"in_channels": 2,
|
||||
"resolution": 32,
|
||||
"latent_channels": 2,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 8,
|
||||
}
|
||||
audio_vae_cls = AutoencoderKLLTX2Audio
|
||||
|
||||
vocoder_kwargs = {
|
||||
"in_channels": 16, # output_channels * mel_bins = 2 * 8
|
||||
"hidden_channels": 32,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [4, 4],
|
||||
"upsample_factors": [2, 2],
|
||||
"resnet_kernel_sizes": [3],
|
||||
"resnet_dilations": [[1, 3, 5]],
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"output_sampling_rate": 16000,
|
||||
}
|
||||
vocoder_cls = LTX2Vocoder
|
||||
|
||||
connectors_kwargs = {
|
||||
"caption_channels": 32, # Will be set dynamically from text_encoder
|
||||
"text_proj_in_factor": 2, # Will be set dynamically from text_encoder
|
||||
"video_connector_num_attention_heads": 4,
|
||||
"video_connector_attention_head_dim": 8,
|
||||
"video_connector_num_layers": 1,
|
||||
"video_connector_num_learnable_registers": None,
|
||||
"audio_connector_num_attention_heads": 4,
|
||||
"audio_connector_attention_head_dim": 8,
|
||||
"audio_connector_num_layers": 1,
|
||||
"audio_connector_num_learnable_registers": None,
|
||||
"connector_rope_base_seq_len": 32,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
}
|
||||
connectors_cls = LTX2TextConnectors
|
||||
|
||||
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-gemma3"
|
||||
text_encoder_cls, text_encoder_id = (
|
||||
Gemma3ForConditionalGeneration,
|
||||
"hf-internal-testing/tiny-gemma3",
|
||||
)
|
||||
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_out.0"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 5, 32, 32, 3)
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 16
|
||||
num_channels = 4
|
||||
num_frames = 5
|
||||
num_latent_frames = 2
|
||||
latent_height = 8
|
||||
latent_width = 8
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_latent_frames, num_channels, latent_height, latent_width))
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "a robot dancing",
|
||||
"num_frames": num_frames,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"frame_rate": 25.0,
|
||||
"max_sequence_length": sequence_length,
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
|
||||
# Override to instantiate LTX2-specific components (connectors, audio_vae, vocoder)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
|
||||
tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
|
||||
|
||||
# Update caption_channels and text_proj_in_factor based on text_encoder config
|
||||
transformer_kwargs = self.transformer_kwargs.copy()
|
||||
transformer_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size
|
||||
|
||||
connectors_kwargs = self.connectors_kwargs.copy()
|
||||
connectors_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size
|
||||
connectors_kwargs["text_proj_in_factor"] = text_encoder.config.text_config.num_hidden_layers + 1
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = self.transformer_cls(**transformer_kwargs)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = self.vae_cls(**self.vae_kwargs)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
audio_vae = self.audio_vae_cls(**self.audio_vae_kwargs)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vocoder = self.vocoder_cls(**self.vocoder_kwargs)
|
||||
|
||||
torch.manual_seed(0)
|
||||
connectors = self.connectors_cls(**connectors_kwargs)
|
||||
|
||||
if scheduler_cls is None:
|
||||
scheduler_cls = self.scheduler_cls
|
||||
scheduler = scheduler_cls(**self.scheduler_kwargs)
|
||||
|
||||
rank = 4
|
||||
lora_alpha = rank if lora_alpha is None else lora_alpha
|
||||
|
||||
text_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=self.text_encoder_target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
|
||||
pipeline_components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"audio_vae": audio_vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
}
|
||||
|
||||
return pipeline_components, text_lora_config, denoiser_lora_config
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in LTX2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in LTX2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in LTX2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
@@ -76,6 +76,8 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
@@ -125,23 +127,3 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -74,6 +74,8 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma"
|
||||
text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers"
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 4, 4, 3)
|
||||
@@ -113,26 +115,6 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@skip_mps
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
|
||||
@@ -67,6 +67,8 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 7, 16, 16, 3)
|
||||
@@ -117,26 +119,6 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -69,6 +69,8 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
)
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
@@ -107,23 +109,3 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -75,6 +75,8 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
|
||||
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
@@ -117,26 +119,6 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
return super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
@@ -73,6 +73,8 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
@@ -121,23 +123,3 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -85,6 +85,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 16, 16, 3)
|
||||
@@ -139,26 +141,6 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
|
||||
@@ -75,6 +75,8 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
@@ -263,23 +265,3 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@unittest.skip("Not supported in ZImage.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -117,6 +117,7 @@ class PeftLoraLoaderMixinTests:
|
||||
tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
|
||||
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
|
||||
tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
|
||||
supports_text_encoder_loras = True
|
||||
|
||||
unet_kwargs = None
|
||||
transformer_cls = None
|
||||
@@ -333,6 +334,9 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached on the text encoder
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -457,6 +461,9 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached on the text encoder + scale argument
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -494,6 +501,9 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -555,6 +565,9 @@ class PeftLoraLoaderMixinTests:
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA.
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -593,6 +606,9 @@ class PeftLoraLoaderMixinTests:
|
||||
with different ranks and some adapters removed
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, _, _ = self.get_dummy_components()
|
||||
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
|
||||
text_lora_config = LoraConfig(
|
||||
@@ -651,6 +667,9 @@ class PeftLoraLoaderMixinTests:
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
if not self.supports_text_encoder_loras:
|
||||
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
|
||||
|
||||
components, text_lora_config, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
@@ -15,10 +15,10 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
@@ -68,7 +68,6 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
"txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(),
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
@@ -91,6 +90,180 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_infers_text_seq_len_from_mask(self):
|
||||
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
|
||||
|
||||
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
|
||||
self.assertIsInstance(rope_text_seq_len, int)
|
||||
|
||||
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
|
||||
self.assertIsInstance(per_sample_len, torch.Tensor)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 2)
|
||||
|
||||
# Verify mask is normalized to bool dtype
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
|
||||
|
||||
# Verify rope_text_seq_len is at least the sequence length
|
||||
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Verify model runs successfully with inferred values
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 3: Different mask pattern (padding at beginning)
|
||||
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
|
||||
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
|
||||
|
||||
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
|
||||
)
|
||||
|
||||
# Max valid position is 6 (last token), so per_sample_len should be 7
|
||||
self.assertEqual(int(per_sample_len2.max().item()), 7)
|
||||
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
|
||||
|
||||
# Test 4: No mask provided (None case)
|
||||
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], None
|
||||
)
|
||||
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(rope_text_seq_len_none, int)
|
||||
self.assertIsNone(per_sample_len_none)
|
||||
self.assertIsNone(normalized_mask_none)
|
||||
|
||||
def test_non_contiguous_attention_mask(self):
|
||||
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
# Pattern: [True, False, True, False, True, False, False]
|
||||
encoder_hidden_states_mask[:, 1] = 0
|
||||
encoder_hidden_states_mask[:, 3] = 0
|
||||
encoder_hidden_states_mask[:, 5:] = 0
|
||||
|
||||
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 5)
|
||||
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(inferred_rope_len, int)
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
def test_txt_seq_lens_deprecation(self):
|
||||
"""Test that passing txt_seq_lens raises a deprecation warning."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Prepare inputs with txt_seq_lens (deprecated parameter)
|
||||
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
|
||||
|
||||
# Remove encoder_hidden_states_mask to use the deprecated path
|
||||
inputs_with_deprecated = inputs.copy()
|
||||
inputs_with_deprecated.pop("encoder_hidden_states_mask")
|
||||
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
|
||||
|
||||
# Test that deprecation warning is raised
|
||||
with self.assertWarns(FutureWarning) as warning_context:
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_with_deprecated)
|
||||
|
||||
# Verify the warning message mentions the deprecation
|
||||
warning_message = str(warning_context.warning)
|
||||
self.assertIn("txt_seq_lens", warning_message)
|
||||
self.assertIn("deprecated", warning_message)
|
||||
self.assertIn("encoder_hidden_states_mask", warning_message)
|
||||
|
||||
# Verify the model still works correctly despite the deprecation
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
def test_layered_model_with_mask(self):
|
||||
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
|
||||
# Create layered model config
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
|
||||
"use_layer3d_rope": True, # Enable layered RoPE
|
||||
"use_additional_t_cond": True, # Enable additional time conditioning
|
||||
}
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Verify the model uses QwenEmbedLayer3DRope
|
||||
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
|
||||
|
||||
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
|
||||
# Test single generation with layered structure
|
||||
batch_size = 1
|
||||
text_seq_len = 7
|
||||
img_h, img_w = 4, 4
|
||||
layers = 4
|
||||
|
||||
# For layered model: (layers + 1) because we have N layers + 1 combined image
|
||||
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
|
||||
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
|
||||
|
||||
# Create mask with some padding
|
||||
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
|
||||
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device)
|
||||
|
||||
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
|
||||
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
|
||||
|
||||
# Layer structure: 4 layers + 1 condition image
|
||||
img_shapes = [
|
||||
[
|
||||
(1, img_h, img_w), # layer 0
|
||||
(1, img_h, img_w), # layer 1
|
||||
(1, img_h, img_w), # layer 2
|
||||
(1, img_h, img_w), # layer 3
|
||||
(1, img_h, img_w), # condition image (last one gets special treatment)
|
||||
]
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
timestep=timestep,
|
||||
img_shapes=img_shapes,
|
||||
additional_t_cond=addition_t_cond,
|
||||
)
|
||||
|
||||
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
|
||||
|
||||
|
||||
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
@@ -101,6 +274,5 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
@pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True)
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
|
||||
Reference in New Issue
Block a user