Compare commits

...

9 Commits

Author SHA1 Message Date
yiyixuxu
4a56b5083b up 2026-02-10 04:57:17 +01:00
Sayak Paul
20efb79d49 [modular] add modular tests for Z-Image and Wan (#13078)
* add wan modular tests

* style.

* add z-image tests and other fixes.

* style.

* increase tolerance for zimage

* style

* address reviewer feedback.

* address reviewer feedback.

* remove unneeded func

* simplify even more.
2026-02-09 08:27:59 -10:00
Linoy Tsaban
8933686770 Z image lora training (#13056)
* initial commit

* initial commit

* initial commit

* initial commit

* initial commit

* initial commit

* initial commit

* fix vae

* fix prompts

* Apply style fixes

* fix license

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-02-09 15:45:59 +02:00
dg845
baaa8d040b LTX 2 Improve encode_video by Accepting More Input Types (#13057)
* Support different pipeline outputs for LTX 2 encode_video

* Update examples to use improved encode_video function

* Fix comment

* Address review comments

* make style and make quality

* Have non-iterator video inputs respect video_chunks_number

* make style and make quality

* Add warning when encode_video receives a non-denormalized np.ndarray

* make style and make quality

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-02-08 19:40:34 -08:00
YiYi Xu
44f4dc0054 [Modular] guard ModularPipeline.blocks attribute (#13014)
* up

* style
2026-02-08 16:12:47 -10:00
YiYi Xu
fd705bd8ff [Modular] refactor Wan: modular pipelines by task etc (#13063)
* initil

* fix init_pipeline etc

* style

* copies

* fix copies

* upup more

* fix test

* add output type (#13091)

---------

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2026-02-07 11:28:27 -10:00
hlky
09dca386d0 ZImageControlNet cfg (#13080)
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2026-02-07 10:40:55 -10:00
YiYi Xu
10dc589a94 [modular]simplify components manager doc (#13088)
* simplify components manager doc

* Apply suggestion from @yiyixuxu

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Apply suggestion from @stevhliu

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-02-06 09:55:34 -10:00
David El Malih
44b8201d98 docs: improve docstring scheduling_dpmsolver_multistep_inverse.py (#13085)
* Improve docstring scheduling dpmsolver sde

* Update scheduling_dpmsolver_sde.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* run make fix-copies

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-02-06 09:20:05 -08:00
37 changed files with 3494 additions and 995 deletions

View File

@@ -106,8 +106,6 @@ video, audio = pipe(
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
@@ -185,8 +183,6 @@ video, audio = pipe(
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],

View File

@@ -12,179 +12,85 @@ specific language governing permissions and limitations under the License.
# ComponentsManager
The [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), prevents duplicate model instances, and supports offloading.
The [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), and supports offloading.
This guide will show you how to use [`ComponentsManager`] to manage components and device memory.
## Add a component
## Connect to a pipeline
The [`ComponentsManager`] should be created alongside a [`ModularPipeline`] in either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`].
Create a [`ComponentsManager`] and pass it to a [`ModularPipeline`] with either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`].
> [!TIP]
> The `collection` parameter is optional but makes it easier to organize and manage components.
<hfoptions id="create">
<hfoption id="from_pretrained">
```py
from diffusers import ModularPipeline, ComponentsManager
import torch
comp = ComponentsManager()
pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1")
manager = ComponentsManager()
pipe = ModularPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", components_manager=manager)
pipe.load_components(torch_dtype=torch.bfloat16)
```
</hfoption>
<hfoption id="init_pipeline">
```py
from diffusers import ComponentsManager
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
components = ComponentsManager()
t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
from diffusers import ModularPipelineBlocks, ComponentsManager
import torch
manager = ComponentsManager()
blocks = ModularPipelineBlocks.from_pretrained("diffusers/Florence2-image-Annotator", trust_remote_code=True)
pipe= blocks.init_pipeline(components_manager=manager)
pipe.load_components(torch_dtype=torch.bfloat16)
```
</hfoption>
</hfoptions>
Components are only loaded and registered when using [`~ModularPipeline.load_components`] or [`~ModularPipeline.load_components`]. The example below uses [`~ModularPipeline.load_components`] to create a second pipeline that reuses all the components from the first one, and assigns it to a different collection
Components loaded by the pipeline are automatically registered in the manager. You can inspect them right away.
## Inspect components
Print the [`ComponentsManager`] to see all registered components, including their class, device placement, dtype, memory size, and load ID.
The output below corresponds to the `from_pretrained` example above.
```py
pipe.load_components()
pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
Components:
=============================================================================================================================
Models:
-----------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID
-----------------------------------------------------------------------------------------------------------------------------
text_encoder_140458257514752 | Qwen3Model | cpu | torch.bfloat16 | 7.49 | Tongyi-MAI/Z-Image-Turbo|text_encoder|null|null
vae_140458257515376 | AutoencoderKL | cpu | torch.bfloat16 | 0.16 | Tongyi-MAI/Z-Image-Turbo|vae|null|null
transformer_140458257515616 | ZImageTransformer2DModel | cpu | torch.bfloat16 | 11.46 | Tongyi-MAI/Z-Image-Turbo|transformer|null|null
-----------------------------------------------------------------------------------------------------------------------------
Other Components:
-----------------------------------------------------------------------------------------------------------------------------
ID | Class | Collection
-----------------------------------------------------------------------------------------------------------------------------
scheduler_140461023555264 | FlowMatchEulerDiscreteScheduler | N/A
tokenizer_140458256346432 | Qwen2Tokenizer | N/A
-----------------------------------------------------------------------------------------------------------------------------
```
Use the [`~ModularPipeline.null_component_names`] property to identify any components that need to be loaded, retrieve them with [`~ComponentsManager.get_components_by_names`], and then call [`~ModularPipeline.update_components`] to add the missing components.
```py
pipe2.null_component_names
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']
comp_dict = comp.get_components_by_names(names=pipe2.null_component_names)
pipe2.update_components(**comp_dict)
```
To add individual components, use the [`~ComponentsManager.add`] method. This registers a component with a unique id.
```py
from diffusers import AutoModel
text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
component_id = comp.add("text_encoder", text_encoder)
comp
```
Use [`~ComponentsManager.remove`] to remove a component using their id.
```py
comp.remove("text_encoder_139917733042864")
```
## Retrieve a component
The [`ComponentsManager`] provides several methods to retrieve registered components.
### get_one
The [`~ComponentsManager.get_one`] method returns a single component and supports pattern matching for the `name` parameter. If multiple components match, [`~ComponentsManager.get_one`] returns an error.
| Pattern | Example | Description |
|-------------|----------------------------------|-------------------------------------------|
| exact | `comp.get_one(name="unet")` | exact name match |
| wildcard | `comp.get_one(name="unet*")` | names starting with "unet" |
| exclusion | `comp.get_one(name="!unet")` | exclude components named "unet" |
| or | `comp.get_one(name="unet&#124;vae")` | name is "unet" or "vae" |
[`~ComponentsManager.get_one`] also filters components by the `collection` argument or `load_id` argument.
```py
comp.get_one(name="unet", collection="sdxl")
```
### get_components_by_names
The [`~ComponentsManager.get_components_by_names`] method accepts a list of names and returns a dictionary mapping names to components. This is especially useful with [`ModularPipeline`] since they provide lists of required component names and the returned dictionary can be passed directly to [`~ModularPipeline.update_components`].
```py
component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"])
{"text_encoder": component1, "unet": component2, "vae": component3}
```
## Duplicate detection
It is recommended to load model components with [`ComponentSpec`] to assign components with a unique id that encodes their loading parameters. This allows [`ComponentsManager`] to automatically detect and prevent duplicate model instances even when different objects represent the same underlying checkpoint.
```py
from diffusers import ComponentSpec, ComponentsManager
from transformers import CLIPTextModel
comp = ComponentsManager()
# Create ComponentSpec for the first text encoder
spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel)
# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from the same repo/subfolder)
spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel)
# Load and add both components - the manager will detect they're the same model
comp.add("text_encoder", spec.load())
comp.add("text_encoder_duplicated", spec_duplicated.load())
```
This returns a warning with instructions for removing the duplicate.
```py
ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('<component_id>')`.
'text_encoder_duplicated_139917580682672'
```
You could also add a component without using [`ComponentSpec`] and duplicate detection still works in most cases even if you're adding the same component under a different name.
However, [`ComponentManager`] can't detect duplicates when you load the same component into different objects. In this case, you should load a model with [`ComponentSpec`].
```py
text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
comp.add("text_encoder", text_encoder_2)
'text_encoder_139917732983664'
```
## Collections
Collections are labels assigned to components for better organization and management. Add a component to a collection with the `collection` argument in [`~ComponentsManager.add`].
Only one component per name is allowed in each collection. Adding a second component with the same name automatically removes the first component.
```py
from diffusers import ComponentSpec, ComponentsManager
comp = ComponentsManager()
# Create ComponentSpec for the first UNet
spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel)
# Create ComponentSpec for a different UNet
spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16")
# Add both UNets to the same collection - the second one will replace the first
comp.add("unet", spec.load(), collection="sdxl")
comp.add("unet", spec2.load(), collection="sdxl")
```
This makes it convenient to work with node-based systems because you can:
- Mark all models as loaded from one node with the `collection` label.
- Automatically replace models when new checkpoints are loaded under the same name.
- Batch delete all models in a collection when a node is removed.
The table shows models (with device, dtype, and memory info) separately from other components like schedulers and tokenizers. If any models have LoRA adapters, IP-Adapters, or quantization applied, that information is displayed in an additional section at the bottom.
## Offloading
The [`~ComponentsManager.enable_auto_cpu_offload`] method is a global offloading strategy that works across all models regardless of which pipeline is using them. Once enabled, you don't need to worry about device placement if you add or remove components.
```py
comp.enable_auto_cpu_offload(device="cuda")
manager.enable_auto_cpu_offload(device="cuda")
```
All models begin on the CPU and [`ComponentsManager`] moves them to the appropriate device right before they're needed, and moves other models back to the CPU when GPU memory is low.
You can set your own rules for which models to offload first.
Call [`~ComponentsManager.disable_auto_cpu_offload`] to disable offloading.
```py
manager.disable_auto_cpu_offload()
```

View File

@@ -0,0 +1,347 @@
# DreamBooth training example for Z-Image
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
The `train_dreambooth_lora_z_image.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image).
> [!NOTE]
> **About Z-Image**
>
> Z-Image is a high-quality text-to-image generation model from Alibaba's Tongyi Lab. It uses a DiT (Diffusion Transformer) architecture with Qwen3 as the text encoder. The model excels at generating images with accurate text rendering, especially for Chinese characters.
> [!NOTE]
> **Memory consumption**
>
> Z-Image is relatively memory efficient compared to other large-scale diffusion models. Below we provide some tips and tricks to further reduce memory consumption during training.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_z_image.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
## Memory Optimizations
> [!NOTE]
> Many of these techniques complement each other and can be used together to further reduce memory consumption. However some techniques may be mutually exclusive so be sure to check before launching a training run.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the VAE and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
Pre-encode the training images with the VAE, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
### QLoRA: Low Precision Training with Quantization
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
- **FP8 training** with `torchao`:
Enable FP8 training by passing `--do_fp8_training`.
> [!IMPORTANT]
> Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater. If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers.
- **NF4 training** with `bitsandbytes`:
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing `--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
### Gradient Checkpointing and Accumulation
* `--gradient_accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. By passing a value > 1 you can reduce the amount of backward/update passes and hence also memory requirements.
* With `--gradient_checkpointing` we can save memory by not storing all intermediate activations during the forward pass. Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expense of a slower backward pass.
### 8-bit-Adam Optimizer
When training with `AdamW` (doesn't apply to `prodigy`) you can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
### Image Resolution
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
Note that by default, images are resized to resolution of 1024, but it's good to keep in mind in case you're training on higher resolutions.
### Precision of saved LoRA layers
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
This reduces memory requirements significantly without a significant quality loss. Note that if you do wish to save the final layers in float32 at the expense of more memory usage, you can do so by passing `--upcast_before_saving`.
## Training Examples
### Z-Image Training
To perform DreamBooth with LoRA on Z-Image, run:
```bash
export MODEL_NAME="Tongyi-MAI/Z-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-z-image-lora"
accelerate launch train_dreambooth_lora_z_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=5.0 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb"` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. The default is 512. Note that this will use more resources and may slow down the training in some cases.
### Training with FP8 Quantization
For reduced memory usage with FP8 training:
```bash
export MODEL_NAME="Tongyi-MAI/Z-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-z-image-lora-fp8"
accelerate launch train_dreambooth_lora_z_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=5.0 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### FSDP on the transformer
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
```yaml
distributed_type: FSDP
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_sharding_strategy: HYBRID_SHARD
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: ZImageTransformerBlock
fsdp_forward_prefetch: true
fsdp_sync_module_states: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_use_orig_params: false
fsdp_activation_checkpointing: true
fsdp_reshard_after_forward: true
fsdp_cpu_ram_efficient_loading: false
```
### Prodigy Optimizer
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
By using prodigy we can "eliminate" the need for manual learning rate tuning. Read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
To use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify:
```bash
--optimizer="prodigy"
```
> [!TIP]
> When using prodigy it's generally good practice to set `--learning_rate=1.0`
```bash
export MODEL_NAME="Tongyi-MAI/Z-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-z-image-lora-prodigy"
accelerate launch train_dreambooth_lora_z_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=5.0 \
--gradient_accumulation_steps=4 \
--optimizer="prodigy" \
--learning_rate=1.0 \
--report_to="wandb" \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha:
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by `lora_alpha / lora_rank`.
**lora_alpha vs. rank:**
This ratio dictates the LoRA's effective strength:
- `lora_alpha == rank`: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
- `lora_alpha < rank`: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
- `lora_alpha > rank`: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the UNet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the UNet with a diffusion Transformer (DiT). With this change, we may also want to explore applying LoRA training onto different types of layers and blocks.
To allow more flexibility and control over the targeted modules we added `--lora_layers`, in which you can specify in a comma separated string the exact modules for LoRA training. Here are some examples of target modules you can provide:
- For attention only layers: `--lora_layers="to_k,to_q,to_v,to_out.0"`
- For attention and feed-forward layers: `--lora_layers="to_k,to_q,to_v,to_out.0,ff.net.0.proj,ff.net.2"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string.
> [!NOTE]
> Keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
### Aspect Ratio Bucketing
We've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
```bash
--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
```
### Bilingual Prompts
Z-Image has strong support for both Chinese and English prompts. When training with Chinese prompts, ensure your dataset captions are properly encoded in UTF-8:
```bash
--instance_prompt="一只sks狗的照片"
--validation_prompt="一只sks狗在桶里的照片"
```
> [!TIP]
> Z-Image excels at text rendering in generated images, especially for Chinese characters. If your use case involves generating images with text, consider including text-related examples in your training data.
## Inference
Once you have trained a LoRA, you can load it for inference:
```python
import torch
from diffusers import ZImagePipeline
pipe = ZImagePipeline.from_pretrained("Tongyi-MAI/Z-Image", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load your trained LoRA
pipe.load_lora_weights("path/to/your/trained-z-image-lora")
# Generate an image
image = pipe(
prompt="A photo of sks dog in a bucket",
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).images[0]
image.save("output.png")
```
---
Since Z-Image finetuning is still in an experimental phase, we encourage you to explore different settings and share your insights! 🤗

File diff suppressed because it is too large Load Diff

View File

@@ -297,6 +297,8 @@ else:
"ComponentSpec",
"ModularPipeline",
"ModularPipelineBlocks",
"InputParam",
"OutputParam",
]
)
_import_structure["optimization"] = [
@@ -417,6 +419,7 @@ else:
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2KleinBaseModularPipeline",
"Flux2KleinModularPipeline",
"Flux2ModularPipeline",
"FluxAutoBlocks",
@@ -433,8 +436,13 @@ else:
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"Wan22AutoBlocks",
"WanAutoBlocks",
"Wan22Blocks",
"Wan22Image2VideoBlocks",
"Wan22Image2VideoModularPipeline",
"Wan22ModularPipeline",
"WanBlocks",
"WanImage2VideoAutoBlocks",
"WanImage2VideoModularPipeline",
"WanModularPipeline",
"ZImageAutoBlocks",
"ZImageModularPipeline",
@@ -1054,7 +1062,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ZImageTransformer2DModel,
attention_backend,
)
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks, InputParam, OutputParam
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
@@ -1156,6 +1164,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinBaseModularPipeline,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
FluxAutoBlocks,
@@ -1172,8 +1181,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
Wan22AutoBlocks,
WanAutoBlocks,
Wan22Blocks,
Wan22Image2VideoBlocks,
Wan22Image2VideoModularPipeline,
Wan22ModularPipeline,
WanBlocks,
WanImage2VideoAutoBlocks,
WanImage2VideoModularPipeline,
WanModularPipeline,
ZImageAutoBlocks,
ZImageModularPipeline,

View File

@@ -45,7 +45,16 @@ else:
"InsertableDict",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
_import_structure["wan"] = [
"WanBlocks",
"Wan22Blocks",
"WanImage2VideoAutoBlocks",
"Wan22Image2VideoBlocks",
"WanModularPipeline",
"Wan22ModularPipeline",
"WanImage2VideoModularPipeline",
"Wan22Image2VideoModularPipeline",
]
_import_structure["flux"] = [
"FluxAutoBlocks",
"FluxModularPipeline",
@@ -58,6 +67,7 @@ else:
"Flux2KleinBaseAutoBlocks",
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
"Flux2KleinBaseModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
@@ -88,6 +98,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinBaseModularPipeline,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
)
@@ -112,7 +123,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline,
)
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
from .wan import (
Wan22Blocks,
Wan22Image2VideoBlocks,
Wan22Image2VideoModularPipeline,
Wan22ModularPipeline,
WanBlocks,
WanImage2VideoAutoBlocks,
WanImage2VideoModularPipeline,
WanModularPipeline,
)
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
else:
import sys

View File

@@ -31,9 +31,7 @@ else:
"FluxAutoBeforeDenoiseStep",
"FluxAutoBlocks",
"FluxAutoDecodeStep",
"FluxAutoDenoiseStep",
"FluxKontextAutoBlocks",
"FluxKontextAutoDenoiseStep",
"FluxKontextBeforeDenoiseStep",
]
_import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]
@@ -55,9 +53,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxAutoBeforeDenoiseStep,
FluxAutoBlocks,
FluxAutoDecodeStep,
FluxAutoDenoiseStep,
FluxKontextAutoBlocks,
FluxKontextAutoDenoiseStep,
FluxKontextBeforeDenoiseStep,
)
from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline

View File

@@ -201,37 +201,6 @@ class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
)
# denoise: text2image
class FluxAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxDenoiseStep]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2image and img2img tasks."
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
)
# denoise: Flux Kontext
class FluxKontextAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxKontextDenoiseStep]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents for Flux Kontext. "
"This is a auto pipeline block that works for text2image and img2img tasks."
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
)
# decode: all task (text2img, img2img)
class FluxAutoDecodeStep(AutoPipelineBlocks):
@@ -322,7 +291,7 @@ class FluxKontextAutoInputStep(AutoPipelineBlocks):
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux"
block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
@@ -331,7 +300,7 @@ class FluxCoreDenoiseStep(SequentialPipelineBlocks):
"Core step that performs the denoising process. \n"
+ " - `FluxAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ " - `FluxDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
@@ -340,7 +309,7 @@ class FluxCoreDenoiseStep(SequentialPipelineBlocks):
class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux-kontext"
block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextAutoDenoiseStep]
block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
@@ -349,7 +318,7 @@ class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
"Core step that performs the denoising process. \n"
+ " - `FluxKontextAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxKontextAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxKontextAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ " - `FluxKontextDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings."

View File

@@ -55,7 +55,11 @@ else:
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
_import_structure["modular_pipeline"] = [
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
"Flux2KleinBaseModularPipeline",
]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -101,7 +105,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
)
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline
else:
import sys

View File

@@ -13,8 +13,6 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import Flux2LoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
@@ -59,46 +57,35 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
return num_channels_latents
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
class Flux2KleinModularPipeline(Flux2ModularPipeline):
"""
A ModularPipeline for Flux2-Klein.
A ModularPipeline for Flux2-Klein (distilled model).
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2KleinBaseAutoBlocks"
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinAutoBlocks"
else:
return "Flux2KleinBaseAutoBlocks"
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if getattr(self, "vae", None) is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 32
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
default_blocks_name = "Flux2KleinAutoBlocks"
@property
def requires_unconditional_embeds(self):
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
return False
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds
class Flux2KleinBaseModularPipeline(Flux2ModularPipeline):
"""
A ModularPipeline for Flux2-Klein (base model).
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2KleinBaseAutoBlocks"
@property
def requires_unconditional_embeds(self):

View File

@@ -156,6 +156,12 @@ MELLON_PARAM_TEMPLATES = {
"display": "slider",
"required_block_params": ["layers"],
},
"output_type": {
"label": "Output Type",
"type": "dropdown",
"default": "np",
"options": ["np", "pil", "pt"],
},
# ControlNet
"controlnet_conditioning_scale": {
"label": "Controlnet Conditioning Scale",

View File

@@ -54,19 +54,61 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# map regular pipeline to modular pipeline class name
def _create_default_map_fn(pipeline_class_name: str):
"""Create a mapping function that always returns the same pipeline class."""
def _map_fn(config_dict=None):
return pipeline_class_name
return _map_fn
def _flux2_klein_map_fn(config_dict=None):
if config_dict is None:
return "Flux2KleinModularPipeline"
if "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinModularPipeline"
else:
return "Flux2KleinBaseModularPipeline"
def _wan_map_fn(config_dict=None):
if config_dict is None:
return "WanModularPipeline"
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22ModularPipeline"
else:
return "WanModularPipeline"
def _wan_i2v_map_fn(config_dict=None):
if config_dict is None:
return "WanImage2VideoModularPipeline"
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22Image2VideoModularPipeline"
else:
return "WanImage2VideoModularPipeline"
MODULAR_PIPELINE_MAPPING = OrderedDict(
[
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
("flux-kontext", "FluxKontextModularPipeline"),
("flux2", "Flux2ModularPipeline"),
("flux2-klein", "Flux2KleinModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
("qwenimage-layered", "QwenImageLayeredModularPipeline"),
("z-image", "ZImageModularPipeline"),
("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")),
("wan", _wan_map_fn),
("wan-i2v", _wan_i2v_map_fn),
("flux", _create_default_map_fn("FluxModularPipeline")),
("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")),
("flux2", _create_default_map_fn("Flux2ModularPipeline")),
("flux2-klein", _flux2_klein_map_fn),
("qwenimage", _create_default_map_fn("QwenImageModularPipeline")),
("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")),
("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")),
("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")),
("z-image", _create_default_map_fn("ZImageModularPipeline")),
]
)
@@ -368,7 +410,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"""
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
"""
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
map_fn = MODULAR_PIPELINE_MAPPING.get(self.model_name, _create_default_map_fn("ModularPipeline"))
pipeline_class_name = map_fn()
diffusers_module = importlib.import_module("diffusers")
pipeline_class = getattr(diffusers_module, pipeline_class_name)
@@ -1547,7 +1590,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
if modular_config_dict is not None:
blocks_class_name = modular_config_dict.get("_blocks_class_name")
else:
blocks_class_name = self.get_default_blocks_name(config_dict)
blocks_class_name = self.default_blocks_name
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name)
@@ -1555,11 +1598,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
self.blocks = blocks
self._blocks = blocks
self._components_manager = components_manager
self._collection = collection
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
self._component_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_configs}
# update component_specs and config_specs based on modular_model_index.json
if modular_config_dict is not None:
@@ -1606,7 +1649,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
for name, config_spec in self._config_specs.items():
default_configs[name] = config_spec.default
self.register_to_config(**default_configs)
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
self.register_to_config(
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
)
@property
def default_call_parameters(self) -> Dict[str, Any]:
@@ -1615,13 +1660,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- Dictionary mapping input names to their default values
"""
params = {}
for input_param in self.blocks.inputs:
for input_param in self._blocks.inputs:
params[input_param.name] = input_param.default
return params
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
return self.default_blocks_name
@classmethod
def _load_pipeline_config(
cls,
@@ -1717,7 +1759,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
logger.debug(" try to determine the modular pipeline class from model_index.json")
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
model_name = _get_model(standard_pipeline_class.__name__)
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline"))
pipeline_class_name = map_fn(config_dict)
diffusers_module = importlib.import_module("diffusers")
pipeline_class = getattr(diffusers_module, pipeline_class_name)
else:
@@ -1788,7 +1831,15 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Returns:
- The docstring of the pipeline blocks
"""
return self.blocks.doc
return self._blocks.doc
@property
def blocks(self) -> ModularPipelineBlocks:
"""
Returns:
- A copy of the pipeline blocks
"""
return deepcopy(self._blocks)
def register_components(self, **kwargs):
"""
@@ -2524,7 +2575,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
def set_progress_bar_config(self, **kwargs):
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
for sub_block_name, sub_block in self._blocks.sub_blocks.items():
if hasattr(sub_block, "set_progress_bar_config"):
sub_block.set_progress_bar_config(**kwargs)
@@ -2578,7 +2629,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
for expected_input_param in self.blocks.inputs:
for expected_input_param in self._blocks.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
@@ -2597,9 +2648,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# Run the pipeline
with torch.no_grad():
try:
_, state = self.blocks(self, state)
_, state = self._blocks(self, state)
except Exception:
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
error_msg = f"Error in block: ({self._blocks.__class__.__name__}):\n"
logger.error(error_msg)
raise

View File

@@ -21,16 +21,16 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["decoders"] = ["WanImageVaeDecoderStep"]
_import_structure["encoders"] = ["WanTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"Wan22AutoBlocks",
"WanAutoBlocks",
"WanAutoImageEncoderStep",
"WanAutoVaeImageEncoderStep",
_import_structure["modular_blocks_wan"] = ["WanBlocks"]
_import_structure["modular_blocks_wan22"] = ["Wan22Blocks"]
_import_structure["modular_blocks_wan22_i2v"] = ["Wan22Image2VideoBlocks"]
_import_structure["modular_blocks_wan_i2v"] = ["WanImage2VideoAutoBlocks"]
_import_structure["modular_pipeline"] = [
"Wan22Image2VideoModularPipeline",
"Wan22ModularPipeline",
"WanImage2VideoModularPipeline",
"WanModularPipeline",
]
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -39,16 +39,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .decoders import WanImageVaeDecoderStep
from .encoders import WanTextEncoderStep
from .modular_blocks import (
ALL_BLOCKS,
Wan22AutoBlocks,
WanAutoBlocks,
WanAutoImageEncoderStep,
WanAutoVaeImageEncoderStep,
from .modular_blocks_wan import WanBlocks
from .modular_blocks_wan22 import Wan22Blocks
from .modular_blocks_wan22_i2v import Wan22Image2VideoBlocks
from .modular_blocks_wan_i2v import WanImage2VideoAutoBlocks
from .modular_pipeline import (
Wan22Image2VideoModularPipeline,
Wan22ModularPipeline,
WanImage2VideoModularPipeline,
WanModularPipeline,
)
from .modular_pipeline import WanModularPipeline
else:
import sys

View File

@@ -280,7 +280,7 @@ class WanAdditionalInputsStep(ModularPipelineBlocks):
def __init__(
self,
image_latent_inputs: List[str] = ["first_frame_latents"],
image_latent_inputs: List[str] = ["image_condition_latents"],
additional_batch_inputs: List[str] = [],
):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
@@ -294,20 +294,16 @@ class WanAdditionalInputsStep(ModularPipelineBlocks):
Args:
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
a single string or list of strings. Defaults to ["first_frame_latents"].
a single string or list of strings. Defaults to ["image_condition_latents"].
additional_batch_inputs (List[str], optional):
Names of additional conditional input tensors to expand batch size. These tensors will only have their
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
Defaults to [].
Examples:
# Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep()
# Configure to process multiple image latent inputs
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"])
# Configure to process image latents and additional batch inputs WanAdditionalInputsStep(
image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"]
# Configure to process image_condition_latents (default behavior) WanAdditionalInputsStep() # Configure to
process image latents and additional batch inputs WanAdditionalInputsStep(
image_latent_inputs=["image_condition_latents"], additional_batch_inputs=["image_embeds"]
)
"""
if not isinstance(image_latent_inputs, list):
@@ -557,81 +553,3 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked first frame latents and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
self.set_block_state(state, block_state)
return components, state
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
block_state.first_last_frame_latents = torch.concat(
[mask_lat_size, block_state.first_last_frame_latents], dim=1
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -29,7 +29,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanImageVaeDecoderStep(ModularPipelineBlocks):
class WanVaeDecoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -56,7 +56,10 @@ class WanImageVaeDecoderStep(ModularPipelineBlocks):
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
)
),
InputParam(
"output_type", default="np", type_hint=str, description="The output type of the decoded videos"
),
]
@property
@@ -87,7 +90,8 @@ class WanImageVaeDecoderStep(ModularPipelineBlocks):
latents = latents.to(vae_dtype)
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type="np")
output_type = getattr(block_state, "output_type", "np")
block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type=output_type)
self.set_block_state(state, block_state)

View File

@@ -89,52 +89,10 @@ class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"first_frame_latents",
"image_condition_latents",
required=True,
type_hint=torch.Tensor,
description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.",
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs. Can be generated in input step.",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(
block_state.dtype
)
return components, block_state
class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return (
"step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"first_last_frame_latents",
required=True,
type_hint=torch.Tensor,
description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.",
description="The image condition latents to use for the denoising process. Can be generated in prepare_first_frame_latents/prepare_first_last_frame_latents step.",
),
InputParam(
"dtype",
@@ -147,7 +105,7 @@ class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = torch.cat(
[block_state.latents, block_state.first_last_frame_latents], dim=1
[block_state.latents, block_state.image_condition_latents], dim=1
).to(block_state.dtype)
return components, block_state
@@ -584,29 +542,3 @@ class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper):
" - `WanLoopAfterDenoiser`\n"
"This block supports image-to-video tasks for Wan2.2."
)
class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanFLF2VLoopBeforeDenoiser,
WanLoopDenoiser(
guider_input_fields={
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_hidden_states_image": "image_embeds",
}
),
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `WanFLF2VLoopBeforeDenoiser`\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports FLF2V tasks for wan2.1."
)

View File

@@ -468,7 +468,7 @@ class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks):
return components, state
class WanVaeImageEncoderStep(ModularPipelineBlocks):
class WanVaeEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -493,7 +493,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
InputParam("height"),
InputParam("width"),
InputParam("num_frames"),
InputParam("num_frames", type_hint=int, default=81),
InputParam("generator"),
]
@@ -564,7 +564,51 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
return components, state
class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked first frame latents and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", required=True),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
block_state.image_condition_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
self.set_block_state(state, block_state)
return components, state
class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -590,7 +634,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
InputParam("height"),
InputParam("width"),
InputParam("num_frames"),
InputParam("num_frames", type_hint=int, default=81),
InputParam("generator"),
]
@@ -667,3 +711,49 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int, required=True),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
block_state.image_condition_latents = torch.concat(
[mask_lat_size, block_state.first_last_frame_latents], dim=1
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,474 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
WanAdditionalInputsStep,
WanPrepareFirstFrameLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanImageVaeDecoderStep
from .denoise import (
Wan22DenoiseStep,
Wan22Image2VideoDenoiseStep,
WanDenoiseStep,
WanFLF2VDenoiseStep,
WanImage2VideoDenoiseStep,
)
from .encoders import (
WanFirstLastFrameImageEncoderStep,
WanFirstLastFrameVaeImageEncoderStep,
WanImageCropResizeStep,
WanImageEncoderStep,
WanImageResizeStep,
WanTextEncoderStep,
WanVaeImageEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# wan2.1
# wan2.1: text2vid
class WanCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanDenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: image2video
## image encoder
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageEncoderStep]
block_names = ["image_resize", "image_encoder"]
@property
def description(self):
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
## vae encoder
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
block_names = ["image_resize", "vae_encoder"]
@property
def description(self):
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
## denoise
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstFrameLatentsStep,
WanImage2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: FLF2v
## image encoder
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "image_encoder"]
@property
def description(self):
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
## vae encoder
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "vae_encoder"]
@property
def description(self):
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
## denoise
class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanFLF2VDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_last_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: auto blocks
## image encoder
class WanAutoImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
def description(self):
return (
"Image Encoder step that encode the image to generate the image embeddings"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
## vae encoder
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
def description(self):
return (
"Vae Image Encoder step that encode the image to generate the image latents"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
## denoise
class WanAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanFLF2VCoreDenoiseStep,
WanImage2VideoCoreDenoiseStep,
WanCoreDenoiseStep,
]
block_names = ["flf2v", "image2video", "text2video"]
block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2video and image2video tasks."
" - `WanCoreDenoiseStep` (text2video) for text2vid tasks."
" - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks."
+ " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n"
+ " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n"
)
# auto pipeline blocks
class WanAutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoImageEncoderStep,
WanAutoVaeImageEncoderStep,
WanAutoDenoiseStep,
WanImageVaeDecoderStep,
]
block_names = [
"text_encoder",
"image_encoder",
"vae_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-video using Wan.\n"
+ "- for text-to-video generation, all you need to provide is `prompt`"
)
# wan22
# wan2.2: text2vid
## denoise
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
Wan22DenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
)
# wan2.2: image2video
## denoise
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstFrameLatentsStep,
Wan22Image2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
)
class Wan22AutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
Wan22Image2VideoCoreDenoiseStep,
Wan22CoreDenoiseStep,
]
block_names = ["image2video", "text2video"]
block_trigger_inputs = ["first_frame_latents", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2video and image2video tasks."
" - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks."
" - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks."
+ " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n"
+ " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n"
)
class Wan22AutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoVaeImageEncoderStep,
Wan22AutoDenoiseStep,
WanImageVaeDecoderStep,
]
block_names = [
"text_encoder",
"vae_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-video using Wan2.2.\n"
+ "- for text-to-video generation, all you need to provide is `prompt`"
)
# presets for wan2.1 and wan2.2
# YiYi Notes: should we move these to doc?
# wan2.1
TEXT2VIDEO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", WanDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
IMAGE2VIDEO_BLOCKS = InsertableDict(
[
("image_resize", WanImageResizeStep),
("image_encoder", WanImage2VideoImageEncoderStep),
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep),
("denoise", WanImage2VideoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
FLF2V_BLOCKS = InsertableDict(
[
("image_resize", WanImageResizeStep),
("last_image_resize", WanImageCropResizeStep),
("image_encoder", WanFLF2VImageEncoderStep),
("vae_encoder", WanFLF2VVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep),
("denoise", WanFLF2VDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("image_encoder", WanAutoImageEncoderStep),
("vae_encoder", WanAutoVaeImageEncoderStep),
("denoise", WanAutoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
# wan2.2 presets
TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", Wan22DenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
[
("image_resize", WanImageResizeStep),
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", Wan22DenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
AUTO_BLOCKS_WAN22 = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("vae_encoder", WanAutoVaeImageEncoderStep),
("denoise", Wan22AutoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
# presets all blocks (wan and wan22)
ALL_BLOCKS = {
"wan2.1": {
"text2video": TEXT2VIDEO_BLOCKS,
"image2video": IMAGE2VIDEO_BLOCKS,
"flf2v": FLF2V_BLOCKS,
"auto": AUTO_BLOCKS,
},
"wan2.2": {
"text2video": TEXT2VIDEO_BLOCKS_WAN22,
"image2video": IMAGE2VIDEO_BLOCKS_WAN22,
"auto": AUTO_BLOCKS_WAN22,
},
}

View File

@@ -0,0 +1,83 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .before_denoise import (
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanVaeDecoderStep
from .denoise import (
WanDenoiseStep,
)
from .encoders import (
WanTextEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. DENOISE
# ====================
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
class WanCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanDenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanDenoiseStep` is used to denoise the latents\n"
)
# ====================
# 2. BLOCKS (Wan2.1 text2video)
# ====================
class WanBlocks(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [
WanTextEncoderStep,
WanCoreDenoiseStep,
WanVaeDecoderStep,
]
block_names = ["text_encoder", "denoise", "decode"]
@property
def description(self):
return (
"Modular pipeline blocks for Wan2.1.\n"
+ "- `WanTextEncoderStep` is used to encode the text\n"
+ "- `WanCoreDenoiseStep` is used to denoise the latents\n"
+ "- `WanVaeDecoderStep` is used to decode the latents to images"
)

View File

@@ -0,0 +1,88 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .before_denoise import (
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanVaeDecoderStep
from .denoise import (
Wan22DenoiseStep,
)
from .encoders import (
WanTextEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. DENOISE
# ====================
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
Wan22DenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
)
# ====================
# 2. BLOCKS (Wan2.2 text2video)
# ====================
class Wan22Blocks(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [
WanTextEncoderStep,
Wan22CoreDenoiseStep,
WanVaeDecoderStep,
]
block_names = [
"text_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Modular pipeline for text-to-video using Wan2.2.\n"
+ " - `WanTextEncoderStep` encodes the text\n"
+ " - `Wan22CoreDenoiseStep` denoes the latents\n"
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
)

View File

@@ -0,0 +1,117 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .before_denoise import (
WanAdditionalInputsStep,
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanVaeDecoderStep
from .denoise import (
Wan22Image2VideoDenoiseStep,
)
from .encoders import (
WanImageResizeStep,
WanPrepareFirstFrameLatentsStep,
WanTextEncoderStep,
WanVaeEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. VAE ENCODER
# ====================
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
@property
def description(self):
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
# ====================
# 2. DENOISE
# ====================
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
Wan22Image2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
)
# ====================
# 3. BLOCKS (Wan2.2 Image2Video)
# ====================
class Wan22Image2VideoBlocks(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanTextEncoderStep,
WanImage2VideoVaeEncoderStep,
Wan22Image2VideoCoreDenoiseStep,
WanVaeDecoderStep,
]
block_names = [
"text_encoder",
"vae_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Modular pipeline for image-to-video using Wan2.2.\n"
+ " - `WanTextEncoderStep` encodes the text\n"
+ " - `WanImage2VideoVaeEncoderStep` encodes the image\n"
+ " - `Wan22Image2VideoCoreDenoiseStep` denoes the latents\n"
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
)

View File

@@ -0,0 +1,203 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from .before_denoise import (
WanAdditionalInputsStep,
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanVaeDecoderStep
from .denoise import (
WanImage2VideoDenoiseStep,
)
from .encoders import (
WanFirstLastFrameImageEncoderStep,
WanFirstLastFrameVaeEncoderStep,
WanImageCropResizeStep,
WanImageEncoderStep,
WanImageResizeStep,
WanPrepareFirstFrameLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanTextEncoderStep,
WanVaeEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. IMAGE ENCODER
# ====================
# wan2.1 I2V (first frame only)
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanImageEncoderStep]
block_names = ["image_resize", "image_encoder"]
@property
def description(self):
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
# wan2.1 FLF2V (first and last frame)
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "image_encoder"]
@property
def description(self):
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
# wan2.1 Auto Image Encoder
class WanAutoImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
block_trigger_inputs = ["last_image", "image"]
model_name = "wan-i2v"
@property
def description(self):
return (
"Image Encoder step that encode the image to generate the image embeddings"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
# ====================
# 2. VAE ENCODER
# ====================
# wan2.1 I2V (first frame only)
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
@property
def description(self):
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
# wan2.1 FLF2V (first and last frame)
class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanImageResizeStep,
WanImageCropResizeStep,
WanFirstLastFrameVaeEncoderStep,
WanPrepareFirstLastFrameLatentsStep,
]
block_names = ["image_resize", "last_image_resize", "vae_encoder", "prepare_first_last_frame_latents"]
@property
def description(self):
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
# wan2.1 Auto Vae Encoder
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanFLF2VVaeEncoderStep, WanImage2VideoVaeEncoderStep]
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
def description(self):
return (
"Vae Image Encoder step that encode the image to generate the image latents"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
# ====================
# 3. DENOISE (inputs -> set_timesteps -> prepare_latents -> denoise)
# ====================
# wan2.1 I2V core denoise (support both I2V and FLF2V)
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanImage2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# ====================
# 4. BLOCKS (Wan2.1 Image2Video)
# ====================
# wan2.1 Image2Video Auto Blocks
class WanImage2VideoAutoBlocks(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanTextEncoderStep,
WanAutoImageEncoderStep,
WanAutoVaeEncoderStep,
WanImage2VideoCoreDenoiseStep,
WanVaeDecoderStep,
]
block_names = [
"text_encoder",
"image_encoder",
"vae_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Auto Modular pipeline for image-to-video using Wan.\n"
+ "- for I2V workflow, all you need to provide is `image`"
+ "- for FLF2V workflow, all you need to provide is `last_image` and `image`"
)

View File

@@ -13,8 +13,6 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import WanLoraLoaderMixin
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...utils import logging
@@ -30,19 +28,12 @@ class WanModularPipeline(
WanLoraLoaderMixin,
):
"""
A ModularPipeline for Wan.
A ModularPipeline for Wan2.1 text2video.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "WanAutoBlocks"
# override the default_blocks_name in base class, which is just return self.default_blocks_name
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22AutoBlocks"
else:
return "WanAutoBlocks"
default_blocks_name = "WanBlocks"
@property
def default_height(self):
@@ -118,3 +109,33 @@ class WanModularPipeline(
if hasattr(self, "scheduler") and self.scheduler is not None:
num_train_timesteps = self.scheduler.config.num_train_timesteps
return num_train_timesteps
class WanImage2VideoModularPipeline(WanModularPipeline):
"""
A ModularPipeline for Wan2.1 image2video (both I2V and FLF2V).
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "WanImage2VideoAutoBlocks"
class Wan22ModularPipeline(WanModularPipeline):
"""
A ModularPipeline for Wan2.2 text2video.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Wan22Blocks"
class Wan22Image2VideoModularPipeline(Wan22ModularPipeline):
"""
A ModularPipeline for Wan2.2 image2video.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Wan22Image2VideoBlocks"

View File

@@ -248,7 +248,7 @@ AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict(
[
("wan", WanImageToVideoPipeline),
("wan-i2v", WanImageToVideoPipeline),
]
)

View File

@@ -13,12 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterator
from fractions import Fraction
from typing import Optional
from itertools import chain
from typing import List, Optional, Union
import numpy as np
import PIL.Image
import torch
from tqdm import tqdm
from ...utils import is_av_available
from ...utils import get_logger, is_av_available
logger = get_logger(__name__) # pylint: disable=invalid-name
_CAN_USE_AV = is_av_available()
@@ -101,11 +109,59 @@ def _write_audio(
def encode_video(
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]],
fps: int,
audio: Optional[torch.Tensor],
audio_sample_rate: Optional[int],
output_path: str,
video_chunks_number: int = 1,
) -> None:
video_np = video.cpu().numpy()
"""
Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo:
https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182
_, height, width, _ = video_np.shape
Args:
video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the
input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines
usually return with `output_type="np"`).
fps (`int`)
The frames per second (FPS) of the encoded video.
audio (`torch.Tensor`, *optional*):
An audio waveform of shape [audio_channels, samples].
audio_sample_rate: (`int`, *optional*):
The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz).
output_path (`str`):
The path to save the encoded video to.
video_chunks_number (`int`, *optional*, defaults to `1`):
The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The
number of chunks to use often depends on the tiling config for the video VAE.
"""
if isinstance(video, list) and isinstance(video[0], PIL.Image.Image):
# Pipeline output_type="pil"; assumes each image is in "RGB" mode
video_frames = [np.array(frame) for frame in video]
video = np.stack(video_frames, axis=0)
video = torch.from_numpy(video)
elif isinstance(video, np.ndarray):
# Pipeline output_type="np"
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
if np.all(is_denormalized):
video = (video * 255).round().astype("uint8")
else:
logger.warning(
"Supplied `numpy.ndarray` does not have values in [0, 1]. The values will be assumed to be pixel "
"values in [0, ..., 255] and will be used as is."
)
video = torch.from_numpy(video)
if isinstance(video, torch.Tensor):
# Split into video_chunks_number along the frame dimension
video = torch.tensor_split(video, video_chunks_number, dim=0)
video = iter(video)
first_chunk = next(video)
_, height, width, _ = first_chunk.shape
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
@@ -119,10 +175,12 @@ def encode_video(
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
for video_chunk in tqdm(chain([first_chunk], video), total=video_chunks_number, desc="Encoding video chunks"):
video_chunk_cpu = video_chunk.to("cpu").numpy()
for frame_array in video_chunk_cpu:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():

View File

@@ -69,8 +69,6 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],

View File

@@ -75,8 +75,6 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],

View File

@@ -76,8 +76,6 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )[0]
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],

View File

@@ -635,10 +635,12 @@ class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin):
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
timestep_model_input = timestep.repeat(2)
control_image_input = control_image.repeat(2, 1, 1, 1, 1)
else:
latent_model_input = latents.to(self.transformer.dtype)
prompt_embeds_model_input = prompt_embeds
timestep_model_input = timestep
control_image_input = control_image
latent_model_input = latent_model_input.unsqueeze(2)
latent_model_input_list = list(latent_model_input.unbind(dim=0))
@@ -647,7 +649,7 @@ class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin):
latent_model_input_list,
timestep_model_input,
prompt_embeds_model_input,
control_image,
control_image_input,
conditioning_scale=controlnet_conditioning_scale,
)

View File

@@ -657,10 +657,12 @@ class ZImageControlNetInpaintPipeline(DiffusionPipeline, FromSingleFileMixin):
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
timestep_model_input = timestep.repeat(2)
control_image_input = control_image.repeat(2, 1, 1, 1, 1)
else:
latent_model_input = latents.to(self.transformer.dtype)
prompt_embeds_model_input = prompt_embeds
timestep_model_input = timestep
control_image_input = control_image
latent_model_input = latent_model_input.unsqueeze(2)
latent_model_input_list = list(latent_model_input.unbind(dim=0))
@@ -669,7 +671,7 @@ class ZImageControlNetInpaintPipeline(DiffusionPipeline, FromSingleFileMixin):
latent_model_input_list,
timestep_model_input,
prompt_embeds_model_input,
control_image,
control_image_input,
conditioning_scale=controlnet_conditioning_scale,
)

View File

@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
from typing import List, Literal, Optional, Tuple, Union
from typing import Callable, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -51,7 +51,14 @@ class DPMSolverSDESchedulerOutput(BaseOutput):
class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def __init__(self, x, t0, t1, seed=None, **kwargs):
def __init__(
self,
x: torch.Tensor,
t0: float,
t1: float,
seed: Optional[Union[int, List[int]]] = None,
**kwargs,
):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get("w0", torch.zeros_like(x))
if seed is None:
@@ -79,10 +86,23 @@ class BatchedBrownianTree:
]
@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)
def sort(a: float, b: float) -> Tuple[float, float, float]:
"""
Sorts two float values and returns them along with a sign indicating if they were swapped.
def __call__(self, t0, t1):
Args:
a (`float`):
The first value.
b (`float`):
The second value.
Returns:
`Tuple[float, float, float]`:
A tuple containing the sorted values (min, max) and a sign (1.0 if a < b, -1.0 otherwise).
"""
return (a, b, 1.0) if a < b else (b, a, -1.0)
def __call__(self, t0: float, t1: float) -> torch.Tensor:
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
return w if self.batched else w[0]
@@ -92,23 +112,29 @@ class BrownianTreeNoiseSampler:
"""A noise sampler backed by a torchsde.BrownianTree.
Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
x (`torch.Tensor`): The tensor whose shape, device and dtype is used to generate random samples.
sigma_min (`float`): The low end of the valid interval.
sigma_max (`float`): The high end of the valid interval.
seed (`int` or `List[int]`): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each
with its own seed.
transform (callable): A function that maps sigma to the sampler's
transform (`callable`): A function that maps sigma to the sampler's
internal timestep.
"""
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
def __init__(
self,
x: torch.Tensor,
sigma_min: float,
sigma_max: float,
seed: Optional[Union[int, List[int]]] = None,
transform: Callable[[float], float] = lambda x: x,
):
self.transform = transform
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)
def __call__(self, sigma, sigma_next):
def __call__(self, sigma: float, sigma_next: float) -> torch.Tensor:
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
@@ -216,19 +242,28 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
noise_sampler_seed: Optional[int] = None,
timestep_spacing: str = "linspace",
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
@@ -238,7 +273,15 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -305,7 +348,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self._step_index = self._begin_index
@property
def init_noise_sigma(self):
def init_noise_sigma(self) -> torch.Tensor:
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
@@ -313,21 +356,21 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
def step_index(self) -> Union[int, None]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
def begin_index(self) -> Union[int, None]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -369,7 +412,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
num_inference_steps: int,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
):
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -378,6 +421,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
num_train_timesteps (`int`, *optional*):
The number of train timesteps. If `None`, uses `self.config.num_train_timesteps`.
"""
self.num_inference_steps = num_inference_steps
@@ -443,7 +488,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.noise_sampler = None
def _second_order_timesteps(self, sigmas, log_sigmas):
def _second_order_timesteps(self, sigmas: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
def sigma_fn(_t):
return np.exp(-_t)
@@ -459,7 +504,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return timesteps
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
"""
Convert sigma values to corresponding timestep values through interpolation.
@@ -604,14 +649,14 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return sigmas
@property
def state_in_first_order(self):
def state_in_first_order(self) -> bool:
return self.sample is None
def step(
self,
model_output: Union[torch.Tensor, np.ndarray],
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
sample: torch.Tensor,
return_dict: bool = True,
s_noise: float = 1.0,
) -> Union[DPMSolverSDESchedulerOutput, Tuple]:
@@ -620,11 +665,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor` or `np.ndarray`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor` or `np.ndarray`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
@@ -643,7 +688,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
# Create a noise sampler if it hasn't been created yet
if self.noise_sampler is None:
min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max()
self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed)
self.noise_sampler = BrownianTreeNoiseSampler(
sample, min_sigma.item(), max_sigma.item(), self.noise_sampler_seed
)
# Define functions to compute sigma and t from each other
def sigma_fn(_t: torch.Tensor) -> torch.Tensor:
@@ -694,7 +741,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
sigma_from = sigma_fn(t)
sigma_to = sigma_fn(t_next)
sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
sigma_up = min(
sigma_to,
(sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
)
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
ancestral_t = t_fn(sigma_down)
prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - (
@@ -771,5 +821,5 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -47,6 +47,21 @@ class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinBaseModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -287,7 +302,7 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Wan22AutoBlocks(metaclass=DummyObject):
class Wan22Blocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -302,7 +317,82 @@ class Wan22AutoBlocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class WanAutoBlocks(metaclass=DummyObject):
class Wan22Image2VideoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Wan22Image2VideoModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Wan22ModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanImage2VideoAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanImage2VideoModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):

View File

@@ -37,6 +37,9 @@ class ModularPipelineTesterMixin:
optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
# this is modular specific: generator needs to be a intermediate input because it's mutable
intermediate_params = frozenset(["generator"])
# Output type for the pipeline (e.g., "images" for image pipelines, "videos" for video pipelines)
# Subclasses can override this to change the expected output type
output_name = "images"
def get_generator(self, seed=0):
generator = torch.Generator("cpu").manual_seed(seed)
@@ -163,7 +166,7 @@ class ModularPipelineTesterMixin:
logger.setLevel(level=diffusers.logging.WARNING)
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
output = pipe(**batched_input, output="images")
output = pipe(**batched_input, output=self.output_name)
assert len(output) == batch_size, "Output is different from expected batch size"
def test_inference_batch_single_identical(
@@ -197,12 +200,16 @@ class ModularPipelineTesterMixin:
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
output = pipe(**inputs, output="images")
output_batch = pipe(**batched_inputs, output="images")
output = pipe(**inputs, output=self.output_name)
output_batch = pipe(**batched_inputs, output=self.output_name)
assert output_batch.shape[0] == batch_size
max_diff = torch.abs(output_batch[0] - output[0]).max()
# For batch comparison, we only need to compare the first item
if output_batch.shape[0] == batch_size and output.shape[0] == 1:
output_batch = output_batch[0:1]
max_diff = torch.abs(output_batch - output).max()
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
@require_accelerator
@@ -217,19 +224,32 @@ class ModularPipelineTesterMixin:
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs, output="images")
output = pipe(**inputs, output=self.output_name)
fp16_inputs = self.get_dummy_inputs()
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
output = output.cpu()
output_fp16 = output_fp16.cpu()
output_fp16 = pipe_fp16(**fp16_inputs, output=self.output_name)
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
output_tensor = output.float().cpu()
output_fp16_tensor = output_fp16.float().cpu()
# Check for NaNs in outputs (can happen with tiny models in FP16)
if torch.isnan(output_tensor).any() or torch.isnan(output_fp16_tensor).any():
pytest.skip("FP16 inference produces NaN values - this is a known issue with tiny models")
max_diff = numpy_cosine_similarity_distance(
output_tensor.flatten().numpy(), output_fp16_tensor.flatten().numpy()
)
# Check if cosine similarity is NaN (which can happen if vectors are zero or very small)
if torch.isnan(torch.tensor(max_diff)):
pytest.skip("Cosine similarity is NaN - outputs may be too small for reliable comparison")
assert max_diff < expected_max_diff, f"FP16 inference is different from FP32 inference (max_diff: {max_diff})"
@require_accelerator
def test_to_device(self):
@@ -251,14 +271,16 @@ class ModularPipelineTesterMixin:
def test_inference_is_not_nan_cpu(self):
pipe = self.get_pipeline().to("cpu")
output = pipe(**self.get_dummy_inputs(), output="images")
inputs = self.get_dummy_inputs()
output = pipe(**inputs, output=self.output_name)
assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"
@require_accelerator
def test_inference_is_not_nan(self):
pipe = self.get_pipeline().to(torch_device)
output = pipe(**self.get_dummy_inputs(), output="images")
inputs = self.get_dummy_inputs()
output = pipe(**inputs, output=self.output_name)
assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN"
def test_num_images_per_prompt(self):
@@ -278,7 +300,7 @@ class ModularPipelineTesterMixin:
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output=self.output_name)
assert images.shape[0] == batch_size * num_images_per_prompt
@@ -293,8 +315,7 @@ class ModularPipelineTesterMixin:
image_slices = []
for pipe in [base_pipe, offload_pipe]:
inputs = self.get_dummy_inputs()
image = pipe(**inputs, output="images")
image = pipe(**inputs, output=self.output_name)
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
@@ -315,8 +336,7 @@ class ModularPipelineTesterMixin:
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs()
image = pipe(**inputs, output="images")
image = pipe(**inputs, output=self.output_name)
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
@@ -331,13 +351,13 @@ class ModularGuiderTesterMixin:
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_no_cfg = pipe(**inputs, output="images")
out_no_cfg = pipe(**inputs, output=self.output_name)
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_cfg = pipe(**inputs, output="images")
out_cfg = pipe(**inputs, output=self.output_name)
assert out_cfg.shape == out_no_cfg.shape
max_diff = torch.abs(out_cfg - out_no_cfg).max()

View File

View File

@@ -0,0 +1,49 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from diffusers.modular_pipelines import WanBlocks, WanModularPipeline
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestWanModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = WanModularPipeline
pipeline_blocks_class = WanBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-wan-modular-pipe"
params = frozenset(["prompt", "height", "width", "num_frames"])
batch_params = frozenset(["prompt"])
optional_params = frozenset(["num_inference_steps", "num_videos_per_prompt", "latents"])
output_name = "videos"
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"height": 16,
"width": 16,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
@pytest.mark.skip(reason="num_videos_per_prompt")
def test_num_images_per_prompt(self):
pass

View File

@@ -0,0 +1,44 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers.modular_pipelines import ZImageAutoBlocks, ZImageModularPipeline
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestZImageModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = ZImageModularPipeline
pipeline_blocks_class = ZImageAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-zimage-modular-pipe"
params = frozenset(["prompt", "height", "width"])
batch_params = frozenset(["prompt"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-3)