Compare commits

..

9 Commits

Author SHA1 Message Date
Sayak Paul
8cd0f86ad8 Merge branch 'main' into enable-cp-kernels 2026-02-09 16:28:41 +05:30
Sayak Paul
dab372dd27 Merge branch 'main' into enable-cp-kernels 2026-01-26 21:58:00 +08:00
Sayak Paul
79438572e0 Merge branch 'main' into enable-cp-kernels 2026-01-19 10:28:00 +05:30
sayakpaul
2268583f39 up 2026-01-11 20:05:26 +05:30
Sayak Paul
dfbd4857b2 Merge branch 'main' into enable-cp-kernels 2025-12-17 12:14:40 +08:00
Sayak Paul
9bd83616bf Merge branch 'main' into enable-cp-kernels 2025-12-10 12:33:18 +08:00
sayakpaul
f732ff1144 up 2025-12-09 15:30:33 +05:30
sayakpaul
7a8f85b047 up 2025-12-09 14:59:01 +05:30
sayakpaul
82d20e64a5 up 2025-12-09 14:39:07 +05:30
11 changed files with 454 additions and 2433 deletions

View File

@@ -1,347 +0,0 @@
# DreamBooth training example for Z-Image
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
The `train_dreambooth_lora_z_image.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image).
> [!NOTE]
> **About Z-Image**
>
> Z-Image is a high-quality text-to-image generation model from Alibaba's Tongyi Lab. It uses a DiT (Diffusion Transformer) architecture with Qwen3 as the text encoder. The model excels at generating images with accurate text rendering, especially for Chinese characters.
> [!NOTE]
> **Memory consumption**
>
> Z-Image is relatively memory efficient compared to other large-scale diffusion models. Below we provide some tips and tricks to further reduce memory consumption during training.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_z_image.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
## Memory Optimizations
> [!NOTE]
> Many of these techniques complement each other and can be used together to further reduce memory consumption. However some techniques may be mutually exclusive so be sure to check before launching a training run.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the VAE and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
Pre-encode the training images with the VAE, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
### QLoRA: Low Precision Training with Quantization
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
- **FP8 training** with `torchao`:
Enable FP8 training by passing `--do_fp8_training`.
> [!IMPORTANT]
> Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater. If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers.
- **NF4 training** with `bitsandbytes`:
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing `--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
### Gradient Checkpointing and Accumulation
* `--gradient_accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. By passing a value > 1 you can reduce the amount of backward/update passes and hence also memory requirements.
* With `--gradient_checkpointing` we can save memory by not storing all intermediate activations during the forward pass. Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expense of a slower backward pass.
### 8-bit-Adam Optimizer
When training with `AdamW` (doesn't apply to `prodigy`) you can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
### Image Resolution
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
Note that by default, images are resized to resolution of 1024, but it's good to keep in mind in case you're training on higher resolutions.
### Precision of saved LoRA layers
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
This reduces memory requirements significantly without a significant quality loss. Note that if you do wish to save the final layers in float32 at the expense of more memory usage, you can do so by passing `--upcast_before_saving`.
## Training Examples
### Z-Image Training
To perform DreamBooth with LoRA on Z-Image, run:
```bash
export MODEL_NAME="Tongyi-MAI/Z-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-z-image-lora"
accelerate launch train_dreambooth_lora_z_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=5.0 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb"` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. The default is 512. Note that this will use more resources and may slow down the training in some cases.
### Training with FP8 Quantization
For reduced memory usage with FP8 training:
```bash
export MODEL_NAME="Tongyi-MAI/Z-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-z-image-lora-fp8"
accelerate launch train_dreambooth_lora_z_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=5.0 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### FSDP on the transformer
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
```yaml
distributed_type: FSDP
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_sharding_strategy: HYBRID_SHARD
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: ZImageTransformerBlock
fsdp_forward_prefetch: true
fsdp_sync_module_states: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_use_orig_params: false
fsdp_activation_checkpointing: true
fsdp_reshard_after_forward: true
fsdp_cpu_ram_efficient_loading: false
```
### Prodigy Optimizer
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
By using prodigy we can "eliminate" the need for manual learning rate tuning. Read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
To use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify:
```bash
--optimizer="prodigy"
```
> [!TIP]
> When using prodigy it's generally good practice to set `--learning_rate=1.0`
```bash
export MODEL_NAME="Tongyi-MAI/Z-Image"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-z-image-lora-prodigy"
accelerate launch train_dreambooth_lora_z_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=5.0 \
--gradient_accumulation_steps=4 \
--optimizer="prodigy" \
--learning_rate=1.0 \
--report_to="wandb" \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha:
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by `lora_alpha / lora_rank`.
**lora_alpha vs. rank:**
This ratio dictates the LoRA's effective strength:
- `lora_alpha == rank`: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
- `lora_alpha < rank`: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
- `lora_alpha > rank`: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the UNet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the UNet with a diffusion Transformer (DiT). With this change, we may also want to explore applying LoRA training onto different types of layers and blocks.
To allow more flexibility and control over the targeted modules we added `--lora_layers`, in which you can specify in a comma separated string the exact modules for LoRA training. Here are some examples of target modules you can provide:
- For attention only layers: `--lora_layers="to_k,to_q,to_v,to_out.0"`
- For attention and feed-forward layers: `--lora_layers="to_k,to_q,to_v,to_out.0,ff.net.0.proj,ff.net.2"`
> [!NOTE]
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string.
> [!NOTE]
> Keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
### Aspect Ratio Bucketing
We've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
```bash
--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
```
### Bilingual Prompts
Z-Image has strong support for both Chinese and English prompts. When training with Chinese prompts, ensure your dataset captions are properly encoded in UTF-8:
```bash
--instance_prompt="一只sks狗的照片"
--validation_prompt="一只sks狗在桶里的照片"
```
> [!TIP]
> Z-Image excels at text rendering in generated images, especially for Chinese characters. If your use case involves generating images with text, consider including text-related examples in your training data.
## Inference
Once you have trained a LoRA, you can load it for inference:
```python
import torch
from diffusers import ZImagePipeline
pipe = ZImagePipeline.from_pretrained("Tongyi-MAI/Z-Image", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load your trained LoRA
pipe.load_lora_weights("path/to/your/trained-z-image-lora")
# Generate an image
image = pipe(
prompt="A photo of sks dog in a bucket",
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).images[0]
image.save("output.png")
```
---
Since Z-Image finetuning is still in an experimental phase, we encourage you to explore different settings and share your insights! 🤗

File diff suppressed because it is too large Load Diff

View File

@@ -297,8 +297,6 @@ else:
"ComponentSpec",
"ModularPipeline",
"ModularPipelineBlocks",
"InputParam",
"OutputParam",
]
)
_import_structure["optimization"] = [
@@ -1062,7 +1060,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ZImageTransformer2DModel,
attention_backend,
)
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks, InputParam, OutputParam
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,

View File

@@ -264,6 +264,10 @@ class _HubKernelConfig:
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None
wrapped_forward_attr: Optional[str] = None
wrapped_backward_attr: Optional[str] = None
wrapped_forward_fn: Optional[Callable] = None
wrapped_backward_fn: Optional[Callable] = None
# Registry for hub-based attention kernels
@@ -278,7 +282,11 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# revision="fake-ops-return-probs",
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
@@ -603,22 +611,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== Helpers for downloading kernels =====
def _resolve_kernel_attr(module, attr_path: str):
target = module
for attr in attr_path.split("."):
if not hasattr(target, attr):
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
target = getattr(target, attr)
return target
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
if config.kernel_fn is not None:
needs_kernel = config.kernel_fn is None
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
if needs_wrapped_backward:
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
@@ -1069,6 +1094,231 @@ def _flash_attention_backward_op(
return grad_query, grad_key, grad_value
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = wrapped_forward_fn(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
)
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
_ = wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _flash_attention_3_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
*,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.scale = scale
ctx.is_causal = is_causal
ctx._hub_kernel = func
return (out, lse) if return_lse else out
def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
)
if isinstance(out, tuple):
out = out[0]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
return grad_query, grad_key, grad_value
def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1107,6 +1357,46 @@ def _sage_attention_forward_op(
return (out, lse) if return_lse else out
def _sage_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1).contiguous()
return (out, lse) if return_lse else out
def _sage_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
@@ -1940,7 +2230,7 @@ def _flash_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_hub(
query: torch.Tensor,
@@ -1958,17 +2248,35 @@ def _flash_attention_hub(
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_hub_forward_op,
backward_op=_flash_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@@ -2115,7 +2423,7 @@ def _flash_attention_3(
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_3_hub(
query: torch.Tensor,
@@ -2130,33 +2438,68 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
return (out[0], out[1]) if return_attn_probs else out
forward_op = functools.partial(
_flash_attention_3_hub_forward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
# When `return_attn_probs` is True, the above returns a tuple of
# actual outputs and lse.
return (out[0], out[1]) if return_attn_probs else out
backward_op = functools.partial(
_flash_attention_3_hub_backward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_attn_probs,
forward_op=forward_op,
backward_op=backward_op,
_parallel_config=_parallel_config,
)
if return_attn_probs:
out, lse = out
return out, lse
return out
@_AttentionBackendRegistry.register(
@@ -2787,7 +3130,7 @@ def _sage_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _sage_attention_hub(
query: torch.Tensor,
@@ -2815,6 +3158,23 @@ def _sage_attention_hub(
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_hub_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out

View File

@@ -31,7 +31,9 @@ else:
"FluxAutoBeforeDenoiseStep",
"FluxAutoBlocks",
"FluxAutoDecodeStep",
"FluxAutoDenoiseStep",
"FluxKontextAutoBlocks",
"FluxKontextAutoDenoiseStep",
"FluxKontextBeforeDenoiseStep",
]
_import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]
@@ -53,7 +55,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxAutoBeforeDenoiseStep,
FluxAutoBlocks,
FluxAutoDecodeStep,
FluxAutoDenoiseStep,
FluxKontextAutoBlocks,
FluxKontextAutoDenoiseStep,
FluxKontextBeforeDenoiseStep,
)
from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline

View File

@@ -201,6 +201,37 @@ class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
)
# denoise: text2image
class FluxAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxDenoiseStep]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2image and img2img tasks."
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
)
# denoise: Flux Kontext
class FluxKontextAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxKontextDenoiseStep]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents for Flux Kontext. "
"This is a auto pipeline block that works for text2image and img2img tasks."
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
)
# decode: all task (text2img, img2img)
class FluxAutoDecodeStep(AutoPipelineBlocks):
@@ -291,7 +322,7 @@ class FluxKontextAutoInputStep(AutoPipelineBlocks):
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux"
block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxDenoiseStep]
block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
@@ -300,7 +331,7 @@ class FluxCoreDenoiseStep(SequentialPipelineBlocks):
"Core step that performs the denoising process. \n"
+ " - `FluxAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
@@ -309,7 +340,7 @@ class FluxCoreDenoiseStep(SequentialPipelineBlocks):
class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux-kontext"
block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextDenoiseStep]
block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextAutoDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
@@ -318,7 +349,7 @@ class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
"Core step that performs the denoising process. \n"
+ " - `FluxKontextAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxKontextAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxKontextDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ " - `FluxKontextAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings."

View File

@@ -37,9 +37,6 @@ class ModularPipelineTesterMixin:
optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
# this is modular specific: generator needs to be a intermediate input because it's mutable
intermediate_params = frozenset(["generator"])
# Output type for the pipeline (e.g., "images" for image pipelines, "videos" for video pipelines)
# Subclasses can override this to change the expected output type
output_name = "images"
def get_generator(self, seed=0):
generator = torch.Generator("cpu").manual_seed(seed)
@@ -166,7 +163,7 @@ class ModularPipelineTesterMixin:
logger.setLevel(level=diffusers.logging.WARNING)
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
output = pipe(**batched_input, output=self.output_name)
output = pipe(**batched_input, output="images")
assert len(output) == batch_size, "Output is different from expected batch size"
def test_inference_batch_single_identical(
@@ -200,16 +197,12 @@ class ModularPipelineTesterMixin:
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
output = pipe(**inputs, output=self.output_name)
output_batch = pipe(**batched_inputs, output=self.output_name)
output = pipe(**inputs, output="images")
output_batch = pipe(**batched_inputs, output="images")
assert output_batch.shape[0] == batch_size
# For batch comparison, we only need to compare the first item
if output_batch.shape[0] == batch_size and output.shape[0] == 1:
output_batch = output_batch[0:1]
max_diff = torch.abs(output_batch - output).max()
max_diff = torch.abs(output_batch[0] - output[0]).max()
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
@require_accelerator
@@ -224,32 +217,19 @@ class ModularPipelineTesterMixin:
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs, output=self.output_name)
output = pipe(**inputs, output="images")
fp16_inputs = self.get_dummy_inputs()
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
output_fp16 = pipe_fp16(**fp16_inputs, output=self.output_name)
output = output.cpu()
output_fp16 = output_fp16.cpu()
output_tensor = output.float().cpu()
output_fp16_tensor = output_fp16.float().cpu()
# Check for NaNs in outputs (can happen with tiny models in FP16)
if torch.isnan(output_tensor).any() or torch.isnan(output_fp16_tensor).any():
pytest.skip("FP16 inference produces NaN values - this is a known issue with tiny models")
max_diff = numpy_cosine_similarity_distance(
output_tensor.flatten().numpy(), output_fp16_tensor.flatten().numpy()
)
# Check if cosine similarity is NaN (which can happen if vectors are zero or very small)
if torch.isnan(torch.tensor(max_diff)):
pytest.skip("Cosine similarity is NaN - outputs may be too small for reliable comparison")
assert max_diff < expected_max_diff, f"FP16 inference is different from FP32 inference (max_diff: {max_diff})"
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
@require_accelerator
def test_to_device(self):
@@ -271,16 +251,14 @@ class ModularPipelineTesterMixin:
def test_inference_is_not_nan_cpu(self):
pipe = self.get_pipeline().to("cpu")
inputs = self.get_dummy_inputs()
output = pipe(**inputs, output=self.output_name)
output = pipe(**self.get_dummy_inputs(), output="images")
assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"
@require_accelerator
def test_inference_is_not_nan(self):
pipe = self.get_pipeline().to(torch_device)
inputs = self.get_dummy_inputs()
output = pipe(**inputs, output=self.output_name)
output = pipe(**self.get_dummy_inputs(), output="images")
assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN"
def test_num_images_per_prompt(self):
@@ -300,7 +278,7 @@ class ModularPipelineTesterMixin:
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output=self.output_name)
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
assert images.shape[0] == batch_size * num_images_per_prompt
@@ -315,7 +293,8 @@ class ModularPipelineTesterMixin:
image_slices = []
for pipe in [base_pipe, offload_pipe]:
inputs = self.get_dummy_inputs()
image = pipe(**inputs, output=self.output_name)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
@@ -336,7 +315,8 @@ class ModularPipelineTesterMixin:
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs()
image = pipe(**inputs, output=self.output_name)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
@@ -351,13 +331,13 @@ class ModularGuiderTesterMixin:
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_no_cfg = pipe(**inputs, output=self.output_name)
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_cfg = pipe(**inputs, output=self.output_name)
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = torch.abs(out_cfg - out_no_cfg).max()

View File

@@ -1,49 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from diffusers.modular_pipelines import WanBlocks, WanModularPipeline
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestWanModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = WanModularPipeline
pipeline_blocks_class = WanBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-wan-modular-pipe"
params = frozenset(["prompt", "height", "width", "num_frames"])
batch_params = frozenset(["prompt"])
optional_params = frozenset(["num_inference_steps", "num_videos_per_prompt", "latents"])
output_name = "videos"
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"height": 16,
"width": 16,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
@pytest.mark.skip(reason="num_videos_per_prompt")
def test_num_images_per_prompt(self):
pass

View File

@@ -1,44 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers.modular_pipelines import ZImageAutoBlocks, ZImageModularPipeline
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestZImageModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = ZImageModularPipeline
pipeline_blocks_class = ZImageAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-zimage-modular-pipe"
params = frozenset(["prompt", "height", "width"])
batch_params = frozenset(["prompt"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-3)