mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-06 00:31:48 +08:00
Compare commits
7 Commits
Ando233-ra
...
modular-sa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db9483f781 | ||
|
|
369e123bde | ||
|
|
20995862b0 | ||
|
|
a9269d2cf2 | ||
|
|
c167fe335e | ||
|
|
e340b52a92 | ||
|
|
71ce634d1e |
100
.claude/CLAUDE.md
Normal file
100
.claude/CLAUDE.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Build, Lint, and Test Commands
|
||||
|
||||
```bash
|
||||
# Install in development mode
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# Run full test suite (requires beefy machine)
|
||||
make test
|
||||
# Or directly:
|
||||
python -m pytest -n auto --dist=loadfile -s -v ./tests/
|
||||
|
||||
# Run a single test file
|
||||
python -m pytest tests/<TEST_FILE>.py
|
||||
|
||||
# Run slow tests (downloads many GBs of models)
|
||||
RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/
|
||||
|
||||
# Format code (ruff + doc-builder)
|
||||
make style
|
||||
|
||||
# Check code quality without modifying
|
||||
make quality
|
||||
|
||||
# Fast fixup for modified files only (recommended before commits)
|
||||
make fixup
|
||||
|
||||
# Fix copied code snippets and dummy objects
|
||||
make fix-copies
|
||||
|
||||
# Check repository consistency (dummies, inits, repo structure)
|
||||
make repo-consistency
|
||||
```
|
||||
|
||||
## Code Architecture
|
||||
|
||||
Diffusers is built on three core component types that work together:
|
||||
|
||||
### Pipelines (`src/diffusers/pipelines/`)
|
||||
- End-to-end inference workflows combining models and schedulers
|
||||
- Base class: `DiffusionPipeline` (in `pipeline_utils.py`)
|
||||
- Follow **single-file policy**: each pipeline in its own directory
|
||||
- Loaded via `DiffusionPipeline.from_pretrained()` which reads `model_index.json`
|
||||
- Components registered via `register_modules()` become pipeline attributes
|
||||
- ~99 pipeline implementations (Stable Diffusion, SDXL, Flux, etc.)
|
||||
|
||||
### Models (`src/diffusers/models/`)
|
||||
- Configurable neural network architectures extending PyTorch's Module
|
||||
- Base classes: `ModelMixin` + `ConfigMixin` (in `modeling_utils.py`)
|
||||
- **Do NOT follow single-file policy**: use shared building blocks (`attention.py`, `embeddings.py`, `resnet.py`)
|
||||
- Key subdirectories:
|
||||
- `autoencoders/`: VAEs for latent space compression
|
||||
- `unets/`: Diffusion model architectures (UNet2DConditionModel, etc.)
|
||||
- `transformers/`: Transformer-based models (Flux, SD3, etc.)
|
||||
- `controlnets/`: ControlNet variants
|
||||
|
||||
### Schedulers (`src/diffusers/schedulers/`)
|
||||
- Guide denoising process during inference
|
||||
- Base class: `SchedulerMixin` + `ConfigMixin` (in `scheduling_utils.py`)
|
||||
- Follow **single-file policy**: one scheduler per file
|
||||
- Key methods: `set_num_inference_steps()`, `step()`, `timesteps` property
|
||||
- Easily swappable via `ConfigMixin.from_config()`
|
||||
- ~55 scheduler algorithms (DDPM, DDIM, Euler, DPM-Solver, etc.)
|
||||
|
||||
### Supporting Systems
|
||||
|
||||
- **Loaders** (`src/diffusers/loaders/`): Mixins for LoRA, IP-Adapter, textual inversion, single-file loading
|
||||
- **Quantizers** (`src/diffusers/quantizers/`): BitsAndBytes, GGUF, TorchAO, Quanto support
|
||||
- **Hooks** (`src/diffusers/hooks/`): Runtime optimizations (offloading, layer skipping, caching)
|
||||
- **Guiders** (`src/diffusers/guiders/`): Guidance algorithms (CFG, PAG, etc.)
|
||||
|
||||
## Configuration System
|
||||
|
||||
All components use `ConfigMixin` for serialization:
|
||||
- Constructor arguments stored via `register_to_config(**kwargs)`
|
||||
- Instantiate from config: `Component.from_config(config_dict)`
|
||||
- Save/load as JSON files
|
||||
|
||||
## Key Design Principles
|
||||
|
||||
1. **Usability over Performance**: Models load at float32/CPU by default
|
||||
2. **Simple over Easy**: Explicit > implicit; expose complexity rather than hide it
|
||||
3. **Single-file policy**: Pipelines and schedulers are self-contained; models share building blocks
|
||||
4. **Copy-paste over abstraction**: Prefer duplicated code over hasty abstractions for contributor-friendliness
|
||||
|
||||
## Code Style
|
||||
|
||||
- Uses `ruff` for linting and formatting (line length: 119)
|
||||
- Documentation follows [Google style](https://google.github.io/styleguide/pyguide.html)
|
||||
- Use `# Copied from` mechanism for sharing code between similar files
|
||||
- Avoid lambda functions and advanced PyTorch operators for readability
|
||||
|
||||
## Testing
|
||||
|
||||
- Tests use `pytest` with `pytest-xdist` for parallelization
|
||||
- Slow tests gated by `RUN_SLOW=yes` environment variable
|
||||
- Test dependencies: `pip install -e ".[test]"`
|
||||
75
_modular_model_index.json
Normal file
75
_modular_model_index.json
Normal file
@@ -0,0 +1,75 @@
|
||||
{
|
||||
"_blocks_class_name": "SequentialPipelineBlocks",
|
||||
"_class_name": "Flux2ModularPipeline",
|
||||
"_diffusers_version": "0.36.0.dev0",
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"FlowMatchEulerDiscreteScheduler",
|
||||
{
|
||||
"repo": "hf-internal-testing/tiny-flux2",
|
||||
"revision": null,
|
||||
"subfolder": "scheduler",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"FlowMatchEulerDiscreteScheduler"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"Mistral3ForConditionalGeneration",
|
||||
{
|
||||
"repo": "hf-internal-testing/tiny-flux2",
|
||||
"revision": null,
|
||||
"subfolder": "text_encoder",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"Mistral3ForConditionalGeneration"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"AutoProcessor",
|
||||
{
|
||||
"repo": "hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor",
|
||||
"revision": null,
|
||||
"subfolder": "",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"AutoProcessor"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"transformer": [
|
||||
"diffusers",
|
||||
"Flux2Transformer2DModel",
|
||||
{
|
||||
"repo": "hf-internal-testing/tiny-flux2",
|
||||
"revision": null,
|
||||
"subfolder": "transformer",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"Flux2Transformer2DModel"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKLFlux2",
|
||||
{
|
||||
"repo": "hf-internal-testing/tiny-flux2",
|
||||
"revision": null,
|
||||
"subfolder": "vae",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"AutoencoderKLFlux2"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
]
|
||||
}
|
||||
239
custom_model_automodel_guide.md
Normal file
239
custom_model_automodel_guide.md
Normal file
@@ -0,0 +1,239 @@
|
||||
# Loading Custom Models with `AutoModel` and `trust_remote_code`
|
||||
|
||||
This guide shows how to create a custom model class that lives outside the `diffusers` library and load it via `AutoModel` with `trust_remote_code=True`.
|
||||
|
||||
## How It Works
|
||||
|
||||
When `AutoModel.from_pretrained()` (or `from_config()`) is called with `trust_remote_code=True`, it:
|
||||
|
||||
1. Loads the `config.json` from the model repository.
|
||||
2. Checks for an `"auto_map"` key in the config that maps `"AutoModel"` to a `"<module_file>.<ClassName>"` reference.
|
||||
3. Downloads the referenced Python module from the repository.
|
||||
4. Dynamically imports and instantiates the class from that module.
|
||||
|
||||
This allows anyone to define and share completely custom model architectures without requiring changes to the `diffusers` library itself.
|
||||
|
||||
## Step 1: Define Your Custom Model
|
||||
|
||||
Create a Python file (e.g., `modeling_my_model.py`) that defines your model class. The class must inherit from `ModelMixin` and `ConfigMixin`, and use the `@register_to_config` decorator on `__init__`.
|
||||
|
||||
```python
|
||||
# modeling_my_model.py
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
|
||||
|
||||
class MyCustomModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(self, in_channels: int = 3, hidden_dim: int = 64, out_channels: int = 3):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(hidden_dim, out_channels, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
```
|
||||
|
||||
Key requirements:
|
||||
|
||||
- **`ModelMixin`** provides `save_pretrained()` / `from_pretrained()` for weight serialization.
|
||||
- **`ConfigMixin`** provides `save_config()` / `from_config()` and the `config.json` machinery.
|
||||
- **`@register_to_config`** automatically captures all `__init__` parameters into `config.json` so the model can be reconstructed from config alone.
|
||||
|
||||
## Step 2: Save the Model Locally
|
||||
|
||||
```python
|
||||
from modeling_my_model import MyCustomModel
|
||||
|
||||
model = MyCustomModel(in_channels=3, hidden_dim=128, out_channels=3)
|
||||
model.save_pretrained("./my-custom-model")
|
||||
```
|
||||
|
||||
This creates a directory with:
|
||||
|
||||
```
|
||||
my-custom-model/
|
||||
├── config.json
|
||||
└── diffusion_pytorch_model.safetensors
|
||||
```
|
||||
|
||||
The generated `config.json` will look like:
|
||||
|
||||
```json
|
||||
{
|
||||
"_class_name": "MyCustomModel",
|
||||
"_diffusers_version": "0.32.0",
|
||||
"in_channels": 3,
|
||||
"hidden_dim": 128,
|
||||
"out_channels": 3
|
||||
}
|
||||
```
|
||||
|
||||
## Step 3: Add the `auto_map` and Model File to the Repository
|
||||
|
||||
To make `AutoModel` aware of your custom class, you need to:
|
||||
|
||||
1. **Copy `modeling_my_model.py` into the saved model directory.**
|
||||
2. **Add an `"auto_map"` entry to `config.json`** that points `AutoModel` to your class.
|
||||
|
||||
The `auto_map` value format is `"<module_name_without_.py>.<ClassName>"`:
|
||||
|
||||
```json
|
||||
{
|
||||
"_class_name": "MyCustomModel",
|
||||
"_diffusers_version": "0.32.0",
|
||||
"in_channels": 3,
|
||||
"hidden_dim": 128,
|
||||
"out_channels": 3,
|
||||
"auto_map": {
|
||||
"AutoModel": "modeling_my_model.MyCustomModel"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Your final directory structure should be:
|
||||
|
||||
```
|
||||
my-custom-model/
|
||||
├── config.json # with auto_map added
|
||||
├── diffusion_pytorch_model.safetensors
|
||||
└── modeling_my_model.py # your custom model code
|
||||
```
|
||||
|
||||
## Step 4: Load with `AutoModel`
|
||||
|
||||
### From a Local Directory
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained("./my-custom-model", trust_remote_code=True)
|
||||
print(model)
|
||||
```
|
||||
|
||||
### From the Hugging Face Hub
|
||||
|
||||
First, push the model directory to a Hub repository:
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo("your-username/my-custom-model", exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path="./my-custom-model",
|
||||
repo_id="your-username/my-custom-model",
|
||||
)
|
||||
```
|
||||
|
||||
Then load it:
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"your-username/my-custom-model",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Initializing from Config (Random Weights)
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel
|
||||
|
||||
model = AutoModel.from_config("./my-custom-model", trust_remote_code=True)
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch import nn
|
||||
from diffusers import ModelMixin, ConfigMixin, AutoModel
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
|
||||
|
||||
# 1. Define
|
||||
class MyCustomModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(self, in_channels: int = 3, hidden_dim: int = 64, out_channels: int = 3):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(hidden_dim, out_channels, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# 2. Save
|
||||
model = MyCustomModel(in_channels=3, hidden_dim=128, out_channels=3)
|
||||
model.save_pretrained("./my-custom-model")
|
||||
|
||||
# 3. Manually add auto_map to config.json and copy modeling file
|
||||
import json, shutil
|
||||
|
||||
config_path = "./my-custom-model/config.json"
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
config["auto_map"] = {"AutoModel": "modeling_my_model.MyCustomModel"}
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
shutil.copy("modeling_my_model.py", "./my-custom-model/modeling_my_model.py")
|
||||
|
||||
# 4. Load via AutoModel
|
||||
loaded_model = AutoModel.from_pretrained("./my-custom-model", trust_remote_code=True)
|
||||
|
||||
# 5. Verify
|
||||
x = torch.randn(1, 3, 32, 32)
|
||||
with torch.no_grad():
|
||||
out_original = model(x)
|
||||
out_loaded = loaded_model(x)
|
||||
|
||||
assert torch.allclose(out_original, out_loaded)
|
||||
print("Models produce identical outputs!")
|
||||
```
|
||||
|
||||
## Using Relative Imports in Custom Code
|
||||
|
||||
If your custom model depends on additional modules, you can use relative imports. For example, if your model uses a custom attention layer defined in a separate file:
|
||||
|
||||
```
|
||||
my-custom-model/
|
||||
├── config.json
|
||||
├── diffusion_pytorch_model.safetensors
|
||||
├── modeling_my_model.py # imports from .my_attention
|
||||
└── my_attention.py # custom attention implementation
|
||||
```
|
||||
|
||||
In `modeling_my_model.py`:
|
||||
|
||||
```python
|
||||
from .my_attention import MyAttention
|
||||
```
|
||||
|
||||
The dynamic module loader will automatically resolve and download all relatively imported files.
|
||||
|
||||
## Security Note
|
||||
|
||||
`trust_remote_code=True` executes arbitrary Python code from the model repository. Only use it with repositories you trust. You can globally disable remote code execution by setting the environment variable:
|
||||
|
||||
```bash
|
||||
export DIFFUSERS_DISABLE_REMOTE_CODE=1
|
||||
```
|
||||
@@ -194,8 +194,6 @@
|
||||
title: Model accelerators and hardware
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: using-diffusers/helios
|
||||
title: Helios
|
||||
- local: using-diffusers/consisid
|
||||
title: ConsisID
|
||||
- local: using-diffusers/sdxl
|
||||
@@ -352,8 +350,6 @@
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/glm_image_transformer2d
|
||||
title: GlmImageTransformer2DModel
|
||||
- local: api/models/helios_transformer3d
|
||||
title: HeliosTransformer3DModel
|
||||
- local: api/models/hidream_image_transformer
|
||||
title: HiDreamImageTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
@@ -460,8 +456,6 @@
|
||||
title: AutoencoderKLQwenImage
|
||||
- local: api/models/autoencoder_kl_wan
|
||||
title: AutoencoderKLWan
|
||||
- local: api/models/autoencoder_rae
|
||||
title: AutoencoderRAE
|
||||
- local: api/models/consistency_decoder_vae
|
||||
title: ConsistencyDecoderVAE
|
||||
- local: api/models/autoencoder_oobleck
|
||||
@@ -631,6 +625,7 @@
|
||||
title: Image-to-image
|
||||
- local: api/pipelines/stable_diffusion/inpaint
|
||||
title: Inpainting
|
||||
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Latent upscaler
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
@@ -679,8 +674,6 @@
|
||||
title: ConsisID
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/helios
|
||||
title: Helios
|
||||
- local: api/pipelines/hunyuan_video
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/hunyuan_video15
|
||||
@@ -752,10 +745,6 @@
|
||||
title: FlowMatchEulerDiscreteScheduler
|
||||
- local: api/schedulers/flow_match_heun_discrete
|
||||
title: FlowMatchHeunDiscreteScheduler
|
||||
- local: api/schedulers/helios_dmd
|
||||
title: HeliosDMDScheduler
|
||||
- local: api/schedulers/helios
|
||||
title: HeliosScheduler
|
||||
- local: api/schedulers/heun
|
||||
title: HeunDiscreteScheduler
|
||||
- local: api/schedulers/ipndm
|
||||
|
||||
@@ -23,7 +23,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
|
||||
- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
|
||||
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
|
||||
- [`HeliosLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/helios).
|
||||
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
|
||||
- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
|
||||
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
|
||||
@@ -87,10 +86,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin
|
||||
|
||||
## HeliosLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.HeliosLoraLoaderMixin
|
||||
|
||||
## HunyuanVideoLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
<!-- Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AutoencoderRAE
|
||||
|
||||
The Representation Autoencoder (RAE) model introduced in [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) by Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie from NYU VISIONx.
|
||||
|
||||
RAE combines a frozen pretrained vision encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT-MAE-style decoder. In the two-stage RAE training recipe, the autoencoder is trained in stage 1 (reconstruction), and then a diffusion model is trained on the resulting latent space in stage 2 (generation).
|
||||
|
||||
The following RAE models are released and supported in Diffusers:
|
||||
|
||||
| Model | Encoder | Latent shape (224px input) |
|
||||
|:------|:--------|:---------------------------|
|
||||
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08) | DINOv2-base | 768 x 16 x 16 |
|
||||
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512) | DINOv2-base (512px) | 768 x 32 x 32 |
|
||||
| [`nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08) | DINOv2-small | 384 x 16 x 16 |
|
||||
| [`nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08) | DINOv2-large | 1024 x 16 x 16 |
|
||||
| [`nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08) | SigLIP2-base | 768 x 16 x 16 |
|
||||
| [`nyu-visionx/RAE-mae-base-p16-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-mae-base-p16-ViTXL-n08) | MAE-base | 768 x 16 x 16 |
|
||||
|
||||
## Loading a pretrained model
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderRAE
|
||||
|
||||
model = AutoencoderRAE.from_pretrained(
|
||||
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
|
||||
).to("cuda").eval()
|
||||
```
|
||||
|
||||
## Encoding and decoding a real image
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderRAE
|
||||
from diffusers.utils import load_image
|
||||
from torchvision.transforms.functional import to_tensor, to_pil_image
|
||||
|
||||
model = AutoencoderRAE.from_pretrained(
|
||||
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
|
||||
).to("cuda").eval()
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
|
||||
image = image.convert("RGB").resize((224, 224))
|
||||
x = to_tensor(image).unsqueeze(0).to("cuda") # (1, 3, 224, 224), values in [0, 1]
|
||||
|
||||
with torch.no_grad():
|
||||
latents = model.encode(x).latent # (1, 768, 16, 16)
|
||||
recon = model.decode(latents).sample # (1, 3, 256, 256)
|
||||
|
||||
recon_image = to_pil_image(recon[0].clamp(0, 1).cpu())
|
||||
recon_image.save("recon.png")
|
||||
```
|
||||
|
||||
## Latent normalization
|
||||
|
||||
Some pretrained checkpoints include per-channel `latents_mean` and `latents_std` statistics for normalizing the latent space. When present, `encode` and `decode` automatically apply the normalization and denormalization, respectively.
|
||||
|
||||
```python
|
||||
model = AutoencoderRAE.from_pretrained(
|
||||
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
|
||||
).to("cuda").eval()
|
||||
|
||||
# Latent normalization is handled automatically inside encode/decode
|
||||
# when the checkpoint config includes latents_mean/latents_std.
|
||||
with torch.no_grad():
|
||||
latents = model.encode(x).latent # normalized latents
|
||||
recon = model.decode(latents).sample
|
||||
```
|
||||
|
||||
## AutoencoderRAE
|
||||
|
||||
[[autodoc]] AutoencoderRAE
|
||||
- encode
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -1,35 +0,0 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# HeliosTransformer3DModel
|
||||
|
||||
A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) by Peking University & ByteDance & etc.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import HeliosTransformer3DModel
|
||||
|
||||
# Best Quality
|
||||
transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
# Intermediate Weight
|
||||
transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
# Best Efficiency
|
||||
transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## HeliosTransformer3DModel
|
||||
|
||||
[[autodoc]] HeliosTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -1,465 +0,0 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# Helios
|
||||
|
||||
[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Zongjian Li, Xinwei Huang, Xiao Yang, Li Yuan.
|
||||
|
||||
* <u>We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality.</u> We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios-Page).
|
||||
|
||||
The following Helios models are supported in Diffusers:
|
||||
|
||||
- [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality, with v-prediction, standard CFG and custom HeliosScheduler.
|
||||
- [Helios-Mid](https://huggingface.co/BestWishYsh/Helios-Mid): Intermediate Weight, with v-prediction, CFG-Zero* and custom HeliosScheduler.
|
||||
- [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency, with x0-prediction and custom HeliosDMDScheduler.
|
||||
|
||||
> [!TIP]
|
||||
> Click on the Helios models in the right sidebar for more examples of video generation.
|
||||
|
||||
### Optimizing Memory and Inference Speed
|
||||
|
||||
The example below demonstrates how to generate a video from text optimized for memory or inference speed.
|
||||
|
||||
<hfoptions id="optimization">
|
||||
<hfoption id="memory">
|
||||
|
||||
Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
|
||||
|
||||
The Helios model below requires ~19GB of VRAM.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel, HeliosPipeline
|
||||
from diffusers.hooks.group_offloading import apply_group_offloading
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
# group-offloading
|
||||
pipeline = HeliosPipeline.from_pretrained(
|
||||
"BestWishYsh/Helios-Base",
|
||||
vae=vae,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.enable_group_offload(
|
||||
onload_device=torch.device("cuda"),
|
||||
offload_device=torch.device("cpu"),
|
||||
offload_type="block_level",
|
||||
num_blocks_per_group=1,
|
||||
use_stream=True,
|
||||
record_stream=True,
|
||||
)
|
||||
|
||||
prompt = """
|
||||
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
|
||||
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
|
||||
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
|
||||
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
|
||||
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
|
||||
the vivid colors of its surroundings. A close-up shot with dynamic movement.
|
||||
"""
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=99,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="inference speed">
|
||||
|
||||
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Attention Backends](../../optimization/attention_backends) such as FlashAttention and SageAttention can significantly increase speed by optimizing the computation of the attention mechanism. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel, HeliosPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
pipeline = HeliosPipeline.from_pretrained(
|
||||
"BestWishYsh/Helios-Base",
|
||||
vae=vae,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
|
||||
# attention backend
|
||||
# pipeline.transformer.set_attention_backend("flash")
|
||||
pipeline.transformer.set_attention_backend("_flash_3_hub") # For Hopper GPUs
|
||||
|
||||
# torch.compile
|
||||
torch.backends.cudnn.benchmark = True
|
||||
pipeline.text_encoder.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||
pipeline.vae.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||
pipeline.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||
|
||||
prompt = """
|
||||
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
|
||||
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
|
||||
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
|
||||
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
|
||||
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
|
||||
the vivid colors of its surroundings. A close-up shot with dynamic movement.
|
||||
"""
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=99,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### Generation with Helios-Base
|
||||
|
||||
The example below demonstrates how to use Helios-Base to generate video based on text, image or video.
|
||||
|
||||
<hfoptions id="Helios-Base usage">
|
||||
<hfoption id="usage">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoModel, HeliosPipeline
|
||||
from diffusers.utils import export_to_video, load_video, load_image
|
||||
|
||||
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
pipeline = HeliosPipeline.from_pretrained(
|
||||
"BestWishYsh/Helios-Base",
|
||||
vae=vae,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
# For Text-to-Video
|
||||
prompt = """
|
||||
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
|
||||
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
|
||||
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
|
||||
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
|
||||
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
|
||||
the vivid colors of its surroundings. A close-up shot with dynamic movement.
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=99,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
|
||||
|
||||
# For Image-to-Video
|
||||
prompt = """
|
||||
A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
|
||||
illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest,
|
||||
casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
|
||||
apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and
|
||||
relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
|
||||
respect for nature’s might.
|
||||
"""
|
||||
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=load_image(image_path).resize((640, 384)),
|
||||
num_frames=99,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_base_i2v_output.mp4", fps=24)
|
||||
|
||||
# For Video-to-Video
|
||||
prompt = """
|
||||
A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
|
||||
under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
|
||||
emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
|
||||
the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
|
||||
A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
|
||||
"""
|
||||
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
video=load_video(video_path),
|
||||
num_frames=99,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_base_v2v_output.mp4", fps=24)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### Generation with Helios-Mid
|
||||
|
||||
The example below demonstrates how to use Helios-Mid to generate video based on text, image or video.
|
||||
|
||||
<hfoptions id="Helios-Mid usage">
|
||||
<hfoption id="usage">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoModel, HeliosPyramidPipeline
|
||||
from diffusers.utils import export_to_video, load_video, load_image
|
||||
|
||||
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
pipeline = HeliosPyramidPipeline.from_pretrained(
|
||||
"BestWishYsh/Helios-Mid",
|
||||
vae=vae,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
# For Text-to-Video
|
||||
prompt = """
|
||||
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
|
||||
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
|
||||
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
|
||||
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
|
||||
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
|
||||
the vivid colors of its surroundings. A close-up shot with dynamic movement.
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=99,
|
||||
pyramid_num_inference_steps_list=[20, 20, 20],
|
||||
guidance_scale=5.0,
|
||||
use_zero_init=True,
|
||||
zero_steps=1,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_pyramid_t2v_output.mp4", fps=24)
|
||||
|
||||
# For Image-to-Video
|
||||
prompt = """
|
||||
A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
|
||||
illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest,
|
||||
casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
|
||||
apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and
|
||||
relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
|
||||
respect for nature’s might.
|
||||
"""
|
||||
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=load_image(image_path).resize((640, 384)),
|
||||
num_frames=99,
|
||||
pyramid_num_inference_steps_list=[20, 20, 20],
|
||||
guidance_scale=5.0,
|
||||
use_zero_init=True,
|
||||
zero_steps=1,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_pyramid_i2v_output.mp4", fps=24)
|
||||
|
||||
# For Video-to-Video
|
||||
prompt = """
|
||||
A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
|
||||
under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
|
||||
emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
|
||||
the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
|
||||
A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
|
||||
"""
|
||||
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
video=load_video(video_path),
|
||||
num_frames=99,
|
||||
pyramid_num_inference_steps_list=[20, 20, 20],
|
||||
guidance_scale=5.0,
|
||||
use_zero_init=True,
|
||||
zero_steps=1,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_pyramid_v2v_output.mp4", fps=24)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### Generation with Helios-Distilled
|
||||
|
||||
The example below demonstrates how to use Helios-Distilled to generate video based on text, image or video.
|
||||
|
||||
<hfoptions id="Helios-Distilled usage">
|
||||
<hfoption id="usage">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoModel, HeliosPyramidPipeline
|
||||
from diffusers.utils import export_to_video, load_video, load_image
|
||||
|
||||
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
pipeline = HeliosPyramidPipeline.from_pretrained(
|
||||
"BestWishYsh/Helios-Distilled",
|
||||
vae=vae,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
# For Text-to-Video
|
||||
prompt = """
|
||||
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
|
||||
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
|
||||
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
|
||||
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
|
||||
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
|
||||
the vivid colors of its surroundings. A close-up shot with dynamic movement.
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=240,
|
||||
pyramid_num_inference_steps_list=[2, 2, 2],
|
||||
guidance_scale=1.0,
|
||||
is_amplify_first_chunk=True,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_distilled_t2v_output.mp4", fps=24)
|
||||
|
||||
# For Image-to-Video
|
||||
prompt = """
|
||||
A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
|
||||
illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest,
|
||||
casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
|
||||
apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and
|
||||
relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
|
||||
respect for nature’s might.
|
||||
"""
|
||||
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=load_image(image_path).resize((640, 384)),
|
||||
num_frames=240,
|
||||
pyramid_num_inference_steps_list=[2, 2, 2],
|
||||
guidance_scale=1.0,
|
||||
is_amplify_first_chunk=True,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_distilled_i2v_output.mp4", fps=24)
|
||||
|
||||
# For Video-to-Video
|
||||
prompt = """
|
||||
A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
|
||||
under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
|
||||
emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
|
||||
the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
|
||||
A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
|
||||
"""
|
||||
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
video=load_video(video_path),
|
||||
num_frames=240,
|
||||
pyramid_num_inference_steps_list=[2, 2, 2],
|
||||
guidance_scale=1.0,
|
||||
is_amplify_first_chunk=True,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_distilled_v2v_output.mp4", fps=24)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
## HeliosPipeline
|
||||
|
||||
[[autodoc]] HeliosPipeline
|
||||
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HeliosPyramidPipeline
|
||||
|
||||
[[autodoc]] HeliosPyramidPipeline
|
||||
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HeliosPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.helios.pipeline_output.HeliosPipelineOutput
|
||||
@@ -1,20 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# HeliosScheduler
|
||||
|
||||
`HeliosScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers).
|
||||
|
||||
## HeliosScheduler
|
||||
[[autodoc]] HeliosScheduler
|
||||
|
||||
scheduling_helios
|
||||
@@ -1,20 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# HeliosDMDScheduler
|
||||
|
||||
`HeliosDMDScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers).
|
||||
|
||||
## HeliosDMDScheduler
|
||||
[[autodoc]] HeliosDMDScheduler
|
||||
|
||||
scheduling_helios_dmd
|
||||
@@ -332,49 +332,4 @@ Make your custom block work with Mellon's visual interface. See the [Mellon Cust
|
||||
Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Dependencies
|
||||
|
||||
Declaring package dependencies in custom blocks prevents runtime import errors later on. Diffusers validates the dependencies and returns a warning if a package is missing or incompatible.
|
||||
|
||||
Set a `_requirements` attribute in your block class, mapping package names to version specifiers.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import PipelineBlock
|
||||
|
||||
class MyCustomBlock(PipelineBlock):
|
||||
_requirements = {
|
||||
"transformers": ">=4.44.0",
|
||||
"sentencepiece": ">=0.2.0"
|
||||
}
|
||||
```
|
||||
|
||||
When there are blocks with different requirements, Diffusers merges their requirements.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
|
||||
class BlockA(PipelineBlock):
|
||||
_requirements = {"transformers": ">=4.44.0"}
|
||||
# ...
|
||||
|
||||
class BlockB(PipelineBlock):
|
||||
_requirements = {"sentencepiece": ">=0.2.0"}
|
||||
# ...
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict({
|
||||
"block_a": BlockA,
|
||||
"block_b": BlockB,
|
||||
})
|
||||
```
|
||||
|
||||
When this block is saved with [`~ModularPipeline.save_pretrained`], the requirements are saved to the `modular_config.json` file. When this block is loaded, Diffusers checks each requirement against the current environment. If there is a mismatch or a package isn't found, Diffusers returns the following warning.
|
||||
|
||||
```md
|
||||
# missing package
|
||||
xyz-package was specified in the requirements but wasn't found in the current environment.
|
||||
|
||||
# version mismatch
|
||||
xyz requirement 'specific-version' is not satisfied by the installed version 'actual-version'. Things might work unexpected.
|
||||
```
|
||||
</hfoptions>
|
||||
@@ -97,32 +97,5 @@ If the custom model inherits from the [`ModelMixin`] class, it gets access to th
|
||||
> )
|
||||
> ```
|
||||
|
||||
### Saving custom models
|
||||
|
||||
Use [`~ConfigMixin.register_for_auto_class`] to add the `auto_map` entry to `config.json` automatically when saving. This avoids having to manually edit the config file.
|
||||
|
||||
```py
|
||||
# my_model.py
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
|
||||
class MyCustomModel(ModelMixin, ConfigMixin):
|
||||
...
|
||||
|
||||
MyCustomModel.register_for_auto_class("AutoModel")
|
||||
|
||||
model = MyCustomModel(...)
|
||||
model.save_pretrained("./my_model")
|
||||
```
|
||||
|
||||
The saved `config.json` will include the `auto_map` field.
|
||||
|
||||
```json
|
||||
{
|
||||
"auto_map": {
|
||||
"AutoModel": "my_model.MyCustomModel"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
|
||||
@@ -60,7 +60,7 @@ export_to_video(video.frames[0], "output.mp4", fps=8)
|
||||
<tr>
|
||||
<th style="text-align: center;">Face Image</th>
|
||||
<th style="text-align: center;">Video</th>
|
||||
<th style="text-align: center;">Description</th>
|
||||
<th style="text-align: center;">Description</th
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_0.png?download=true" style="height: auto; width: 600px;"></td>
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
# Helios
|
||||
|
||||
[Helios](https://github.com/PKU-YuanGroup/Helios) is the first 14B video generation model that runs at 19.5 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching the quality of a strong baseline, natively integrating T2V, I2V, and V2V tasks within a unified architecture. The main features of Helios are:
|
||||
|
||||
- Without commonly used anti-drifting strategies (eg, self-forcing, error-banks, keyframe sampling, or inverted sampling), Helios generates minute-scale videos with high quality and strong coherence.
|
||||
- Without standard acceleration techniques (eg, KV-cache, causal masking, sparse/linear attention, TinyVAE, progressive noise schedules, hidden-state caching, or quantization), Helios achieves 19.5 FPS in end-to-end inference for a 14B video generation model on a single H100 GPU.
|
||||
- Introducing optimizations that improve both training and inference throughput while reducing memory consumption. These changes enable training a 14B video generation model without parallelism or sharding infrastructure, with batch sizes comparable to image models.
|
||||
|
||||
This guide will walk you through using Helios for use cases.
|
||||
|
||||
## Load Model Checkpoints
|
||||
|
||||
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import HeliosPipeline, HeliosPyramidPipeline
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# For Best Quality
|
||||
snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base")
|
||||
pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Intermediate Weight
|
||||
snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid")
|
||||
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# For Best Efficiency
|
||||
snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled")
|
||||
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
## Text-to-Video Showcases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th style="text-align: center;">Prompt</th>
|
||||
<th style="text-align: center;">Generated Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="4000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="4000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases2.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Image-to-Video Showcases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th style="text-align: center;">Image</th>
|
||||
<th style="text-align: center;">Prompt</th>
|
||||
<th style="text-align: center;">Generated Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.jpg" style="height: auto; width: 300px;"></td>
|
||||
<td><small>A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="2000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.jpg" style="height: auto; width: 300px;"></td>
|
||||
<td><small>A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="2000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Interactive-Video Showcases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th style="text-align: center;">Prompt</th>
|
||||
<th style="text-align: center;">Generated Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.txt">here</a></small></td>
|
||||
<td>
|
||||
<video width="680" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.txt">here</a></small></td>
|
||||
<td>
|
||||
<video width="680" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Resources
|
||||
|
||||
Learn more about Helios with the following resources.
|
||||
- Watch [video1](https://www.youtube.com/watch?v=vd_AgHtOUFQ) and [video2](https://www.youtube.com/watch?v=1GeIU2Dn7UY) for a demonstration of Helios's key features.
|
||||
- The research paper, [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) for more details.
|
||||
@@ -132,8 +132,6 @@
|
||||
sections:
|
||||
- local: using-diffusers/consisid
|
||||
title: ConsisID
|
||||
- local: using-diffusers/helios
|
||||
title: Helios
|
||||
|
||||
- title: Resources
|
||||
isExpanded: false
|
||||
|
||||
@@ -26,14 +26,6 @@ http://www.apache.org/licenses/LICENSE-2.0
|
||||
<th>项目名称</th>
|
||||
<th>描述</th>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/PKU-YuanGroup/Helios"> helios </a></td>
|
||||
<td>Helios:比1.3B更低开销、更快且更强的14B的实时长视频生成模型</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/PKU-YuanGroup/ConsisID"> consisid </a></td>
|
||||
<td>ConsisID:零样本身份保持的文本到视频生成模型</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/carson-katri/dream-textures"> dream-textures </a></td>
|
||||
<td>Stable Diffusion内置到Blender</td>
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
# Helios
|
||||
|
||||
[Helios](https://github.com/PKU-YuanGroup/Helios) 是首个能够在单张 NVIDIA H100 GPU 上以 19.5 FPS 运行的 14B 视频生成模型。它在支持分钟级视频生成的同时,拥有媲美强大基线模型的生成质量,并在统一架构下原生集成了文生视频(T2V)、图生视频(I2V)和视频生视频(V2V)任务。Helios 的主要特性包括:
|
||||
|
||||
- 无需常用的防漂移策略(例如:自强制/self-forcing、误差库/error-banks、关键帧采样或逆采样),我们的模型即可生成高质量且高度连贯的分钟级视频。
|
||||
- 无需标准的加速技术(例如:KV 缓存、因果掩码、稀疏/线性注意力机制、TinyVAE、渐进式噪声调度、隐藏状态缓存或量化),作为一款 14B 规模的视频生成模型,我们在单张 H100 GPU 上的端到端推理速度便达到了 19.5 FPS。
|
||||
- 引入了多项优化方案,在降低显存消耗的同时,显著提升了训练与推理的吞吐量。这些改进使得我们无需借助并行或分片(sharding)等基础设施,即可使用与图像模型相当的批大小(batch sizes)来训练 14B 的视频生成模型。
|
||||
|
||||
本指南将引导您完成 Helios 在不同场景下的使用。
|
||||
|
||||
## Load Model Checkpoints
|
||||
|
||||
模型权重可以存储在Hub上或本地的单独子文件夹中,在这种情况下,您应该使用 [`~DiffusionPipeline.from_pretrained`] 方法。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import HeliosPipeline, HeliosPyramidPipeline
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# For Best Quality
|
||||
snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base")
|
||||
pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Intermediate Weight
|
||||
snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid")
|
||||
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# For Best Efficiency
|
||||
snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled")
|
||||
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
## Text-to-Video Showcases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th style="text-align: center;">Prompt</th>
|
||||
<th style="text-align: center;">Generated Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="4000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="4000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases2.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Image-to-Video Showcases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th style="text-align: center;">Image</th>
|
||||
<th style="text-align: center;">Prompt</th>
|
||||
<th style="text-align: center;">Generated Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.jpg" style="height: auto; width: 300px;"></td>
|
||||
<td><small>A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="2000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.jpg" style="height: auto; width: 300px;"></td>
|
||||
<td><small>A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="2000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Interactive-Video Showcases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th style="text-align: center;">Prompt</th>
|
||||
<th style="text-align: center;">Generated Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.txt">here</a></small></td>
|
||||
<td>
|
||||
<video width="680" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.txt">here</a></small></td>
|
||||
<td>
|
||||
<video width="680" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Resources
|
||||
|
||||
通过以下资源了解有关 Helios 的更多信息:
|
||||
|
||||
- [视频1](https://www.youtube.com/watch?v=vd_AgHtOUFQ)和[视频2](https://www.youtube.com/watch?v=1GeIU2Dn7UY)演示了 Helios 的主要功能;
|
||||
- 有关更多详细信息,请参阅研究论文 [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/)。
|
||||
120
example.py
Normal file
120
example.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class QwenImageTransformerTesterConfig:
|
||||
model_class = QwenImageTransformer2DModel
|
||||
pretrained_model_name_or_path = ""
|
||||
pretrained_model_kwargs = {"subfolder": "transformer"}
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
# __init__ parameters:
|
||||
# patch_size: int = 2
|
||||
# in_channels: int = 64
|
||||
# out_channels: Optional[int] = 16
|
||||
# num_layers: int = 60
|
||||
# attention_head_dim: int = 128
|
||||
# num_attention_heads: int = 24
|
||||
# joint_attention_dim: int = 3584
|
||||
# guidance_embeds: bool = False
|
||||
# axes_dims_rope: Tuple[int, int, int] = <complex>
|
||||
return {}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
# forward() parameters:
|
||||
# hidden_states: torch.Tensor
|
||||
# encoder_hidden_states: torch.Tensor
|
||||
# encoder_hidden_states_mask: torch.Tensor
|
||||
# timestep: torch.LongTensor
|
||||
# img_shapes: Optional[List[Tuple[int, int, int]]]
|
||||
# txt_seq_lens: Optional[List[int]]
|
||||
# guidance: torch.Tensor
|
||||
# attention_kwargs: Optional[Dict[str, Any]]
|
||||
# controlnet_block_samples
|
||||
# return_dict: bool = True
|
||||
# TODO: Fill in dummy inputs
|
||||
return {}
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (1, 1)
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (1, 1)
|
||||
|
||||
|
||||
class TestQwenImageTransformerModel(QwenImageTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestQwenImageTransformerTorchCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
# TODO: Implement dynamic input generation
|
||||
return {}
|
||||
|
||||
|
||||
class TestQwenImageTransformerLora(QwenImageTransformerTesterConfig, LoraTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoraHotSwappingForModel(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
# TODO: Implement dynamic input generation
|
||||
return {}
|
||||
@@ -1232,49 +1232,22 @@ def main(args):
|
||||
id_token=args.id_token,
|
||||
)
|
||||
|
||||
def encode_video(video):
|
||||
def encode_video(video, bar):
|
||||
bar.update(1)
|
||||
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
|
||||
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||
latent_dist = vae.encode(video).latent_dist
|
||||
return latent_dist
|
||||
|
||||
# Distribute video encoding across processes: each process only encodes its own shard
|
||||
num_videos = len(train_dataset.instance_videos)
|
||||
num_procs = accelerator.num_processes
|
||||
local_rank = accelerator.process_index
|
||||
local_count = len(range(local_rank, num_videos, num_procs))
|
||||
|
||||
progress_encode_bar = tqdm(
|
||||
range(local_count),
|
||||
desc="Encoding videos",
|
||||
disable=not accelerator.is_local_main_process,
|
||||
range(0, len(train_dataset.instance_videos)),
|
||||
desc="Loading Encode videos",
|
||||
)
|
||||
|
||||
encoded_videos = [None] * num_videos
|
||||
for i, video in enumerate(train_dataset.instance_videos):
|
||||
if i % num_procs == local_rank:
|
||||
encoded_videos[i] = encode_video(video)
|
||||
progress_encode_bar.update(1)
|
||||
train_dataset.instance_videos = [
|
||||
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
|
||||
]
|
||||
progress_encode_bar.close()
|
||||
|
||||
# Broadcast encoded latent distributions so every process has the full set
|
||||
if num_procs > 1:
|
||||
import torch.distributed as dist
|
||||
|
||||
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
||||
|
||||
ref_params = next(v for v in encoded_videos if v is not None).parameters
|
||||
for i in range(num_videos):
|
||||
src = i % num_procs
|
||||
if encoded_videos[i] is not None:
|
||||
params = encoded_videos[i].parameters.contiguous()
|
||||
else:
|
||||
params = torch.empty_like(ref_params)
|
||||
dist.broadcast(params, src=src)
|
||||
encoded_videos[i] = DiagonalGaussianDistribution(params)
|
||||
|
||||
train_dataset.instance_videos = encoded_videos
|
||||
|
||||
def collate_fn(examples):
|
||||
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
# Training AutoencoderRAE
|
||||
|
||||
This example trains the decoder of `AutoencoderRAE` (stage-1 style), while keeping the representation encoder frozen.
|
||||
|
||||
It follows the same high-level training recipe as the official RAE stage-1 setup:
|
||||
- frozen encoder
|
||||
- train decoder
|
||||
- pixel reconstruction loss
|
||||
- optional encoder feature consistency loss
|
||||
|
||||
## Quickstart
|
||||
|
||||
### Resume or finetune from pretrained weights
|
||||
|
||||
```bash
|
||||
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
|
||||
--pretrained_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \
|
||||
--train_data_dir /path/to/imagenet_like_folder \
|
||||
--output_dir /tmp/autoencoder-rae \
|
||||
--resolution 256 \
|
||||
--train_batch_size 8 \
|
||||
--learning_rate 1e-4 \
|
||||
--num_train_epochs 10 \
|
||||
--report_to wandb \
|
||||
--reconstruction_loss_type l1 \
|
||||
--use_encoder_loss \
|
||||
--encoder_loss_weight 0.1
|
||||
```
|
||||
|
||||
### Train from scratch with a pretrained encoder
|
||||
The following command launches RAE training with "facebook/dinov2-with-registers-base" as the base.
|
||||
|
||||
```bash
|
||||
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
|
||||
--train_data_dir /path/to/imagenet_like_folder \
|
||||
--output_dir /tmp/autoencoder-rae \
|
||||
--resolution 256 \
|
||||
--encoder_type dinov2 \
|
||||
--encoder_name_or_path facebook/dinov2-with-registers-base \
|
||||
--encoder_input_size 224 \
|
||||
--patch_size 16 \
|
||||
--image_size 256 \
|
||||
--decoder_hidden_size 1152 \
|
||||
--decoder_num_hidden_layers 28 \
|
||||
--decoder_num_attention_heads 16 \
|
||||
--decoder_intermediate_size 4096 \
|
||||
--train_batch_size 8 \
|
||||
--learning_rate 1e-4 \
|
||||
--num_train_epochs 10 \
|
||||
--report_to wandb \
|
||||
--reconstruction_loss_type l1 \
|
||||
--use_encoder_loss \
|
||||
--encoder_loss_weight 0.1
|
||||
```
|
||||
|
||||
Note: stage-1 reconstruction loss assumes matching target/output spatial size, so `--resolution` must equal `--image_size`.
|
||||
|
||||
Dataset format is expected to be `ImageFolder`-compatible:
|
||||
|
||||
```text
|
||||
train_data_dir/
|
||||
class_a/
|
||||
img_0001.jpg
|
||||
class_b/
|
||||
img_0002.jpg
|
||||
```
|
||||
@@ -1,405 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import ImageFolder
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from diffusers import AutoencoderRAE
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Train a stage-1 Representation Autoencoder (RAE) decoder.")
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to an ImageFolder-style dataset root.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default="autoencoder-rae", help="Directory to save checkpoints/model."
|
||||
)
|
||||
parser.add_argument("--logging_dir", type=str, default="logs", help="Accelerate logging directory.")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
|
||||
parser.add_argument("--resolution", type=int, default=256)
|
||||
parser.add_argument("--center_crop", action="store_true")
|
||||
parser.add_argument("--random_flip", action="store_true")
|
||||
|
||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||
parser.add_argument("--dataloader_num_workers", type=int, default=4)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=10)
|
||||
parser.add_argument("--max_train_steps", type=int, default=None)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9)
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-8)
|
||||
parser.add_argument("--lr_scheduler", type=str, default="cosine")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=500)
|
||||
|
||||
parser.add_argument("--checkpointing_steps", type=int, default=1000)
|
||||
parser.add_argument("--validation_steps", type=int, default=500)
|
||||
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a pretrained AutoencoderRAE model (or HF Hub id) to resume training from.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"HF Hub id or local path of the pretrained encoder (e.g. 'facebook/dinov2-with-registers-base'). "
|
||||
"When --pretrained_model_name_or_path is not set, the encoder weights are loaded from this path "
|
||||
"into a freshly constructed AutoencoderRAE. Ignored when --pretrained_model_name_or_path is set."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "siglip2", "mae"], default="dinov2")
|
||||
parser.add_argument("--encoder_hidden_size", type=int, default=768)
|
||||
parser.add_argument("--encoder_patch_size", type=int, default=14)
|
||||
parser.add_argument("--encoder_num_hidden_layers", type=int, default=12)
|
||||
parser.add_argument("--encoder_input_size", type=int, default=224)
|
||||
parser.add_argument("--patch_size", type=int, default=16)
|
||||
parser.add_argument("--image_size", type=int, default=256)
|
||||
parser.add_argument("--num_channels", type=int, default=3)
|
||||
|
||||
parser.add_argument("--decoder_hidden_size", type=int, default=1152)
|
||||
parser.add_argument("--decoder_num_hidden_layers", type=int, default=28)
|
||||
parser.add_argument("--decoder_num_attention_heads", type=int, default=16)
|
||||
parser.add_argument("--decoder_intermediate_size", type=int, default=4096)
|
||||
|
||||
parser.add_argument("--noise_tau", type=float, default=0.0)
|
||||
parser.add_argument("--scaling_factor", type=float, default=1.0)
|
||||
parser.add_argument("--reshape_to_2d", action=argparse.BooleanOptionalAction, default=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--reconstruction_loss_type",
|
||||
type=str,
|
||||
choices=["l1", "mse"],
|
||||
default="l1",
|
||||
help="Pixel reconstruction loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_loss_weight",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Weight for encoder feature consistency loss in the training loop.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_encoder_loss",
|
||||
action="store_true",
|
||||
help="Enable encoder feature consistency loss term in the training loop.",
|
||||
)
|
||||
parser.add_argument("--report_to", type=str, default="tensorboard")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_transforms(args):
|
||||
image_transforms = [
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
]
|
||||
if args.random_flip:
|
||||
image_transforms.append(transforms.RandomHorizontalFlip())
|
||||
image_transforms.append(transforms.ToTensor())
|
||||
return transforms.Compose(image_transforms)
|
||||
|
||||
|
||||
def compute_losses(
|
||||
model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float
|
||||
):
|
||||
decoded = model(pixel_values).sample
|
||||
|
||||
if decoded.shape[-2:] != pixel_values.shape[-2:]:
|
||||
raise ValueError(
|
||||
"Training requires matching reconstruction and target sizes, got "
|
||||
f"decoded={tuple(decoded.shape[-2:])}, target={tuple(pixel_values.shape[-2:])}."
|
||||
)
|
||||
|
||||
if reconstruction_loss_type == "l1":
|
||||
reconstruction_loss = F.l1_loss(decoded.float(), pixel_values.float())
|
||||
else:
|
||||
reconstruction_loss = F.mse_loss(decoded.float(), pixel_values.float())
|
||||
|
||||
encoder_loss = torch.zeros_like(reconstruction_loss)
|
||||
if use_encoder_loss and encoder_loss_weight > 0:
|
||||
base_model = model.module if hasattr(model, "module") else model
|
||||
target_encoder_input = base_model._resize_and_normalize(pixel_values)
|
||||
reconstructed_encoder_input = base_model._resize_and_normalize(decoded)
|
||||
|
||||
encoder_forward_kwargs = {"model": base_model.encoder}
|
||||
if base_model.config.encoder_type == "mae":
|
||||
encoder_forward_kwargs["patch_size"] = base_model.config.encoder_patch_size
|
||||
with torch.no_grad():
|
||||
target_tokens = base_model._encoder_forward_fn(images=target_encoder_input, **encoder_forward_kwargs)
|
||||
reconstructed_tokens = base_model._encoder_forward_fn(
|
||||
images=reconstructed_encoder_input, **encoder_forward_kwargs
|
||||
)
|
||||
encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float())
|
||||
|
||||
loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss
|
||||
return decoded, loss, reconstruction_loss, encoder_loss
|
||||
|
||||
|
||||
def _strip_final_layernorm_affine(state_dict, prefix=""):
|
||||
"""Remove final layernorm weight/bias so the model keeps its default init (identity)."""
|
||||
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
|
||||
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
|
||||
|
||||
|
||||
def _load_pretrained_encoder_weights(model, encoder_type, encoder_name_or_path):
|
||||
"""Load pretrained HF transformers encoder weights into the model's encoder."""
|
||||
if encoder_type == "dinov2":
|
||||
from transformers import Dinov2WithRegistersModel
|
||||
|
||||
hf_encoder = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
|
||||
state_dict = hf_encoder.state_dict()
|
||||
state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.")
|
||||
elif encoder_type == "siglip2":
|
||||
from transformers import SiglipModel
|
||||
|
||||
hf_encoder = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
|
||||
state_dict = {f"vision_model.{k}": v for k, v in hf_encoder.state_dict().items()}
|
||||
state_dict = _strip_final_layernorm_affine(state_dict, prefix="vision_model.post_layernorm.")
|
||||
elif encoder_type == "mae":
|
||||
from transformers import ViTMAEForPreTraining
|
||||
|
||||
hf_encoder = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
|
||||
state_dict = hf_encoder.state_dict()
|
||||
state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.")
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder_type: {encoder_type}")
|
||||
|
||||
model.encoder.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.resolution != args.image_size:
|
||||
raise ValueError(
|
||||
f"`--resolution` ({args.resolution}) must match `--image_size` ({args.image_size}) "
|
||||
"for stage-1 reconstruction loss."
|
||||
)
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
project_config=accelerator_project_config,
|
||||
log_with=args.report_to,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
dataset = ImageFolder(args.train_data_dir, transform=build_transforms(args))
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example[0] for example in examples]).float()
|
||||
return {"pixel_values": pixel_values}
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
if args.pretrained_model_name_or_path is not None:
|
||||
model = AutoencoderRAE.from_pretrained(args.pretrained_model_name_or_path)
|
||||
logger.info(f"Loaded pretrained AutoencoderRAE from {args.pretrained_model_name_or_path}")
|
||||
else:
|
||||
model = AutoencoderRAE(
|
||||
encoder_type=args.encoder_type,
|
||||
encoder_hidden_size=args.encoder_hidden_size,
|
||||
encoder_patch_size=args.encoder_patch_size,
|
||||
encoder_num_hidden_layers=args.encoder_num_hidden_layers,
|
||||
decoder_hidden_size=args.decoder_hidden_size,
|
||||
decoder_num_hidden_layers=args.decoder_num_hidden_layers,
|
||||
decoder_num_attention_heads=args.decoder_num_attention_heads,
|
||||
decoder_intermediate_size=args.decoder_intermediate_size,
|
||||
patch_size=args.patch_size,
|
||||
encoder_input_size=args.encoder_input_size,
|
||||
image_size=args.image_size,
|
||||
num_channels=args.num_channels,
|
||||
noise_tau=args.noise_tau,
|
||||
reshape_to_2d=args.reshape_to_2d,
|
||||
use_encoder_loss=args.use_encoder_loss,
|
||||
scaling_factor=args.scaling_factor,
|
||||
)
|
||||
if args.encoder_name_or_path is not None:
|
||||
_load_pretrained_encoder_weights(model, args.encoder_type, args.encoder_name_or_path)
|
||||
logger.info(f"Loaded pretrained encoder weights from {args.encoder_name_or_path}")
|
||||
model.encoder.requires_grad_(False)
|
||||
model.decoder.requires_grad_(True)
|
||||
model.train()
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
(p for p in model.parameters() if p.requires_grad),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if overrode_max_train_steps:
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("train_autoencoder_rae", config=vars(args))
|
||||
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
pixel_values = batch["pixel_values"]
|
||||
|
||||
_, loss, reconstruction_loss, encoder_loss = compute_losses(
|
||||
model,
|
||||
pixel_values,
|
||||
reconstruction_loss_type=args.reconstruction_loss_type,
|
||||
use_encoder_loss=args.use_encoder_loss,
|
||||
encoder_loss_weight=args.encoder_loss_weight,
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"reconstruction_loss": reconstruction_loss.detach().item(),
|
||||
"encoder_loss": encoder_loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step % args.validation_steps == 0:
|
||||
with torch.no_grad():
|
||||
_, val_loss, val_reconstruction_loss, val_encoder_loss = compute_losses(
|
||||
model,
|
||||
pixel_values,
|
||||
reconstruction_loss_type=args.reconstruction_loss_type,
|
||||
use_encoder_loss=args.use_encoder_loss,
|
||||
encoder_loss_weight=args.encoder_loss_weight,
|
||||
)
|
||||
accelerator.log(
|
||||
{
|
||||
"val/loss": val_loss.detach().item(),
|
||||
"val/reconstruction_loss": val_reconstruction_loss.detach().item(),
|
||||
"val/encoder_loss": val_encoder_loss.detach().item(),
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(save_path)
|
||||
logger.info(f"Saved checkpoint to {save_path}")
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
73
modular_model_index.json
Normal file
73
modular_model_index.json
Normal file
@@ -0,0 +1,73 @@
|
||||
{
|
||||
"_blocks_class_name": "SequentialPipelineBlocks",
|
||||
"_class_name": "Flux2ModularPipeline",
|
||||
"_diffusers_version": "0.36.0.dev0",
|
||||
"scheduler": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev",
|
||||
"revision": null,
|
||||
"subfolder": "scheduler",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"FlowMatchEulerDiscreteScheduler"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"text_encoder": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"revision": null,
|
||||
"subfolder": "text_encoder",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"Mistral3ForConditionalGeneration"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"tokenizer": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"revision": null,
|
||||
"subfolder": "tokenizer",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"AutoProcessor"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"transformer": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"pretrained_model_name_or_path": "diffusers/FLUX.2-dev-bnb-4bit",
|
||||
"revision": null,
|
||||
"subfolder": "transformer",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"Flux2Transformer2DModel"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"vae": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev",
|
||||
"revision": null,
|
||||
"subfolder": "vae",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"AutoencoderKLFlux2"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
]
|
||||
}
|
||||
56
pr_review/12498.md
Normal file
56
pr_review/12498.md
Normal file
@@ -0,0 +1,56 @@
|
||||
Now let me look at what `dequantize_gguf_tensor` does to understand the fix better:
|
||||
|
||||
## Code Review: GGUF fix for unquantized types
|
||||
|
||||
### 1. Summary of Changes
|
||||
|
||||
The PR fixes a bug in the `_fused_mul_mat_gguf` function (line 79-105) where unquantized GGUF tensor types (F32, F16, BF16) were incorrectly handled.
|
||||
|
||||
**Before:** When `qweight_type` was an unquantized type, the code directly performed matrix multiplication: `x @ qweight.T`
|
||||
|
||||
**After:** It now calls `dequantize_gguf_tensor(qweight)` first, then performs the matrix multiplication: `x @ weight.T`
|
||||
|
||||
The issue was that even "unquantized" GGUF tensors are stored in an 8-bit tensor format and need to be converted to their proper data type representation before use.
|
||||
|
||||
### 2. Potential Issues or Bugs
|
||||
|
||||
**None identified.** The fix is correct and addresses a real bug:
|
||||
|
||||
- The `dequantize_gguf_tensor` function (lines 509-527) checks if the tensor has a `quant_type` attribute and handles the appropriate conversion
|
||||
- For BF16 specifically, there's a dedicated `dequantize_blocks_BF16` function (lines 428-429) that properly converts the 8-bit storage format
|
||||
- The fix aligns with how the native path already works in `forward_native` (lines 593-599), which always calls `dequantize_gguf_tensor`
|
||||
|
||||
### 3. Code Quality Observations
|
||||
|
||||
**Strengths:**
|
||||
- The fix is minimal and surgical - only changes what's necessary
|
||||
- Maintains consistency with the `forward_native` path which already uses `dequantize_gguf_tensor`
|
||||
- The variable naming (`weight` instead of reusing `qweight`) makes it clear a transformation occurred
|
||||
|
||||
**Minor observation:**
|
||||
- The comment on line 80 "there is no need to call any kernel for fp16/bf16" is now slightly misleading since we DO need to call dequantization logic. Consider updating it to something like: "no need to call specialized GGUF kernel for fp16/bf16, but still need to dequantize from 8-bit storage"
|
||||
|
||||
### 4. Security Considerations
|
||||
|
||||
**No security concerns.** The change:
|
||||
- Doesn't introduce any external input handling
|
||||
- Doesn't modify control flow in a way that could bypass security checks
|
||||
- Only fixes a data type conversion issue
|
||||
|
||||
### 5. Suggestions for Improvement
|
||||
|
||||
1. **Update the comment** on line 80 in `src/diffusers/quantizers/gguf/utils.py:80`:
|
||||
```python
|
||||
# unquantized types still need dequantization from 8-bit storage, but don't need specialized kernels
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
weight = dequantize_gguf_tensor(qweight)
|
||||
return x @ weight.T
|
||||
```
|
||||
|
||||
2. **Consider adding a test** to prevent regression of this issue. A test should verify that unquantized GGUF tensors produce correct output shapes and values.
|
||||
|
||||
3. **Documentation:** The PR description mentions torch 2.8/2.9 build availability. This might be worth tracking in a GitHub issue if not already done.
|
||||
|
||||
### Verdict
|
||||
|
||||
**Approve with minor comment update suggestion.** The fix correctly addresses a real shape mismatch bug where GGUF's 8-bit storage format wasn't being properly converted for unquantized types. The logic is sound and aligns with the existing native implementation path.
|
||||
186
pr_review/12744.md
Normal file
186
pr_review/12744.md
Normal file
@@ -0,0 +1,186 @@
|
||||
I'll provide a comprehensive code review of this MagCache PR.
|
||||
|
||||
## Summary of Changes
|
||||
|
||||
This PR implements MagCache (Magnitude-aware Cache), a training-free inference acceleration technique for diffusion transformers. The implementation:
|
||||
|
||||
- Adds a `MagCacheConfig` class for configuration
|
||||
- Implements `MagCacheHeadHook` and `MagCacheBlockHook` following the existing ModelHook pattern
|
||||
- Includes calibration mode to compute magnitude ratios for any transformer model
|
||||
- Provides pre-computed `FLUX_MAG_RATIOS` for Flux models
|
||||
- Adds comprehensive documentation and tests
|
||||
|
||||
## Potential Issues and Bugs
|
||||
|
||||
### 1. **Critical: Missing Hook Removal in `disable_cache()`**
|
||||
```python
|
||||
# In cache_utils.py, line ~127
|
||||
elif isinstance(self._cache_config, MagCacheConfig):
|
||||
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True)
|
||||
```
|
||||
|
||||
**Issue**: The code only removes the leader/head hook but not the block hooks (`_MAG_CACHE_BLOCK_HOOK`). This will leave hooks attached when disabling the cache.
|
||||
|
||||
**Fix**: Add removal of block hooks:
|
||||
```python
|
||||
elif isinstance(self._cache_config, MagCacheConfig):
|
||||
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
|
||||
```
|
||||
|
||||
### 2. **Shape Mismatch Handling Logic Issue**
|
||||
In `mag_cache.py` lines 224-248, the shape mismatch handling has a potential issue:
|
||||
|
||||
```python
|
||||
elif (
|
||||
output.ndim == 3
|
||||
and res.ndim == 3
|
||||
and output.shape[0] == res.shape[0]
|
||||
and output.shape[2] == res.shape[2]
|
||||
):
|
||||
diff = output.shape[1] - res.shape[1]
|
||||
if diff > 0:
|
||||
output = output.clone()
|
||||
output[:, diff:, :] = output[:, diff:, :] + res
|
||||
```
|
||||
|
||||
**Issue**: This assumes text tokens come first and image tokens come last. This may not be universal across all models (e.g., some models interleave tokens differently).
|
||||
|
||||
**Suggestion**: Add a comment explaining this assumption or add configuration to specify the concatenation strategy.
|
||||
|
||||
### 3. **Residual Calculation Fallback is Unsafe**
|
||||
In `mag_cache.py` line 343:
|
||||
|
||||
```python
|
||||
else:
|
||||
# Fallback for completely mismatched shapes
|
||||
residual = out_hidden
|
||||
```
|
||||
|
||||
**Issue**: This fallback doesn't compute a residual at all—it just uses the output. This will cause incorrect behavior in subsequent steps.
|
||||
|
||||
**Suggestion**: Either raise an error or add a warning that calibration is required for this model architecture.
|
||||
|
||||
### 4. **Device Mismatch Handling is Incomplete**
|
||||
```python
|
||||
if res.device != output.device:
|
||||
res = res.to(output.device)
|
||||
```
|
||||
|
||||
**Issue**: This only handles device mismatch for the residual, but doesn't handle dtype mismatches which could occur with mixed precision training.
|
||||
|
||||
**Suggestion**: Add dtype handling:
|
||||
```python
|
||||
if res.device != output.device or res.dtype != output.dtype:
|
||||
res = res.to(device=output.device, dtype=output.dtype)
|
||||
```
|
||||
|
||||
### 5. **Calibration Logging Could Be Missed**
|
||||
The calibration results are printed to stdout (line 380) and logged. However, if the user has logging disabled or redirected, they might miss this critical information.
|
||||
|
||||
**Suggestion**: Consider returning calibration results from the pipeline or raising a more visible notification.
|
||||
|
||||
### 6. **Test Suite is Skipped**
|
||||
```python
|
||||
@unittest.skip("MagCache unit tests are skipped.")
|
||||
class MagCacheTests(unittest.TestCase):
|
||||
```
|
||||
|
||||
**Issue**: All unit tests are skipped, which means the core logic isn't being validated in CI.
|
||||
|
||||
**Action Required**: Remove the skip decorator before merging or add a comment explaining why it's temporarily skipped.
|
||||
|
||||
## Code Quality Observations
|
||||
|
||||
### Strengths:
|
||||
1. **Well-structured**: Follows existing patterns (ModelHook, StateManager) consistently
|
||||
2. **Good documentation**: Comprehensive docstrings and inline comments
|
||||
3. **Calibration mode**: Clever design allowing model-agnostic usage
|
||||
4. **Error handling**: Validates configuration upfront
|
||||
5. **Interpolation logic**: Smart handling of different step counts via `nearest_interp()`
|
||||
|
||||
### Areas for Improvement:
|
||||
|
||||
1. **Magic Numbers**: Several hardcoded values could be constants:
|
||||
```python
|
||||
eps = 1e-8 # Line 335 in _perform_calibration_step
|
||||
expected_atol = 0.1 # Line 2989 in test
|
||||
```
|
||||
|
||||
2. **Code Duplication**: The logic for handling tuple returns appears multiple times. Consider extracting to a helper method.
|
||||
|
||||
3. **Type Hints**: Some methods lack return type hints (e.g., `nearest_interp`)
|
||||
|
||||
4. **Compiler Disable Decorator**: The `@torch.compiler.disable` decorator is used but not explained. Add a comment about why compilation is disabled.
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Low Risk:
|
||||
- No external network calls
|
||||
- No file system access beyond logging
|
||||
- No execution of arbitrary code
|
||||
- Tensor operations are standard PyTorch
|
||||
|
||||
### Observations:
|
||||
1. **Device Transfer**: The `.to(device)` calls are safe but could consume unexpected memory if tensors are large
|
||||
2. **State Management**: The state is properly isolated and reset between inference runs
|
||||
|
||||
## Suggestions for Improvement
|
||||
|
||||
### 1. Add Configuration Validation
|
||||
```python
|
||||
def __post_init__(self):
|
||||
# Existing checks...
|
||||
|
||||
# Add bounds checking
|
||||
if not 0.0 <= self.retention_ratio <= 1.0:
|
||||
raise ValueError(f"retention_ratio must be in [0, 1], got {self.retention_ratio}")
|
||||
if self.max_skip_steps < 1:
|
||||
raise ValueError(f"max_skip_steps must be >= 1, got {self.max_skip_steps}")
|
||||
if self.threshold <= 0:
|
||||
raise ValueError(f"threshold must be positive, got {self.threshold}")
|
||||
```
|
||||
|
||||
### 2. Add Metrics/Statistics
|
||||
Consider adding optional statistics collection:
|
||||
- How many blocks were skipped per step
|
||||
- Average accumulated error
|
||||
- Total compute savings
|
||||
|
||||
This would help users optimize their thresholds.
|
||||
|
||||
### 3. Improve Documentation Example
|
||||
The documentation example could show expected speedup or quality metrics to set user expectations.
|
||||
|
||||
### 4. Add Gradient Mode Check
|
||||
```python
|
||||
if torch.is_grad_enabled():
|
||||
logger.warning("MagCache is designed for inference only. Gradients are enabled but will not flow correctly through cached blocks.")
|
||||
```
|
||||
|
||||
### 5. Consider Memory Cleanup
|
||||
The `previous_residual` is held in state indefinitely. Consider adding explicit cleanup:
|
||||
```python
|
||||
def cleanup(self):
|
||||
if self.previous_residual is not None:
|
||||
del self.previous_residual
|
||||
self.previous_residual = None
|
||||
```
|
||||
|
||||
## Minor Issues
|
||||
|
||||
1. **Line 26**: Unused import or should be used in logger initialization
|
||||
2. **Line 332**: Comment says "Fallback to matching tail" but logic is unclear
|
||||
3. **Documentation**: The TIP about batched CFG could include more detail about why this works
|
||||
|
||||
## Conclusion
|
||||
|
||||
This is a **well-implemented feature** with good design patterns and documentation. The main concerns are:
|
||||
|
||||
1. **Critical**: Fix the missing block hook removal in `disable_cache()` (Line 127)
|
||||
2. **Important**: Unskip and fix the unit tests
|
||||
3. **Recommended**: Improve shape mismatch handling with better error messages
|
||||
|
||||
The implementation is production-ready once these issues are addressed. The calibration mode is particularly clever and makes this genuinely model-agnostic.
|
||||
|
||||
**Recommendation**: Request changes for items #1 and #2, then approve once fixed.
|
||||
99
pr_review/13028.md
Normal file
99
pr_review/13028.md
Normal file
@@ -0,0 +1,99 @@
|
||||
# PR #13028: [Modular] add explicit workflow support
|
||||
|
||||
**Author:** @yiyixuxu
|
||||
**Branch:** `modular-workflow` -> `main`
|
||||
**Files changed:** `modular_pipeline.py`, `modular_pipeline_utils.py`, `qwenimage/modular_blocks_qwenimage.py`
|
||||
**+298 / -165**
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
This PR adds a `_workflow_map` class attribute to `SequentialPipelineBlocks` that maps named workflows (e.g., `"text2image"`, `"inpainting"`) to their trigger inputs. Users can then call `get_workflow("text2image")` to get the execution blocks for that workflow. The PR also refactors `get_execution_blocks` into `ConditionalPipelineBlocks` and `SequentialPipelineBlocks`, moves `combine_inputs`/`combine_outputs` to module-level functions, and improves docstrings.
|
||||
|
||||
## Main Concern: "Workflow" as a New Concept
|
||||
|
||||
Modular Diffusers already requires users to learn: **Pipelines**, **Blocks** (Sequential, Conditional, Auto, Loop), **Steps**, **Components**, **Inputs/Outputs**, **Trigger Inputs**, **Execution Blocks**, **PipelineState**, and **BlockState**. Adding "workflow" as yet another term increases cognitive overhead.
|
||||
|
||||
The underlying feature is useful — named presets for trigger inputs are genuinely helpful for discoverability. But "workflow" may not be the right label:
|
||||
|
||||
1. **Overloaded term**: "Workflow" is heavily used in the AI/ML ecosystem (ComfyUI workflows, orchestration workflows, CI/CD workflows). Users may expect something more complex than what this is.
|
||||
|
||||
2. **It's really a task/mode, not a workflow**: `"text2image"`, `"inpainting"`, `"image2image"` are *tasks* or *modes*. The rest of diffusers already uses "task" terminology — `AutoPipelineForText2Image`, `AutoPipelineForInpainting`, etc. Calling the same concept "workflow" in Modular Diffusers creates inconsistency.
|
||||
|
||||
3. **It's a thin wrapper**: `get_workflow("text2image")` is just `get_execution_blocks(prompt=True)`. Users still need to understand `get_execution_blocks` and trigger inputs to do anything beyond the predefined workflows. The abstraction doesn't save much complexity.
|
||||
|
||||
**Suggestion**: Consider `_task_map` / `get_task()` / `task_names` to align with existing diffusers terminology, or `_mode_map` / `get_mode()` / `mode_names` for something more neutral. The existing `auto_pipeline.py` already uses "task" internally — `_get_task_class()` maps pipeline class names to task-specific variants (text2image, image2image, inpainting), and the public API follows the `AutoPipelineFor<Task>` naming pattern. These are the exact same concepts this PR calls "workflows." Alternatively, this could simply be better documentation on `get_execution_blocks` with named examples, rather than a new API surface.
|
||||
|
||||
## Code Issues
|
||||
|
||||
### Behavioral change: `outputs` -> `intermediate_outputs` in traversal
|
||||
|
||||
`modular_pipeline.py` — In `SequentialPipelineBlocks.get_execution_blocks`, the old `_traverse_trigger_blocks` tracked `block.outputs` to propagate available values to downstream blocks. The new code tracks `block.intermediate_outputs` instead:
|
||||
|
||||
```python
|
||||
# Old
|
||||
if hasattr(block, "outputs"):
|
||||
for out in block.outputs:
|
||||
active_inputs[out.name] = True
|
||||
|
||||
# New
|
||||
if hasattr(block, "intermediate_outputs"):
|
||||
for out in block.intermediate_outputs:
|
||||
active_inputs[out.name] = True
|
||||
```
|
||||
|
||||
`intermediate_outputs` and `outputs` can differ — `intermediate_outputs` includes values passed between blocks in the pipeline state, while `outputs` are the final outputs. This could change which downstream conditional blocks get triggered. If this is intentional, it should be called out explicitly in the PR description since it affects existing behavior.
|
||||
|
||||
### `_workflow_map` on base class, implementations only on `SequentialPipelineBlocks`
|
||||
|
||||
`_workflow_map = None` is defined on `ModularPipelineBlocks` (the base class), but `workflow_names` and `get_workflow()` are only implemented on `SequentialPipelineBlocks`. The base class stubs raise `NotImplementedError`. This is misleading — it suggests workflows *could* be implemented for other block types. If workflows are intentionally only for `SequentialPipelineBlocks`, define `_workflow_map` there and don't add stubs to the base class.
|
||||
|
||||
### `get_execution_blocks` no longer filters None values
|
||||
|
||||
Old code:
|
||||
```python
|
||||
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
```
|
||||
|
||||
New code:
|
||||
```python
|
||||
active_inputs = dict(kwargs)
|
||||
```
|
||||
|
||||
This is a behavioral change to the public `get_execution_blocks` API. The old code explicitly stripped `None` values so users could write `get_execution_blocks(prompt="a cat", image=None)` and `image` wouldn't trigger anything. The new code passes `None` through. It happens to still work because `select_block` checks `is not None` internally, but callers can no longer rely on the documented filtering behavior. This should be noted.
|
||||
|
||||
### `default_block_name` changed from property to instance attribute
|
||||
|
||||
In `AutoPipelineBlocks`, `default_block_name` was a `@property` that derived the default from `block_trigger_inputs` on every access. It's now set as an instance attribute in `__init__`. This is mostly fine, but the new code also adds a validation that `default_block_name is not None` raises an error before it's set — so subclasses that accidentally set `default_block_name` as a class attribute will now break. This is a stricter contract that should be documented.
|
||||
|
||||
### Typo
|
||||
|
||||
`modular_pipeline.py` — `# currentlyonly ConditionalPipelineBlocks` should be `# currently only`.
|
||||
|
||||
### `_get_trigger_inputs()` called multiple times in `__repr__`
|
||||
|
||||
In `SequentialPipelineBlocks.__repr__`, `self._get_trigger_inputs()` is called 3 times (condition check, trigger inputs display, example input). This recursively traverses all blocks each time. Should be computed once and reused.
|
||||
|
||||
### Duplicate `format_workflow` calls in `__repr__` and `doc`
|
||||
|
||||
Both `SequentialPipelineBlocks.__repr__` and `SequentialPipelineBlocks.doc` build the description + workflow string independently with identical logic:
|
||||
|
||||
```python
|
||||
description = self.description
|
||||
if self._workflow_map is not None:
|
||||
workflow_str = format_workflow(self._workflow_map)
|
||||
description = f"{self.description}\n\n{workflow_str}"
|
||||
```
|
||||
|
||||
This should be extracted into a property or helper.
|
||||
|
||||
### No tests
|
||||
|
||||
The PR description mentions "I will add a test suite for this too!" but there are no tests included. Workflow resolution, edge cases (empty workflow map, missing workflow name, workflows with overlapping triggers), and the `get_execution_blocks` refactoring should all be tested before merge.
|
||||
|
||||
## Refactoring Quality
|
||||
|
||||
The refactoring of `get_execution_blocks` from a monolithic method on `SequentialPipelineBlocks` into separate implementations on `ConditionalPipelineBlocks` and `SequentialPipelineBlocks` is a good separation of concerns. Moving `combine_inputs`/`combine_outputs` to module-level functions is also reasonable since they don't depend on instance state.
|
||||
|
||||
The improved `AutoPipelineBlocks` docstring with the example is a significant documentation improvement.
|
||||
97
pr_review/13075.md
Normal file
97
pr_review/13075.md
Normal file
@@ -0,0 +1,97 @@
|
||||
I'll review this PR that addresses PyTorch version compatibility for distributed operations.
|
||||
|
||||
## Summary of Changes
|
||||
|
||||
The PR refactors the `gather_size_by_comm` function in `_modeling_parallel.py` to handle PyTorch versions prior to 2.6 that don't have the `torch.accelerator` API. The changes replace a single ternary expression with a multi-level conditional that:
|
||||
|
||||
1. First checks if "cpu" is in the backend string
|
||||
2. Then checks if `torch.accelerator` exists (PyTorch >= 2.6)
|
||||
3. Falls back to CUDA as a default device
|
||||
|
||||
## Potential Issues or Bugs
|
||||
|
||||
**1. Device Type Inconsistency**
|
||||
The original code returns a string `"cpu"` but the new code returns `torch.device("cuda")` objects. This inconsistency could cause issues:
|
||||
|
||||
```python
|
||||
gather_device = "cpu" # str
|
||||
# vs
|
||||
gather_device = torch.device("cuda") # torch.device object
|
||||
```
|
||||
|
||||
**Recommendation:** Use `torch.device()` consistently:
|
||||
```python
|
||||
if "cpu" in comm_backends:
|
||||
gather_device = torch.device("cpu")
|
||||
elif hasattr(torch, "accelerator"):
|
||||
acc = torch.accelerator.current_accelerator()
|
||||
gather_device = torch.device(acc if acc is not None else "cuda")
|
||||
else:
|
||||
gather_device = torch.device("cuda")
|
||||
```
|
||||
|
||||
**2. Unclear Accelerator Return Behavior**
|
||||
The comment states "Fall back to CUDA when no accelerator is returned" but it's unclear when `torch.accelerator.current_accelerator()` would return `None`. This should be verified or documented.
|
||||
|
||||
**3. Missing Type Information**
|
||||
What type does `torch.accelerator.current_accelerator()` return? If it returns a string like `"cuda"` or `"mps"`, the code should handle it consistently. If it returns a device object, the logic might need adjustment.
|
||||
|
||||
## Code Quality Observations
|
||||
|
||||
**Positive:**
|
||||
- Clear comments explaining the fallback logic
|
||||
- Proper use of `hasattr()` for backward compatibility
|
||||
- Addresses the reported issue #13074
|
||||
|
||||
**Areas for Improvement:**
|
||||
|
||||
1. **Device type consistency** (mentioned above)
|
||||
|
||||
2. **Consider alternative hardware accelerators:** The fallback to CUDA might not be appropriate for all systems (e.g., MPS on macOS, XPU on Intel). Consider:
|
||||
```python
|
||||
else:
|
||||
# Fallback for PyTorch < 2.6
|
||||
if torch.cuda.is_available():
|
||||
gather_device = torch.device("cuda")
|
||||
else:
|
||||
gather_device = torch.device("cpu")
|
||||
```
|
||||
|
||||
3. **Code style:** The expanded conditional is more readable but could benefit from extracting into a helper function if this pattern appears elsewhere:
|
||||
```python
|
||||
def _get_gather_device(comm_backends: str) -> torch.device:
|
||||
"""Determine device for distributed gather operations."""
|
||||
# ... implementation
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
No significant security issues identified. This is primarily a compatibility fix for internal device selection logic.
|
||||
|
||||
## Suggestions for Improvement
|
||||
|
||||
1. **Add a test case** to verify behavior on PyTorch < 2.6 (if not already covered)
|
||||
|
||||
2. **Document the behavior** more explicitly:
|
||||
```python
|
||||
# Determine gather device based on backend and PyTorch version
|
||||
# Priority: CPU backend > torch.accelerator (>= 2.6) > CUDA fallback (< 2.6)
|
||||
```
|
||||
|
||||
3. **Consider this more defensive approach:**
|
||||
```python
|
||||
if "cpu" in comm_backends:
|
||||
gather_device = torch.device("cpu")
|
||||
elif hasattr(torch, "accelerator"):
|
||||
acc = torch.accelerator.current_accelerator()
|
||||
gather_device = torch.device(acc if acc else "cuda")
|
||||
elif torch.cuda.is_available():
|
||||
gather_device = torch.device("cuda")
|
||||
else:
|
||||
# Fallback to CPU if no GPU available
|
||||
gather_device = torch.device("cpu")
|
||||
```
|
||||
|
||||
## Verdict
|
||||
|
||||
The PR addresses the compatibility issue but has a **type inconsistency bug** that should be fixed before merging. The string vs `torch.device` object mismatch could cause runtime errors. Once that's addressed, the change is sound for backward compatibility.
|
||||
66
pr_review/13116.md
Normal file
66
pr_review/13116.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# PR #13116: [tests] tests for `modules_to_not_convert`
|
||||
|
||||
**Author:** @sayakpaul
|
||||
**Branch:** `fix-modules-no-convert-torchao` -> `main`
|
||||
**Files changed:** `tests/models/testing_utils/quantization.py`, `tests/models/transformers/test_models_transformer_flux.py`
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
This PR fixes the `modules_to_not_convert` tests that were effectively dead code. They existed in the base `QuantizationTesterMixin` but never ran because no test class defined `modules_to_not_convert_for_test`. The PR activates these tests for Flux and fixes several underlying bugs that would have caused them to fail.
|
||||
|
||||
## Key Changes
|
||||
|
||||
1. **BnB config key fix**: `BitsAndBytesConfig` uses `llm_int8_skip_modules`, not `modules_to_not_convert`. The base test was setting the wrong key, so modules were never actually excluded.
|
||||
|
||||
2. **TorchAO `_verify_if_layer_quantized` fix**: Previously only checked `isinstance(module, torch.nn.Linear)`, which is always true for TorchAO (it doesn't replace the module class). Now properly checks weight tensor types (`AffineQuantizedTensor`, `LinearActivationQuantizedTensor`).
|
||||
|
||||
3. **`_is_module_quantized` fix**: Now passes `quant_config_kwargs` to `_verify_if_layer_quantized`. Previously it passed `{}`, which caused BnB to always check for `Int8Params` even on 4-bit models.
|
||||
|
||||
4. **Cleanup**: Removes unused guard blocks (`is_gguf_available`, `is_torchao_available`) that only contained `pass`.
|
||||
|
||||
5. **Activates tests**: Adds `modules_to_not_convert_for_test` returning `["norm_out.linear"]` to BnB, Quanto, TorchAo, and ModelOpt Flux test classes.
|
||||
|
||||
## Issues
|
||||
|
||||
### `to_not_convert_key` parameter pollutes the base class interface
|
||||
|
||||
`quantization.py:271-273` — The new `to_not_convert_key` parameter on `_test_quantization_modules_to_not_convert` exists solely for BnB's naming quirk (`llm_int8_skip_modules` vs `modules_to_not_convert`). Every other backend uses the default. This leaks a BnB-specific detail into the shared base method.
|
||||
|
||||
BnB already has its own `test_bnb_modules_to_not_convert` that could handle the key translation locally — either by building the correct `config_kwargs` with `llm_int8_skip_modules` before calling `_create_quantized_model` directly, or by overriding the test. This keeps the base method clean and isolates BnB's naming quirk in `BitsAndBytesTesterMixin` where it belongs.
|
||||
|
||||
### Code duplication in TorchAO `test_torchao_modules_to_not_convert`
|
||||
|
||||
`quantization.py:915-950` — The TorchAO test inlines ~30 lines from `_test_quantization_modules_to_not_convert` to skip the memory footprint comparison. If the base method is updated in the future, this copy won't get the fix. Consider parameterizing the base method instead:
|
||||
|
||||
```python
|
||||
def _test_quantization_modules_to_not_convert(
|
||||
self, config_kwargs, modules_to_not_convert, check_memory_footprint=True,
|
||||
):
|
||||
# ... existing module-walking logic ...
|
||||
|
||||
if check_memory_footprint:
|
||||
# Compare memory footprint with fully quantized model
|
||||
...
|
||||
```
|
||||
|
||||
Then TorchAO could simply call:
|
||||
```python
|
||||
self._test_quantization_modules_to_not_convert(
|
||||
TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude,
|
||||
check_memory_footprint=False,
|
||||
)
|
||||
```
|
||||
|
||||
### TorchAO imports inside method body
|
||||
|
||||
`quantization.py:822-823` — The `torchao` imports are placed inside `_verify_if_layer_quantized`. While functional (avoids import errors when torchao isn't installed), these could be placed at module level under the existing `is_torchao_available()` guard for consistency with how `bnb` and `QLinear` imports are handled. Minor style point.
|
||||
|
||||
### `_is_module_quantized` callers not updated
|
||||
|
||||
`quantization.py:368` — The `_test_dequantize` method still calls `self._is_module_quantized(module)` without `quant_config_kwargs`. This happens to work correctly (for BnB, checking `Int8Params` after dequantization correctly returns False; for TorchAO, the weight won't be an `AffineQuantizedTensor`), but it means BnB dequantize for 4-bit models asserts the weight is not `Int8Params` rather than asserting it's not `Params4bit`. Consider updating for correctness.
|
||||
|
||||
### Missing GGUF test coverage
|
||||
|
||||
GGUF's `GGUFTesterMixin` doesn't have a `test_gguf_modules_to_not_convert` method. If GGUF is expected to support `modules_to_not_convert`, a test should be added. If not, a comment explaining why would be helpful.
|
||||
144
pr_review/pr_12700_flashpack.md
Normal file
144
pr_review/pr_12700_flashpack.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# PR #12700 — FlashPack Integration Review
|
||||
|
||||
**URL**: https://github.com/huggingface/diffusers/pull/12700
|
||||
**State**: OPEN
|
||||
**Branch**: `flashpack` → `main`
|
||||
|
||||
## Summary
|
||||
|
||||
Adds FlashPack as a new weight serialization format for faster model loading. FlashPack packs model weights into a single contiguous file (`model.flashpack`) that can be loaded efficiently, especially for larger models. The PR integrates it across `ModelMixin` (save/load), `DiffusionPipeline` (save/load/download), and supporting utilities.
|
||||
|
||||
## Files Changed
|
||||
|
||||
- `setup.py` / `dependency_versions_table.py` — add `flashpack` dependency
|
||||
- `src/diffusers/utils/constants.py` — `FLASHPACK_WEIGHTS_NAME`, `FLASHPACK_FILE_EXTENSION`
|
||||
- `src/diffusers/utils/import_utils.py` — `is_flashpack_available()`
|
||||
- `src/diffusers/utils/__init__.py` — re-exports
|
||||
- `src/diffusers/models/model_loading_utils.py` — `load_flashpack_checkpoint()`, dispatch in `load_state_dict()`
|
||||
- `src/diffusers/models/modeling_utils.py` — `save_pretrained(use_flashpack=...)`, `from_pretrained(use_flashpack=..., flashpack_kwargs=...)`
|
||||
- `src/diffusers/pipelines/pipeline_utils.py` — pipeline-level `save_pretrained`, `from_pretrained`, `download` with `use_flashpack`
|
||||
- `src/diffusers/pipelines/pipeline_loading_utils.py` — `load_sub_model`, `_get_ignore_patterns`, `get_class_obj_and_candidates`, `filter_model_files`
|
||||
|
||||
---
|
||||
|
||||
## Issues
|
||||
|
||||
### 1. `use_flashpack=True` default in `DiffusionPipeline.download()`
|
||||
|
||||
```python
|
||||
# pipeline_utils.py, in download()
|
||||
use_flashpack = kwargs.pop("use_flashpack", True)
|
||||
```
|
||||
|
||||
This defaults to `True`, meaning `download()` will always try to download FlashPack files by default. Every other call site defaults to `False`. This looks like a bug — it would change download behavior for all users even if they never asked for FlashPack. Should be `False`.
|
||||
|
||||
### 2. `load_flashpack_checkpoint` is unused in the `from_pretrained` path
|
||||
|
||||
`load_flashpack_checkpoint()` is added to `model_loading_utils.py` and wired into `load_state_dict()`. However, in `ModelMixin.from_pretrained`, when `use_flashpack=True`, the code **early-returns** after calling `flashpack.mixin.assign_from_file()` directly — it never goes through `load_state_dict()`. So `load_flashpack_checkpoint` is dead code in the `from_pretrained` flow. Either:
|
||||
- Remove it if FlashPack always uses its own assign path, or
|
||||
- Use it consistently (load state dict → assign to model, like safetensors/pickle)
|
||||
|
||||
### 3. `resolved_model_file` may be undefined when `use_flashpack=True` and file fetch fails
|
||||
|
||||
```python
|
||||
# modeling_utils.py, from_pretrained
|
||||
elif use_flashpack:
|
||||
try:
|
||||
resolved_model_file = _get_model_file(...)
|
||||
except IOError as e:
|
||||
logger.error(...)
|
||||
if not allow_pickle:
|
||||
raise
|
||||
logger.warning("Defaulting to unsafe serialization...")
|
||||
```
|
||||
|
||||
If the `IOError` is caught and `allow_pickle` is truthy, `resolved_model_file` is never set but is used later at `flashpack.mixin.assign_from_file(model=model, path=resolved_model_file[0], ...)`. This would crash with `NameError` or `UnboundLocalError`. The fallback logic (copied from the safetensors block) doesn't make sense for FlashPack — there's no pickle fallback for FlashPack. The `except` block should just re-raise unconditionally.
|
||||
|
||||
### 4. `resolved_model_file[0]` assumes a list, but `_get_model_file` returns a string
|
||||
|
||||
```python
|
||||
flashpack.mixin.assign_from_file(
|
||||
model=model,
|
||||
path=resolved_model_file[0], # indexing into a string
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
`_get_model_file` returns a single file path (string), not a list. `resolved_model_file[0]` would give the first character of the path. Should be just `resolved_model_file`.
|
||||
|
||||
### 5. `device_map` handling assumes `device_map[""]` exists
|
||||
|
||||
```python
|
||||
flashpack_device = device_map[""]
|
||||
```
|
||||
|
||||
`device_map` can be a dict with arbitrary keys (layer names, module names), not just `{"": device}`. This would raise `KeyError` for any non-trivial device map. Should handle the general case or document the constraint.
|
||||
|
||||
### 6. `FlashPack` prefix stripping in `get_class_obj_and_candidates` is unexplained
|
||||
|
||||
```python
|
||||
if class_name.startswith("FlashPack"):
|
||||
class_name = class_name.removeprefix("FlashPack")
|
||||
```
|
||||
|
||||
This is injected into a general-purpose utility function with no explanation of when/why a class name would have a `FlashPack` prefix. This seems like it handles a specific config format but there's no corresponding code that writes `FlashPack`-prefixed class names. If this is for some external convention, it should be documented. If not needed, remove it.
|
||||
|
||||
### 7. Duplicated availability check pattern
|
||||
|
||||
The `is_flashpack_available()` check + import + error message pattern is repeated 3 times:
|
||||
- `load_flashpack_checkpoint()` in `model_loading_utils.py`
|
||||
- `save_pretrained()` in `modeling_utils.py`
|
||||
- `from_pretrained()` in `modeling_utils.py`
|
||||
|
||||
Each has slightly different wording. Should be consolidated — e.g., a helper or just use a single `require_flashpack()` function, consistent with how other optional deps are handled.
|
||||
|
||||
### 8. `save_pretrained` error message says "load" instead of "save"
|
||||
|
||||
```python
|
||||
# modeling_utils.py, save_pretrained, use_flashpack=True branch
|
||||
raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.")
|
||||
```
|
||||
|
||||
This is in the **save** path, but the message says "load". Should say "save".
|
||||
|
||||
### 9. No `config.json` saved alongside FlashPack weights in `save_pretrained`
|
||||
|
||||
When `use_flashpack=True` in `ModelMixin.save_pretrained`, the model config is saved normally at the top of the method, but the FlashPack branch calls `flashpack.serialization.pack_to_file()` with `target_dtype=self.dtype`. It's not clear if FlashPack's own `config.json` (mentioned in the benchmark script as `flashpack_config.json`) is the same as diffusers' `config.json`. If they're different files, loading back with `from_pretrained(use_flashpack=True)` might fail to reconstruct the model architecture since `from_config` needs the diffusers config.
|
||||
|
||||
### 10. `output_loading_info` warning placement
|
||||
|
||||
```python
|
||||
if output_loading_info:
|
||||
logger.warning("`output_loading_info` is not supported with FlashPack.")
|
||||
return model, {}
|
||||
```
|
||||
|
||||
This returns an empty dict silently. The warning is fine, but returning `{}` instead of a proper `loading_info` structure (with `missing_keys`, `unexpected_keys`, etc.) could break code that destructures the result.
|
||||
|
||||
### 11. No tests included
|
||||
|
||||
The PR has no test files. At minimum there should be:
|
||||
- Unit tests for `load_flashpack_checkpoint` (mocking `flashpack`)
|
||||
- Unit tests for save/load roundtrip with `use_flashpack=True`
|
||||
- Integration test for pipeline save/load
|
||||
|
||||
### 12. FlashPack doesn't support sharding
|
||||
|
||||
The `save_pretrained` FlashPack branch ignores `max_shard_size` entirely and always saves a single file. This is fine for the format but should either:
|
||||
- Log a warning if `max_shard_size` is explicitly set alongside `use_flashpack=True`
|
||||
- Document this limitation
|
||||
|
||||
---
|
||||
|
||||
## Minor Issues
|
||||
|
||||
- The benchmark in the PR description shows FlashPack is actually **slower** for fp16 SD v1.5 (0.95x). The claimed benefit is only for bf16. This should be prominently noted.
|
||||
- `FLASHPACK_WEIGHTS_NAME = "model.flashpack"` breaks the diffusers naming convention (`diffusion_pytorch_model.*` for other formats).
|
||||
- The PR modifies `_get_ignore_patterns` but doesn't handle the case where both `use_safetensors` and `use_flashpack` are True.
|
||||
- `filter_model_files` adds `FLASHPACK_WEIGHTS_NAME` to the known weights list but there are no corresponding tests for this filtering.
|
||||
|
||||
---
|
||||
|
||||
## Verdict
|
||||
|
||||
The PR needs significant work before it's mergeable. The critical issues are the `use_flashpack=True` default in `download()`, the `resolved_model_file[0]` indexing bug, the dead code path with `load_flashpack_checkpoint`, and the lack of tests. The integration pattern also doesn't feel consistent with how other formats (safetensors, GGUF) are integrated — FlashPack bypasses the standard state dict loading path entirely via its own `assign_from_file`, making it a special case that's harder to maintain.
|
||||
286
pr_review/teacache_pr_12652_review.md
Normal file
286
pr_review/teacache_pr_12652_review.md
Normal file
@@ -0,0 +1,286 @@
|
||||
# TeaCache PR #12652 Review Notes
|
||||
|
||||
## PR Overview
|
||||
|
||||
- **PR**: https://github.com/huggingface/diffusers/pull/12652
|
||||
- **Title**: Implement TeaCache
|
||||
- **Author**: LawJarp-A (Prajwal A)
|
||||
- **Status**: Open
|
||||
- **Changes**: +1335 / -22 lines across 6 files
|
||||
|
||||
### What is TeaCache?
|
||||
|
||||
[TeaCache](https://huggingface.co/papers/2411.19108) (Timestep Embedding Aware Cache) is a training-free caching technique that speeds up diffusion model inference by **1.5x-2.6x** by reusing transformer block computations when consecutive timestep embeddings are similar.
|
||||
|
||||
### Algorithm
|
||||
|
||||
1. Extract modulated input from first transformer block (after norm1 + timestep embedding)
|
||||
2. Compute relative L1 distance vs previous timestep
|
||||
3. Apply model-specific polynomial rescaling: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]`
|
||||
4. Accumulate rescaled distance across timesteps
|
||||
5. If accumulated < threshold → Reuse cached residual (FAST)
|
||||
6. If accumulated >= threshold → Full transformer pass (SLOW, update cache)
|
||||
|
||||
---
|
||||
|
||||
## The Mid-Forward Intercept Problem
|
||||
|
||||
### Why TeaCache is Model-Specific
|
||||
|
||||
TeaCache needs to intercept **within** a model's forward method, not just at module boundaries:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Model Forward │
|
||||
│ │
|
||||
│ PREPROCESSING (must always run) │
|
||||
│ ├── x_embedder(hidden_states) │
|
||||
│ ├── time_text_embed(timestep, ...) │
|
||||
│ └── context_embedder(encoder_hidden_states) │
|
||||
│ │
|
||||
│ ═══════════════════════════════════════════════════════════│
|
||||
│ DECISION POINT ◄── TeaCache needs to intercept HERE │
|
||||
│ └── Extract: transformer_blocks[0].norm1(hs, temb)[0] │
|
||||
│ ═══════════════════════════════════════════════════════════│
|
||||
│ │
|
||||
│ CACHEABLE REGION (can be skipped if cached) │
|
||||
│ ├── for block in transformer_blocks: ... │
|
||||
│ └── for block in single_transformer_blocks: ... │
|
||||
│ │
|
||||
│ POSTPROCESSING (must always run) │
|
||||
│ ├── norm_out(hidden_states, temb) │
|
||||
│ └── proj_out(hidden_states) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
PyTorch hooks only intercept at **module boundaries** (before/after `forward()`), not within a forward method. The `for` loop over blocks is Python control flow - there's no hook point to skip it.
|
||||
|
||||
### Workaround: Custom Forward Replacement
|
||||
|
||||
The PR replaces the entire model forward with a custom implementation that has cache logic inserted at the right point. This works but requires maintaining separate forward functions for each model.
|
||||
|
||||
---
|
||||
|
||||
## Comparison of Caching Approaches
|
||||
|
||||
### TeaCache vs FirstBlockCache vs FasterCache
|
||||
|
||||
| Aspect | TeaCache | FirstBlockCache | FasterCache |
|
||||
|--------|----------|-----------------|-------------|
|
||||
| **Hook target** | Model forward | Transformer blocks | Attention layers |
|
||||
| **Decision signal** | Modulated input (norm1 output) | Block output residual | Iteration count |
|
||||
| **Where signal is** | Inside first block | Block boundary | Attention output |
|
||||
| **Model-specific needs** | norm1 structure | Block output format | Attention class type |
|
||||
| **Model-agnostic?** | ❌ No | ✅ Yes | ✅ Yes |
|
||||
|
||||
### Why FirstBlockCache is Model-Agnostic
|
||||
|
||||
FirstBlockCache uses the **first block's output residual** as its signal:
|
||||
|
||||
```python
|
||||
# FirstBlockCache: hooks individual blocks
|
||||
def new_forward(self, module, *args, **kwargs):
|
||||
original_hidden_states = args[0]
|
||||
output = self.fn_ref.original_forward(*args, **kwargs) # Run block fully
|
||||
residual = output - original_hidden_states # Signal from OUTPUT
|
||||
should_compute = self._compare_residual(residual)
|
||||
...
|
||||
```
|
||||
|
||||
It doesn't need to understand block internals - just input and output.
|
||||
|
||||
### Why FasterCache is Model-Agnostic
|
||||
|
||||
FasterCache hooks **attention layers** (not blocks) using class type checking:
|
||||
|
||||
```python
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
if isinstance(submodule, _ATTENTION_CLASSES):
|
||||
# Hook this attention module
|
||||
```
|
||||
|
||||
All transformer models use standardized attention classes.
|
||||
|
||||
---
|
||||
|
||||
## Model Architecture Analysis
|
||||
|
||||
### Models That Fit TeaCache Pattern
|
||||
|
||||
Models with `norm1(hidden_states, temb)` returning modulated input:
|
||||
|
||||
| Model | norm1 Signature | Modulation Location | Single Residual |
|
||||
|-------|----------------|---------------------|-----------------|
|
||||
| FLUX 1 | `norm1(hs, emb=temb) → (tensor, gate)` | Inside norm1 | ✅ |
|
||||
| FLUX Kontext | `norm1(hs, emb=temb) → (tensor, gate)` | Inside norm1 | ✅ |
|
||||
| Mochi | `norm1(hs, temb) → (tensor, g, s, g)` | Inside norm1 | ✅ |
|
||||
| Lumina2 | `norm1(hs, temb) → (tensor, gate)` | Inside norm1 | ✅ |
|
||||
|
||||
### Models That DON'T Fit Pattern
|
||||
|
||||
| Model | norm1 Signature | Modulation Location | Issue |
|
||||
|-------|----------------|---------------------|-------|
|
||||
| **FLUX 2** | `norm1(hs) → tensor` | Outside norm1 | Plain LayerNorm |
|
||||
| **Wan** | `norm1(hs) → tensor` | Outside norm1 | Plain LayerNorm |
|
||||
| **ZImage** | `attention_norm1(x) → tensor` | Outside norm1 | Plain LayerNorm |
|
||||
| **CogVideoX** | N/A (uses `emb` directly) | N/A | Dual residual needed |
|
||||
|
||||
### FLUX 1 vs FLUX 2 Architecture Difference
|
||||
|
||||
**FLUX 1** (AdaLayerNorm - modulation inside):
|
||||
```python
|
||||
class FluxTransformerBlock:
|
||||
self.norm1 = AdaLayerNormZero(dim) # Takes temb!
|
||||
|
||||
def forward(self, hidden_states, temb, ...):
|
||||
norm_hs, gate = self.norm1(hidden_states, emb=temb) # Modulation inside
|
||||
```
|
||||
|
||||
**FLUX 2** (Plain LayerNorm - modulation outside):
|
||||
```python
|
||||
class Flux2TransformerBlock:
|
||||
self.norm1 = nn.LayerNorm(dim) # NO temb!
|
||||
|
||||
def forward(self, hidden_states, temb_mod_params_img, ...):
|
||||
(shift_msa, scale_msa, gate_msa), ... = temb_mod_params_img
|
||||
norm_hs = self.norm1(hidden_states) # Plain norm
|
||||
norm_hs = (1 + scale_msa) * norm_hs + shift_msa # Modulation outside
|
||||
```
|
||||
|
||||
FLUX 2 follows the Wan/ZImage pattern and would need a separate custom forward.
|
||||
|
||||
---
|
||||
|
||||
## CogVideoX: The Architectural Outlier
|
||||
|
||||
CogVideoX has two unique requirements that don't fit the pattern:
|
||||
|
||||
### 1. Different Modulated Input Source
|
||||
|
||||
```python
|
||||
# Other models: extract from norm1
|
||||
modulated_inp = block.norm1(hidden_states, temb)[0]
|
||||
|
||||
# CogVideoX: uses timestep embedding directly
|
||||
modulated_inp = emb # Just the embedding, computed before blocks!
|
||||
```
|
||||
|
||||
### 2. Dual Residual Caching
|
||||
|
||||
CogVideoX blocks return and modify TWO tensors:
|
||||
```python
|
||||
def forward(self, hidden_states, encoder_hidden_states, temb, ...):
|
||||
# Both are modified!
|
||||
return hidden_states, encoder_hidden_states
|
||||
```
|
||||
|
||||
Requires caching two residuals:
|
||||
```python
|
||||
state.previous_residual = hs_output - hs_input
|
||||
state.previous_residual_encoder = enc_output - enc_input # Extra!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Recommendations
|
||||
|
||||
### Simplification: FLUX-Only Support
|
||||
|
||||
Given the architectural diversity, recommend supporting only FLUX 1 and FLUX Kontext initially:
|
||||
|
||||
```python
|
||||
_MODEL_CONFIG = {
|
||||
"FluxKontext": {
|
||||
"forward_func": _flux_teacache_forward,
|
||||
"coefficients": [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02],
|
||||
},
|
||||
"Flux": {
|
||||
"forward_func": _flux_teacache_forward,
|
||||
"coefficients": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### What to Remove from PR
|
||||
|
||||
1. **CogVideoX support** - Dual residual architecture doesn't fit
|
||||
2. **Mochi support** - Can be added later if needed
|
||||
3. **Lumina2 support** - Can be added later if needed
|
||||
4. **FLUX 2 support** - Different architecture (plain LayerNorm)
|
||||
|
||||
### Estimated Code Reduction
|
||||
|
||||
| Component | Original (PR) | FLUX-Only |
|
||||
|-----------|---------------|-----------|
|
||||
| Forward functions | 4 (~400 lines) | 1 (~100 lines) |
|
||||
| Model configs | 10 entries | 2 entries |
|
||||
| State fields | 8 | 5 |
|
||||
| Utility functions | 6 | 3 |
|
||||
| **Total teacache.py** | ~900 lines | ~350 lines |
|
||||
|
||||
### Simplified State
|
||||
|
||||
```python
|
||||
class TeaCacheState(BaseState):
|
||||
def __init__(self):
|
||||
self.cnt = 0
|
||||
self.num_steps = 0
|
||||
self.accumulated_rel_l1_distance = 0.0
|
||||
self.previous_modulated_input = None
|
||||
self.previous_residual = None
|
||||
# Removed: previous_residual_encoder (CogVideoX)
|
||||
# Removed: cache_dict (Lumina2)
|
||||
# Removed: uncond_seq_len (Lumina2)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Why Custom Forwards Are Necessary
|
||||
|
||||
Despite the maintenance burden, custom forwards are the pragmatic approach for TeaCache because:
|
||||
|
||||
1. **Mid-forward intercept required** - Need to access `norm1` output before blocks run
|
||||
2. **Architectural diversity** - Models differ in where/how modulation happens
|
||||
3. **Block-level hooks insufficient** - Can't extract modulated input from block hooks
|
||||
4. **Algorithm requirements** - TeaCache paper specifically uses modulated input as signal
|
||||
|
||||
### Alternative Approaches Considered
|
||||
|
||||
| Approach | Works? | Issue |
|
||||
|----------|--------|-------|
|
||||
| Block-level hooks (like FirstBlockCache) | ❌ | Can't access modulated input inside block |
|
||||
| Attention-level hooks (like FasterCache) | ❌ | Different algorithm, not TeaCache |
|
||||
| Hook norm1 directly | ⚠️ | norm1 interface varies per model |
|
||||
| Hybrid (FirstBlockCache signal + TeaCache algorithm) | ⚠️ | Loses "optimal" signal per paper |
|
||||
|
||||
---
|
||||
|
||||
## PR Code Quality Issues (From Review)
|
||||
|
||||
1. **torch.compile incompatibility** - `.item()` calls in `_compute_rel_l1_distance` create graph breaks
|
||||
2. **Boundary check bug** - `state.cnt == state.num_steps - 1` when `num_steps=0` evaluates to `-1`
|
||||
3. **Incomplete Lumina2 state reset** - `cache_dict` and `uncond_seq_len` not reset
|
||||
4. **Model auto-detection fragility** - Substring matching relies on iteration order
|
||||
|
||||
---
|
||||
|
||||
## Extension Path
|
||||
|
||||
If support for additional models is needed later:
|
||||
|
||||
1. **Mochi** - Same pattern as FLUX, just add coefficients and reuse `_flux_teacache_forward` or create similar
|
||||
2. **Lumina2** - Same pattern but needs per-sequence-length caching for CFG
|
||||
3. **FLUX 2 / Wan / ZImage** - Need separate forwards that extract modulated input differently
|
||||
4. **CogVideoX** - Needs dual residual support, significant additional complexity
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
- **TeaCache requires custom forwards** due to mid-forward intercept requirement
|
||||
- **FLUX 1 + FLUX Kontext only** is the recommended scope for initial implementation
|
||||
- **~60% code reduction** possible by removing unsupported models
|
||||
- **Clear extension path** for adding models later as needed
|
||||
- **Maintenance burden** is acceptable given the architectural constraints
|
||||
129
release_notes/v0.37.0.md
Normal file
129
release_notes/v0.37.0.md
Normal file
@@ -0,0 +1,129 @@
|
||||
# Diffusers v0.37.0 Release Notes
|
||||
|
||||
*Release based on 191 commits since v0.36.0*
|
||||
|
||||
---
|
||||
|
||||
## Highlights
|
||||
|
||||
- **Modular Pipelines overhaul**: Major investment in the modular pipeline system with explicit workflow support, improved loaders, documentation, and modular implementations for Wan, Flux2, Z-Image, Qwen, and Mellon pipelines.
|
||||
- **New pipelines and models**: Cosmos Predict2.5, LTX 2.0 Video, LongCat-Image, Fibo Edit, Z-Image Omni Base, and more.
|
||||
- **Distributed inference improvements**: Unified Sequence Parallel attention, Ulysses Anything Attention, and context parallel support in native flash attention.
|
||||
- **Python 3.8 dropped**: Sunset Python 3.8 and cleaned up explicit `typing` exports.
|
||||
|
||||
---
|
||||
|
||||
## New Pipelines and Models
|
||||
|
||||
- **Cosmos Predict2.5**: Base inference pipeline, scheduler, and checkpoint conversion; 14B model support (#12852, #12863)
|
||||
- **Cosmos Transfer2.5**: General transfer pipelines for segmentation, depth, blur, and edge (#13066)
|
||||
- **LTX 2.0 Video Pipelines**: New video generation pipelines (#12915), distilled checkpoint support (#12934), single-file loading (#12983), LoRA support (#12933), long multi-prompt (#12614)
|
||||
- **LongCat-Image**: New pipeline with offloading/quantization support and regional compile acceleration (#12828, #12963, #12699, #13019, #13021)
|
||||
- **Fibo Edit Pipeline**: New editing pipeline (#12930)
|
||||
- **Z-Image Omni Base**: New implementation (#12857)
|
||||
- **Z-Image Turbo ControlNet**: ControlNet support for Z-Image Turbo (#12792)
|
||||
- **Z-Image Inpaint Pipeline**: Inpainting support (#13006)
|
||||
- **Z-Image ControlNet CFG**: CFG support for Z-Image ControlNet (#13080)
|
||||
- **Chroma Inpaint Pipeline**: New inpainting pipeline for Chroma (#12848)
|
||||
- **Flux2 Klein**: New model variant (#12982)
|
||||
- **Qwen Image Edit 2511**: New editing support (#12839)
|
||||
- **Qwen Image Layered Support** (#12853)
|
||||
|
||||
## Modular Pipelines
|
||||
|
||||
- Explicit workflow support for modular pipelines (#13028)
|
||||
- Modular implementations for: Wan (#13063), Flux2 (#12763), Z-Image (#12808), Qwen (#12872), Mellon (#12978, #12924, #13051)
|
||||
- Improved loader support (#13025)
|
||||
- Custom block tests (#12557)
|
||||
- Auto-docstring generation and documentation refactors (#12958)
|
||||
- Quick start guide (#13029)
|
||||
- Guard `ModularPipeline.blocks` attribute (#13014)
|
||||
- Better docstrings and template pipeline card (#13072, #12932)
|
||||
|
||||
## Core Improvements
|
||||
|
||||
- **Device-type device maps with offloading support** (#12811)
|
||||
- **`disable_mmap` in pipeline `from_pretrained`** (#12854)
|
||||
- **`apply_lora_scale` helper** to remove boilerplate (#12994)
|
||||
- **MagCache support**: Caching mechanism for faster inference (#12744)
|
||||
- **Mambo-G Guidance**: New guider implementation (#12862)
|
||||
- **Laplace Scheduler for DDPM** (#11320)
|
||||
- **Custom sigmas in UniPCMultistepScheduler** (#12109)
|
||||
- **Control-LoRA support** (#10686)
|
||||
- **Latent Perceptual Loss (LPL) for SDXL** (#11573)
|
||||
- **MultiControlNet support for SD3 Inpainting** (#11251)
|
||||
- Remove 8-bit device restriction (#12972)
|
||||
- Graceful error for unsupported attn-backend / context-parallel combos (#12832)
|
||||
- Handle progress bar and logging in distributed environments (#12806)
|
||||
- Remove unneeded autoencoder methods from `AutoencoderMixin` subclasses (#12873)
|
||||
- Remove k-diffusion support (#13152)
|
||||
- Flag Flax schedulers as deprecated (#13031)
|
||||
|
||||
## Distributed Inference
|
||||
|
||||
- **Unified Sequence Parallel attention** (#12693)
|
||||
- **Ulysses Anything Attention** (#12996)
|
||||
- **Context parallel in native flash attention** (#12829)
|
||||
- NPU Ulysses attention support (#12919)
|
||||
- Fix Wan 2.1 I2V context parallel (#12909)
|
||||
- Fix Qwen-Image context parallel (#12970)
|
||||
|
||||
## LoRA
|
||||
|
||||
- Z-Image LoRA training (#13056)
|
||||
- Fix non-diffusers LoRA key handling for Flux2 (#13119)
|
||||
- Fix LoRA loading for Flux2 Klein with adaptive block enumeration (#13030)
|
||||
- Fix wrong LTX2 LoRA mixin (#13144)
|
||||
|
||||
## Bug Fixes
|
||||
|
||||
- Fix QwenImageEditPlus on NPU (#13017)
|
||||
- Fix MT5Tokenizer → use `T5Tokenizer` for Transformers v5.0+ compatibility (#12877)
|
||||
- Fix Wan/WanI2V patchification (#13038)
|
||||
- Fix LTX-2 inference with `num_videos_per_prompt > 1` and CFG (#13121)
|
||||
- Fix Flux2 img2img prediction (#12855)
|
||||
- Fix QwenImage `txt_seq_lens` handling (#12702)
|
||||
- Fix `prefix_token_len` bug (#12845)
|
||||
- Fix ftfy imports in Wan and SkyReels-V2 (#12314, #13113)
|
||||
- Fix `is_fsdp` determination (#12960)
|
||||
- Fix GLM-Image `get_image_features` API (#13052)
|
||||
- Fix Wan 2.2 when either transformer isn't present (#13055)
|
||||
- Fix guider issue (#13147)
|
||||
- Fix torchao quantizer for new versions (#12901)
|
||||
- Fix GGUF for unquantized types with unquantize kernels (#12498)
|
||||
- Make Qwen hidden states contiguous for torchao (#13081)
|
||||
- Make Flux hidden states contiguous (#13068)
|
||||
- Fix Kandinsky 5 hardcoded CUDA autocast (#12814)
|
||||
- Fix `aiter` availability check (#13059)
|
||||
- Fix attention mask check for unsupported backends (#12892)
|
||||
- Allow `prompt` and `prior_token_ids` simultaneously in `GlmImagePipeline` (#13092)
|
||||
- GLM-Image batch support (#13007)
|
||||
- Cosmos 2.5 Video2World frame extraction fix (#13018)
|
||||
- ResNet: only use contiguous in training mode (#12977)
|
||||
|
||||
## Testing and CI
|
||||
|
||||
- Refactor model tests (#12822)
|
||||
- Refactor Wan model tests (#13082)
|
||||
- Accept `recompile_limit` from user in tests (#13150)
|
||||
- CodeQL workflow for security analysis (#12917)
|
||||
- Upgrade GitHub Actions for Node 24 compatibility (#12865, #12866)
|
||||
- Fix `setuptools` / `pkg_resources` CI bugs (#13129, #13132)
|
||||
- CUDA 12.9 upgrade (#13045)
|
||||
- FSDP option for Flux2 (#12860)
|
||||
|
||||
## Documentation
|
||||
|
||||
- Custom code AutoModel guide (#13099)
|
||||
- Remote inference docs (#12372)
|
||||
- Improved distributed inference docs (#12810, #12827, #12971)
|
||||
- Improved caching docs (#12684)
|
||||
- Numerous scheduler docstring improvements (#12798, #12871, #12928, #12931, #12936, #12992, #13010, #13020, #13023, #13024, #13027, #13044, #13083, #13085, #13122, #13127, #13130)
|
||||
- Various typo and syntax fixes
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
- **Python 3.8 support removed** (#12524)
|
||||
- **k-diffusion removed** (#13152)
|
||||
- **Flax schedulers flagged as deprecated** (#13031)
|
||||
- ControlNet implementations outside the controlnet module removed (#12152)
|
||||
183
scripts/compare_test_coverage.py
Normal file
183
scripts/compare_test_coverage.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compare test coverage between main and model-test-refactor branches
|
||||
for the Flux transformer tests.
|
||||
|
||||
Usage:
|
||||
python scripts/compare_test_coverage.py
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
|
||||
|
||||
TEST_FILE = "tests/models/transformers/test_models_transformer_flux.py"
|
||||
BRANCHES = ["main", "model-test-refactor"]
|
||||
|
||||
|
||||
def run_command(cmd, capture=True):
|
||||
"""Run a shell command and return output."""
|
||||
result = subprocess.run(cmd, shell=True, capture_output=capture, text=True)
|
||||
return result.stdout, result.stderr, result.returncode
|
||||
|
||||
|
||||
def get_current_branch():
|
||||
"""Get the current git branch name."""
|
||||
stdout, _, _ = run_command("git branch --show-current")
|
||||
return stdout.strip()
|
||||
|
||||
|
||||
def stash_changes():
|
||||
"""Stash any uncommitted changes."""
|
||||
run_command("git stash")
|
||||
|
||||
|
||||
def pop_stash():
|
||||
"""Pop stashed changes."""
|
||||
run_command("git stash pop")
|
||||
|
||||
|
||||
def checkout_branch(branch):
|
||||
"""Checkout a git branch."""
|
||||
_, stderr, code = run_command(f"git checkout {branch}")
|
||||
if code != 0:
|
||||
print(f"Failed to checkout {branch}: {stderr}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def collect_tests(test_file):
|
||||
"""Collect tests from a test file and return test info."""
|
||||
cmd = f"python -m pytest {test_file} --collect-only -q 2>/dev/null"
|
||||
stdout, stderr, code = run_command(cmd)
|
||||
|
||||
tests = []
|
||||
for line in stdout.strip().split("\n"):
|
||||
if "::" in line and not line.startswith("="):
|
||||
tests.append(line.strip())
|
||||
|
||||
return tests
|
||||
|
||||
|
||||
def run_tests_verbose(test_file):
|
||||
"""Run tests and capture pass/skip/fail status."""
|
||||
cmd = f"python -m pytest {test_file} -v --tb=no 2>&1"
|
||||
stdout, _, _ = run_command(cmd)
|
||||
|
||||
results = {"passed": [], "skipped": [], "failed": [], "errors": []}
|
||||
|
||||
for line in stdout.split("\n"):
|
||||
if " PASSED" in line:
|
||||
test_name = line.split(" PASSED")[0].strip()
|
||||
results["passed"].append(test_name)
|
||||
elif " SKIPPED" in line:
|
||||
test_name = line.split(" SKIPPED")[0].strip()
|
||||
reason = ""
|
||||
if "SKIPPED" in line and "[" in line:
|
||||
reason = line.split("[")[-1].rstrip("]") if "[" in line else ""
|
||||
results["skipped"].append((test_name, reason))
|
||||
elif " FAILED" in line:
|
||||
test_name = line.split(" FAILED")[0].strip()
|
||||
results["failed"].append(test_name)
|
||||
elif " ERROR" in line:
|
||||
test_name = line.split(" ERROR")[0].strip()
|
||||
results["errors"].append(test_name)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compare_results(main_results, pr_results):
|
||||
"""Compare test results between branches."""
|
||||
print("\n" + "=" * 70)
|
||||
print("COVERAGE COMPARISON REPORT")
|
||||
print("=" * 70)
|
||||
|
||||
print("\n## Test Counts")
|
||||
print(f"{'Category':<20} {'main':<15} {'PR':<15} {'Diff':<10}")
|
||||
print("-" * 60)
|
||||
|
||||
for category in ["passed", "skipped", "failed", "errors"]:
|
||||
main_count = len(main_results[category])
|
||||
pr_count = len(pr_results[category])
|
||||
diff = pr_count - main_count
|
||||
diff_str = f"+{diff}" if diff > 0 else str(diff)
|
||||
print(f"{category:<20} {main_count:<15} {pr_count:<15} {diff_str:<10}")
|
||||
|
||||
main_tests = set(main_results["passed"] + [t[0] for t in main_results["skipped"]])
|
||||
pr_tests = set(pr_results["passed"] + [t[0] for t in pr_results["skipped"]])
|
||||
|
||||
missing_in_pr = main_tests - pr_tests
|
||||
new_in_pr = pr_tests - main_tests
|
||||
|
||||
if missing_in_pr:
|
||||
print("\n## Tests in main but MISSING in PR:")
|
||||
for test in sorted(missing_in_pr):
|
||||
print(f" - {test}")
|
||||
|
||||
if new_in_pr:
|
||||
print("\n## NEW tests in PR (not in main):")
|
||||
for test in sorted(new_in_pr):
|
||||
print(f" + {test}")
|
||||
|
||||
print("\n## Skipped Tests Comparison")
|
||||
main_skipped = {t[0]: t[1] for t in main_results["skipped"]}
|
||||
pr_skipped = {t[0]: t[1] for t in pr_results["skipped"]}
|
||||
|
||||
newly_skipped = set(pr_skipped.keys()) - set(main_skipped.keys())
|
||||
no_longer_skipped = set(main_skipped.keys()) - set(pr_skipped.keys())
|
||||
|
||||
if newly_skipped:
|
||||
print("\nNewly skipped in PR:")
|
||||
for test in sorted(newly_skipped):
|
||||
print(f" - {test}: {pr_skipped.get(test, 'unknown reason')}")
|
||||
|
||||
if no_longer_skipped:
|
||||
print("\nNo longer skipped in PR (now running):")
|
||||
for test in sorted(no_longer_skipped):
|
||||
print(f" + {test}")
|
||||
|
||||
if not newly_skipped and not no_longer_skipped:
|
||||
print("\nNo changes in skipped tests.")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
|
||||
def main():
|
||||
original_branch = get_current_branch()
|
||||
print(f"Current branch: {original_branch}")
|
||||
|
||||
results = {}
|
||||
|
||||
print("Stashing uncommitted changes...")
|
||||
stash_changes()
|
||||
|
||||
try:
|
||||
for branch in BRANCHES:
|
||||
print(f"\n--- Analyzing branch: {branch} ---")
|
||||
|
||||
if not checkout_branch(branch):
|
||||
print(f"Skipping {branch}")
|
||||
continue
|
||||
|
||||
print(f"Collecting and running tests from {TEST_FILE}...")
|
||||
results[branch] = run_tests_verbose(TEST_FILE)
|
||||
|
||||
print(f" Passed: {len(results[branch]['passed'])}")
|
||||
print(f" Skipped: {len(results[branch]['skipped'])}")
|
||||
print(f" Failed: {len(results[branch]['failed'])}")
|
||||
|
||||
checkout_branch(original_branch)
|
||||
|
||||
if "main" in results and "model-test-refactor" in results:
|
||||
compare_results(results["main"], results["model-test-refactor"])
|
||||
else:
|
||||
print("Could not compare - missing results from one or both branches")
|
||||
|
||||
finally:
|
||||
print("\nRestoring stashed changes...")
|
||||
pop_stash()
|
||||
|
||||
checkout_branch(original_branch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,403 +0,0 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
|
||||
from diffusers import AutoencoderRAE
|
||||
|
||||
|
||||
DECODER_CONFIGS = {
|
||||
"ViTB": {
|
||||
"decoder_hidden_size": 768,
|
||||
"decoder_intermediate_size": 3072,
|
||||
"decoder_num_attention_heads": 12,
|
||||
"decoder_num_hidden_layers": 12,
|
||||
},
|
||||
"ViTL": {
|
||||
"decoder_hidden_size": 1024,
|
||||
"decoder_intermediate_size": 4096,
|
||||
"decoder_num_attention_heads": 16,
|
||||
"decoder_num_hidden_layers": 24,
|
||||
},
|
||||
"ViTXL": {
|
||||
"decoder_hidden_size": 1152,
|
||||
"decoder_intermediate_size": 4096,
|
||||
"decoder_num_attention_heads": 16,
|
||||
"decoder_num_hidden_layers": 28,
|
||||
},
|
||||
}
|
||||
|
||||
ENCODER_DEFAULT_NAME_OR_PATH = {
|
||||
"dinov2": "facebook/dinov2-with-registers-base",
|
||||
"siglip2": "google/siglip2-base-patch16-256",
|
||||
"mae": "facebook/vit-mae-base",
|
||||
}
|
||||
|
||||
ENCODER_HIDDEN_SIZE = {
|
||||
"dinov2": 768,
|
||||
"siglip2": 768,
|
||||
"mae": 768,
|
||||
}
|
||||
|
||||
ENCODER_PATCH_SIZE = {
|
||||
"dinov2": 14,
|
||||
"siglip2": 16,
|
||||
"mae": 16,
|
||||
}
|
||||
|
||||
DEFAULT_DECODER_SUBDIR = {
|
||||
"dinov2": "decoders/dinov2/wReg_base",
|
||||
"mae": "decoders/mae/base_p16",
|
||||
"siglip2": "decoders/siglip2/base_p16_i256",
|
||||
}
|
||||
|
||||
DEFAULT_STATS_SUBDIR = {
|
||||
"dinov2": "stats/dinov2/wReg_base",
|
||||
"mae": "stats/mae/base_p16",
|
||||
"siglip2": "stats/siglip2/base_p16_i256",
|
||||
}
|
||||
|
||||
DECODER_FILE_CANDIDATES = ("dinov2_decoder.pt", "model.pt")
|
||||
STATS_FILE_CANDIDATES = ("stat.pt",)
|
||||
|
||||
|
||||
def dataset_case_candidates(name: str) -> tuple[str, ...]:
|
||||
return (name, name.lower(), name.upper(), name.title(), "imagenet1k", "ImageNet1k")
|
||||
|
||||
|
||||
class RepoAccessor:
|
||||
def __init__(self, repo_or_path: str, cache_dir: str | None = None):
|
||||
self.repo_or_path = repo_or_path
|
||||
self.cache_dir = cache_dir
|
||||
self.local_root: Path | None = None
|
||||
self.repo_id: str | None = None
|
||||
self.repo_files: set[str] | None = None
|
||||
|
||||
root = Path(repo_or_path)
|
||||
if root.exists() and root.is_dir():
|
||||
self.local_root = root
|
||||
else:
|
||||
self.repo_id = repo_or_path
|
||||
self.repo_files = set(HfApi().list_repo_files(repo_or_path))
|
||||
|
||||
def exists(self, relative_path: str) -> bool:
|
||||
relative_path = relative_path.replace("\\", "/")
|
||||
if self.local_root is not None:
|
||||
return (self.local_root / relative_path).is_file()
|
||||
return relative_path in self.repo_files
|
||||
|
||||
def fetch(self, relative_path: str) -> Path:
|
||||
relative_path = relative_path.replace("\\", "/")
|
||||
if self.local_root is not None:
|
||||
return self.local_root / relative_path
|
||||
downloaded = hf_hub_download(repo_id=self.repo_id, filename=relative_path, cache_dir=self.cache_dir)
|
||||
return Path(downloaded)
|
||||
|
||||
|
||||
def unwrap_state_dict(maybe_wrapped: dict[str, Any]) -> dict[str, Any]:
|
||||
state_dict = maybe_wrapped
|
||||
for k in ("model", "module", "state_dict"):
|
||||
if isinstance(state_dict, dict) and k in state_dict and isinstance(state_dict[k], dict):
|
||||
state_dict = state_dict[k]
|
||||
|
||||
out = dict(state_dict)
|
||||
if len(out) > 0 and all(key.startswith("module.") for key in out):
|
||||
out = {key[len("module.") :]: value for key, value in out.items()}
|
||||
if len(out) > 0 and all(key.startswith("decoder.") for key in out):
|
||||
out = {key[len("decoder.") :]: value for key, value in out.items()}
|
||||
return out
|
||||
|
||||
|
||||
def remap_decoder_attention_keys_for_diffusers(state_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Map official RAE decoder attention key layout to diffusers Attention layout used by AutoencoderRAE decoder.
|
||||
|
||||
Example mappings:
|
||||
- `...attention.attention.query.*` -> `...attention.to_q.*`
|
||||
- `...attention.attention.key.*` -> `...attention.to_k.*`
|
||||
- `...attention.attention.value.*` -> `...attention.to_v.*`
|
||||
- `...attention.output.dense.*` -> `...attention.to_out.0.*`
|
||||
"""
|
||||
remapped: dict[str, Any] = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
new_key = new_key.replace(".attention.attention.query.", ".attention.to_q.")
|
||||
new_key = new_key.replace(".attention.attention.key.", ".attention.to_k.")
|
||||
new_key = new_key.replace(".attention.attention.value.", ".attention.to_v.")
|
||||
new_key = new_key.replace(".attention.output.dense.", ".attention.to_out.0.")
|
||||
remapped[new_key] = value
|
||||
return remapped
|
||||
|
||||
|
||||
def resolve_decoder_file(
|
||||
accessor: RepoAccessor, encoder_type: str, variant: str, decoder_checkpoint: str | None
|
||||
) -> str:
|
||||
if decoder_checkpoint is not None:
|
||||
if accessor.exists(decoder_checkpoint):
|
||||
return decoder_checkpoint
|
||||
raise FileNotFoundError(f"Decoder checkpoint not found: {decoder_checkpoint}")
|
||||
|
||||
base = f"{DEFAULT_DECODER_SUBDIR[encoder_type]}/{variant}"
|
||||
for name in DECODER_FILE_CANDIDATES:
|
||||
candidate = f"{base}/{name}"
|
||||
if accessor.exists(candidate):
|
||||
return candidate
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Could not find decoder checkpoint under `{base}`. Tried: {list(DECODER_FILE_CANDIDATES)}"
|
||||
)
|
||||
|
||||
|
||||
def resolve_stats_file(
|
||||
accessor: RepoAccessor,
|
||||
encoder_type: str,
|
||||
dataset_name: str,
|
||||
stats_checkpoint: str | None,
|
||||
) -> str | None:
|
||||
if stats_checkpoint is not None:
|
||||
if accessor.exists(stats_checkpoint):
|
||||
return stats_checkpoint
|
||||
raise FileNotFoundError(f"Stats checkpoint not found: {stats_checkpoint}")
|
||||
|
||||
base = DEFAULT_STATS_SUBDIR[encoder_type]
|
||||
for dataset in dataset_case_candidates(dataset_name):
|
||||
for name in STATS_FILE_CANDIDATES:
|
||||
candidate = f"{base}/{dataset}/{name}"
|
||||
if accessor.exists(candidate):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_latent_stats(stats_obj: Any) -> tuple[Any | None, Any | None]:
|
||||
if not isinstance(stats_obj, dict):
|
||||
return None, None
|
||||
|
||||
if "latents_mean" in stats_obj or "latents_std" in stats_obj:
|
||||
return stats_obj.get("latents_mean", None), stats_obj.get("latents_std", None)
|
||||
|
||||
mean = stats_obj.get("mean", None)
|
||||
var = stats_obj.get("var", None)
|
||||
if mean is None and var is None:
|
||||
return None, None
|
||||
|
||||
latents_std = None
|
||||
if var is not None:
|
||||
if isinstance(var, torch.Tensor):
|
||||
latents_std = torch.sqrt(var + 1e-5)
|
||||
else:
|
||||
latents_std = torch.sqrt(torch.tensor(var) + 1e-5)
|
||||
return mean, latents_std
|
||||
|
||||
|
||||
def _strip_final_layernorm_affine(state_dict: dict[str, Any], prefix: str = "") -> dict[str, Any]:
|
||||
"""Remove final layernorm weight/bias from encoder state dict.
|
||||
|
||||
RAE uses non-affine layernorm (weight=1, bias=0 is the default identity).
|
||||
Stripping these keys means the model keeps its default init values, which
|
||||
is functionally equivalent to setting elementwise_affine=False.
|
||||
"""
|
||||
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
|
||||
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
|
||||
|
||||
|
||||
def _load_hf_encoder_state_dict(encoder_type: str, encoder_name_or_path: str) -> dict[str, Any]:
|
||||
"""Download the HF encoder and extract the state dict for the inner model."""
|
||||
if encoder_type == "dinov2":
|
||||
from transformers import Dinov2WithRegistersModel
|
||||
|
||||
hf_model = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
|
||||
sd = hf_model.state_dict()
|
||||
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
|
||||
elif encoder_type == "siglip2":
|
||||
from transformers import SiglipModel
|
||||
|
||||
# SiglipModel.vision_model is a SiglipVisionTransformer.
|
||||
# Our Siglip2Encoder wraps it inside SiglipVisionModel which nests it
|
||||
# under .vision_model, so we add the prefix to match the diffusers key layout.
|
||||
hf_model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
|
||||
sd = {f"vision_model.{k}": v for k, v in hf_model.state_dict().items()}
|
||||
return _strip_final_layernorm_affine(sd, prefix="vision_model.post_layernorm.")
|
||||
elif encoder_type == "mae":
|
||||
from transformers import ViTMAEForPreTraining
|
||||
|
||||
hf_model = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
|
||||
sd = hf_model.state_dict()
|
||||
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder_type: {encoder_type}")
|
||||
|
||||
|
||||
def convert(args: argparse.Namespace) -> None:
|
||||
accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir)
|
||||
encoder_name_or_path = args.encoder_name_or_path or ENCODER_DEFAULT_NAME_OR_PATH[args.encoder_type]
|
||||
|
||||
decoder_relpath = resolve_decoder_file(accessor, args.encoder_type, args.variant, args.decoder_checkpoint)
|
||||
stats_relpath = resolve_stats_file(accessor, args.encoder_type, args.dataset_name, args.stats_checkpoint)
|
||||
|
||||
print(f"Using decoder checkpoint: {decoder_relpath}")
|
||||
if stats_relpath is not None:
|
||||
print(f"Using stats checkpoint: {stats_relpath}")
|
||||
else:
|
||||
print("No stats checkpoint found; conversion will proceed without latent stats.")
|
||||
|
||||
if args.dry_run:
|
||||
return
|
||||
|
||||
decoder_path = accessor.fetch(decoder_relpath)
|
||||
decoder_obj = torch.load(decoder_path, map_location="cpu")
|
||||
decoder_state_dict = unwrap_state_dict(decoder_obj)
|
||||
decoder_state_dict = remap_decoder_attention_keys_for_diffusers(decoder_state_dict)
|
||||
|
||||
latents_mean, latents_std = None, None
|
||||
if stats_relpath is not None:
|
||||
stats_path = accessor.fetch(stats_relpath)
|
||||
stats_obj = torch.load(stats_path, map_location="cpu")
|
||||
latents_mean, latents_std = extract_latent_stats(stats_obj)
|
||||
|
||||
decoder_cfg = DECODER_CONFIGS[args.decoder_config_name]
|
||||
|
||||
# Read encoder normalization stats from the HF image processor (only place that downloads encoder info)
|
||||
from transformers import AutoConfig, AutoImageProcessor
|
||||
|
||||
proc = AutoImageProcessor.from_pretrained(encoder_name_or_path)
|
||||
encoder_norm_mean = list(proc.image_mean)
|
||||
encoder_norm_std = list(proc.image_std)
|
||||
|
||||
# Read encoder hidden size and patch size from HF config
|
||||
encoder_hidden_size = ENCODER_HIDDEN_SIZE[args.encoder_type]
|
||||
encoder_patch_size = ENCODER_PATCH_SIZE[args.encoder_type]
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(encoder_name_or_path)
|
||||
# For models like SigLIP that nest vision config
|
||||
if hasattr(hf_config, "vision_config"):
|
||||
hf_config = hf_config.vision_config
|
||||
encoder_hidden_size = hf_config.hidden_size
|
||||
encoder_patch_size = hf_config.patch_size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Load the actual encoder weights from HF to include in the saved model
|
||||
encoder_state_dict = _load_hf_encoder_state_dict(args.encoder_type, encoder_name_or_path)
|
||||
|
||||
# Build model on meta device to avoid double init overhead
|
||||
with torch.device("meta"):
|
||||
model = AutoencoderRAE(
|
||||
encoder_type=args.encoder_type,
|
||||
encoder_hidden_size=encoder_hidden_size,
|
||||
encoder_patch_size=encoder_patch_size,
|
||||
encoder_input_size=args.encoder_input_size,
|
||||
patch_size=args.patch_size,
|
||||
image_size=args.image_size,
|
||||
num_channels=args.num_channels,
|
||||
encoder_norm_mean=encoder_norm_mean,
|
||||
encoder_norm_std=encoder_norm_std,
|
||||
decoder_hidden_size=decoder_cfg["decoder_hidden_size"],
|
||||
decoder_num_hidden_layers=decoder_cfg["decoder_num_hidden_layers"],
|
||||
decoder_num_attention_heads=decoder_cfg["decoder_num_attention_heads"],
|
||||
decoder_intermediate_size=decoder_cfg["decoder_intermediate_size"],
|
||||
latents_mean=latents_mean,
|
||||
latents_std=latents_std,
|
||||
scaling_factor=args.scaling_factor,
|
||||
)
|
||||
|
||||
# Assemble full state dict and load with assign=True
|
||||
full_state_dict = {}
|
||||
|
||||
# Encoder weights (prefixed with "encoder.")
|
||||
for k, v in encoder_state_dict.items():
|
||||
full_state_dict[f"encoder.{k}"] = v
|
||||
|
||||
# Decoder weights (prefixed with "decoder.")
|
||||
for k, v in decoder_state_dict.items():
|
||||
full_state_dict[f"decoder.{k}"] = v
|
||||
|
||||
# Buffers from config
|
||||
full_state_dict["encoder_mean"] = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
|
||||
full_state_dict["encoder_std"] = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
|
||||
if latents_mean is not None:
|
||||
latents_mean_t = latents_mean if isinstance(latents_mean, torch.Tensor) else torch.tensor(latents_mean)
|
||||
full_state_dict["_latents_mean"] = latents_mean_t
|
||||
else:
|
||||
full_state_dict["_latents_mean"] = torch.zeros(1)
|
||||
if latents_std is not None:
|
||||
latents_std_t = latents_std if isinstance(latents_std, torch.Tensor) else torch.tensor(latents_std)
|
||||
full_state_dict["_latents_std"] = latents_std_t
|
||||
else:
|
||||
full_state_dict["_latents_std"] = torch.ones(1)
|
||||
|
||||
model.load_state_dict(full_state_dict, strict=False, assign=True)
|
||||
|
||||
# Verify no critical keys are missing
|
||||
model_keys = {name for name, _ in model.named_parameters()}
|
||||
model_keys |= {name for name, _ in model.named_buffers()}
|
||||
loaded_keys = set(full_state_dict.keys())
|
||||
missing = model_keys - loaded_keys
|
||||
# trainable_cls_token and decoder_pos_embed are initialized, not loaded from original checkpoint
|
||||
allowed_missing = {"decoder.trainable_cls_token", "decoder.decoder_pos_embed"}
|
||||
if missing - allowed_missing:
|
||||
print(f"Warning: missing keys after conversion: {sorted(missing - allowed_missing)}")
|
||||
|
||||
output_path = Path(args.output_path)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
model.save_pretrained(output_path)
|
||||
|
||||
if args.verify_load:
|
||||
print("Verifying converted checkpoint with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False)...")
|
||||
loaded_model = AutoencoderRAE.from_pretrained(output_path, low_cpu_mem_usage=False)
|
||||
if not isinstance(loaded_model, AutoencoderRAE):
|
||||
raise RuntimeError("Verification failed: loaded object is not AutoencoderRAE.")
|
||||
print("Verification passed.")
|
||||
|
||||
print(f"Saved converted AutoencoderRAE to: {output_path}")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Convert RAE decoder checkpoints to diffusers AutoencoderRAE format")
|
||||
parser.add_argument(
|
||||
"--repo_or_path", type=str, required=True, help="Hub repo id (e.g. nyu-visionx/RAE-collections) or local path"
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Directory to save converted model")
|
||||
|
||||
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "mae", "siglip2"], required=True)
|
||||
parser.add_argument(
|
||||
"--encoder_name_or_path", type=str, default=None, help="Optional encoder HF model id or local path override"
|
||||
)
|
||||
|
||||
parser.add_argument("--variant", type=str, default="ViTXL_n08", help="Decoder variant folder name")
|
||||
parser.add_argument("--dataset_name", type=str, default="imagenet1k", help="Stats dataset folder name")
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder_checkpoint", type=str, default=None, help="Relative path to decoder checkpoint inside repo/path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stats_checkpoint", type=str, default=None, help="Relative path to stats checkpoint inside repo/path"
|
||||
)
|
||||
|
||||
parser.add_argument("--decoder_config_name", type=str, choices=list(DECODER_CONFIGS.keys()), default="ViTXL")
|
||||
parser.add_argument("--encoder_input_size", type=int, default=224)
|
||||
parser.add_argument("--patch_size", type=int, default=16)
|
||||
parser.add_argument("--image_size", type=int, default=None)
|
||||
parser.add_argument("--num_channels", type=int, default=3)
|
||||
parser.add_argument("--scaling_factor", type=float, default=1.0)
|
||||
|
||||
parser.add_argument("--cache_dir", type=str, default=None)
|
||||
parser.add_argument("--dry_run", action="store_true", help="Only resolve and print selected files")
|
||||
parser.add_argument(
|
||||
"--verify_load",
|
||||
action="store_true",
|
||||
help="After conversion, load back with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False).",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
convert(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -202,7 +202,6 @@ else:
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderKLWan",
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderRAE",
|
||||
"AutoencoderTiny",
|
||||
"AutoModel",
|
||||
"BriaFiboTransformer2DModel",
|
||||
@@ -228,7 +227,6 @@ else:
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
"GlmImageTransformer2DModel",
|
||||
"HeliosTransformer3DModel",
|
||||
"HiDreamImageTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
@@ -361,8 +359,6 @@ else:
|
||||
"FlowMatchEulerDiscreteScheduler",
|
||||
"FlowMatchHeunDiscreteScheduler",
|
||||
"FlowMatchLCMScheduler",
|
||||
"HeliosDMDScheduler",
|
||||
"HeliosScheduler",
|
||||
"HeunDiscreteScheduler",
|
||||
"IPNDMScheduler",
|
||||
"KarrasVeScheduler",
|
||||
@@ -519,8 +515,6 @@ else:
|
||||
"FluxPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"GlmImagePipeline",
|
||||
"HeliosPipeline",
|
||||
"HeliosPyramidPipeline",
|
||||
"HiDreamImagePipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
@@ -975,7 +969,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderRAE,
|
||||
AutoencoderTiny,
|
||||
AutoModel,
|
||||
BriaFiboTransformer2DModel,
|
||||
@@ -1001,7 +994,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HeliosTransformer3DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
@@ -1130,8 +1122,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FlowMatchHeunDiscreteScheduler,
|
||||
FlowMatchLCMScheduler,
|
||||
HeliosDMDScheduler,
|
||||
HeliosScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
@@ -1267,8 +1257,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
GlmImagePipeline,
|
||||
HeliosPipeline,
|
||||
HeliosPyramidPipeline,
|
||||
HiDreamImagePipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
|
||||
@@ -89,6 +89,8 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
|
||||
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
|
||||
# with open(CONFIG, "w") as f:
|
||||
# json.dump(automap, f)
|
||||
with open("requirements.txt", "w") as f:
|
||||
f.write("")
|
||||
|
||||
def _choose_block(self, candidates, chosen=None):
|
||||
for cls, base in candidates:
|
||||
|
||||
@@ -107,38 +107,6 @@ class ConfigMixin:
|
||||
has_compatibles = False
|
||||
|
||||
_deprecated_kwargs = []
|
||||
_auto_class = None
|
||||
|
||||
@classmethod
|
||||
def register_for_auto_class(cls, auto_class="AutoModel"):
|
||||
"""
|
||||
Register this class with the given auto class so that it can be loaded with `AutoModel.from_pretrained(...,
|
||||
trust_remote_code=True)`.
|
||||
|
||||
When the config is saved, the resulting `config.json` will include an `auto_map` entry mapping the auto class
|
||||
to this class's module and class name.
|
||||
|
||||
Args:
|
||||
auto_class (`str` or type, *optional*, defaults to `"AutoModel"`):
|
||||
The auto class to register this class with. Can be a string (e.g. `"AutoModel"`) or the class itself.
|
||||
Currently only `"AutoModel"` is supported.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
|
||||
|
||||
class MyCustomModel(ModelMixin, ConfigMixin): ...
|
||||
|
||||
|
||||
MyCustomModel.register_for_auto_class("AutoModel")
|
||||
```
|
||||
"""
|
||||
if auto_class != "AutoModel":
|
||||
raise ValueError(f"Only 'AutoModel' is supported, got '{auto_class}'.")
|
||||
|
||||
cls._auto_class = auto_class
|
||||
|
||||
def register_to_config(self, **kwargs):
|
||||
if self.config_name is None:
|
||||
@@ -653,12 +621,6 @@ class ConfigMixin:
|
||||
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||
_ = config_dict.pop("_pre_quantization_dtype", None)
|
||||
|
||||
if getattr(self, "_auto_class", None) is not None:
|
||||
module = self.__class__.__module__.split(".")[-1]
|
||||
auto_map = config_dict.get("auto_map", {})
|
||||
auto_map[self._auto_class] = f"{module}.{self.__class__.__name__}"
|
||||
config_dict["auto_map"] = auto_map
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: str | os.PathLike):
|
||||
|
||||
@@ -307,17 +307,6 @@ class GroupOffloadingHook(ModelHook):
|
||||
if self.group.onload_leader == module:
|
||||
if self.group.onload_self:
|
||||
self.group.onload_()
|
||||
else:
|
||||
# onload_self=False means this group relies on prefetching from a previous group.
|
||||
# However, for conditionally-executed modules (e.g. patch_short/patch_mid/patch_long in Helios),
|
||||
# the prefetch chain may not cover them if they were absent during the first forward pass
|
||||
# when the execution order was traced. In that case, their weights remain on offload_device,
|
||||
# so we fall back to a synchronous onload here.
|
||||
params = [p for m in self.group.modules for p in m.parameters()] + list(self.group.parameters)
|
||||
if params and params[0].device == self.group.offload_device:
|
||||
self.group.onload_()
|
||||
if self.group.stream is not None:
|
||||
self.group.stream.synchronize()
|
||||
|
||||
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
|
||||
if should_onload_next_group:
|
||||
|
||||
@@ -78,7 +78,6 @@ if is_torch_available():
|
||||
"SanaLoraLoaderMixin",
|
||||
"Lumina2LoraLoaderMixin",
|
||||
"WanLoraLoaderMixin",
|
||||
"HeliosLoraLoaderMixin",
|
||||
"KandinskyLoraLoaderMixin",
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
"SkyReelsV2LoraLoaderMixin",
|
||||
@@ -119,7 +118,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView4LoraLoaderMixin,
|
||||
Flux2LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HeliosLoraLoaderMixin,
|
||||
HiDreamImageLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
KandinskyLoraLoaderMixin,
|
||||
|
||||
@@ -2519,13 +2519,6 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
if has_default:
|
||||
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
|
||||
|
||||
# Normalize ZImage-specific dot-separated module names to underscore form so they
|
||||
# match the diffusers model parameter names (context_refiner, noise_refiner).
|
||||
state_dict = {
|
||||
k.replace("context.refiner.", "context_refiner.").replace("noise.refiner.", "noise_refiner."): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
@@ -2536,18 +2529,19 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
|
||||
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
if has_non_diffusers_lora_id:
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
for k in all_keys:
|
||||
if k.endswith(down_key):
|
||||
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
|
||||
@@ -2560,69 +2554,13 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
converted_state_dict[diffusers_down_key] = down_weight * scale_down
|
||||
converted_state_dict[diffusers_up_key] = up_weight * scale_up
|
||||
|
||||
# Already in diffusers format (lora_A/lora_B), apply alpha scaling and pop.
|
||||
# Already in diffusers format (lora_A/lora_B), just pop
|
||||
elif has_diffusers_lora_id:
|
||||
for k in all_keys:
|
||||
if k.endswith(a_key):
|
||||
diffusers_up_key = k.replace(a_key, b_key)
|
||||
alpha_key = k.replace(a_key, ".alpha")
|
||||
|
||||
down_weight = state_dict.pop(k)
|
||||
up_weight = state_dict.pop(diffusers_up_key)
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[k] = down_weight * scale_down
|
||||
converted_state_dict[diffusers_up_key] = up_weight * scale_up
|
||||
|
||||
# Handle dot-format LoRA keys: ".lora.down.weight" / ".lora.up.weight".
|
||||
# Some external ZImage trainers (e.g. Anime-Z) use dots instead of underscores in
|
||||
# lora weight names and also include redundant keys:
|
||||
# - "qkv.lora.*" duplicates individual "to.q/k/v.lora.*" keys → skip qkv
|
||||
# - "out.lora.*" duplicates "to_out.0.lora.*" keys → skip bare out
|
||||
# - "to.q/k/v.lora.*" → normalise to "to_q/k/v.lora_A/B.weight"
|
||||
lora_dot_down_key = ".lora.down.weight"
|
||||
lora_dot_up_key = ".lora.up.weight"
|
||||
has_lora_dot_format = any(lora_dot_down_key in k for k in state_dict)
|
||||
|
||||
if has_lora_dot_format:
|
||||
dot_keys = list(state_dict.keys())
|
||||
for k in dot_keys:
|
||||
if lora_dot_down_key not in k:
|
||||
continue
|
||||
if k not in state_dict:
|
||||
continue # already popped by a prior iteration
|
||||
|
||||
base = k[: -len(lora_dot_down_key)]
|
||||
|
||||
# Skip combined "qkv" projection — individual to.q/k/v keys are also present.
|
||||
if base.endswith(".qkv"):
|
||||
if a_key in k or b_key in k:
|
||||
converted_state_dict[k] = state_dict.pop(k)
|
||||
elif ".alpha" in k:
|
||||
state_dict.pop(k)
|
||||
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
|
||||
state_dict.pop(base + ".alpha", None)
|
||||
continue
|
||||
|
||||
# Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection.
|
||||
if re.search(r"\.out$", base) and ".to_out" not in base:
|
||||
state_dict.pop(k)
|
||||
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
|
||||
continue
|
||||
|
||||
# Normalise "to.q/k/v" → "to_q/k/v" for the diffusers output key.
|
||||
norm_k = re.sub(
|
||||
r"\.to\.([qkv])" + re.escape(lora_dot_down_key) + r"$",
|
||||
r".to_\1" + lora_dot_down_key,
|
||||
k,
|
||||
)
|
||||
norm_base = norm_k[: -len(lora_dot_down_key)]
|
||||
alpha_key = norm_base + ".alpha"
|
||||
|
||||
diffusers_down = norm_k.replace(lora_dot_down_key, ".lora_A.weight")
|
||||
diffusers_up = norm_k.replace(lora_dot_down_key, ".lora_B.weight")
|
||||
|
||||
down_weight = state_dict.pop(k)
|
||||
up_weight = state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key))
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[diffusers_down] = down_weight * scale_down
|
||||
converted_state_dict[diffusers_up] = up_weight * scale_up
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
|
||||
|
||||
@@ -3440,207 +3440,6 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class HeliosLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`] and [`HeliosPyramidPipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
if any(k.startswith("diffusion_model.") for k in state_dict):
|
||||
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
|
||||
elif any(k.startswith("lora_unet_") for k in state_dict):
|
||||
state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: str | None = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata: dict | None = None,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: list[str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
|
||||
|
||||
@@ -51,7 +51,6 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"FluxTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"HeliosTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"MochiTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
|
||||
@@ -49,7 +49,6 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
|
||||
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
||||
_import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"]
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
@@ -101,7 +100,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_helios"] = ["HeliosTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
|
||||
@@ -169,7 +167,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderRAE,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
VQModel,
|
||||
@@ -215,7 +212,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2Transformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HeliosTransformer3DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
|
||||
@@ -18,7 +18,6 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_kl_wan import AutoencoderKLWan
|
||||
from .autoencoder_oobleck import AutoencoderOobleck
|
||||
from .autoencoder_rae import AutoencoderRAE
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
from .vq_model import VQModel
|
||||
|
||||
@@ -1,692 +0,0 @@
|
||||
# Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from math import sqrt
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ...utils.import_utils import is_transformers_available
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import (
|
||||
Dinov2WithRegistersConfig,
|
||||
Dinov2WithRegistersModel,
|
||||
SiglipVisionConfig,
|
||||
SiglipVisionModel,
|
||||
ViTMAEConfig,
|
||||
ViTMAEModel,
|
||||
)
|
||||
|
||||
from ..activations import get_activation
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import get_2d_sincos_pos_embed
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-encoder forward functions
|
||||
# ---------------------------------------------------------------------------
|
||||
# Each function takes the raw transformers model + images and returns patch
|
||||
# tokens of shape (B, N, C), stripping CLS / register tokens as needed.
|
||||
|
||||
|
||||
def _dinov2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
|
||||
outputs = model(images, output_hidden_states=True)
|
||||
unused_token_num = 5 # 1 CLS + 4 register tokens
|
||||
return outputs.last_hidden_state[:, unused_token_num:]
|
||||
|
||||
|
||||
def _siglip2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
|
||||
outputs = model(images, output_hidden_states=True, interpolate_pos_encoding=True)
|
||||
return outputs.last_hidden_state
|
||||
|
||||
|
||||
def _mae_encoder_forward(model: nn.Module, images: torch.Tensor, patch_size: int) -> torch.Tensor:
|
||||
h, w = images.shape[2], images.shape[3]
|
||||
patch_num = int(h * w // patch_size**2)
|
||||
if patch_num * patch_size**2 != h * w:
|
||||
raise ValueError("Image size should be divisible by patch size.")
|
||||
noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0], -1).to(images.device).to(images.dtype)
|
||||
outputs = model(images, noise, interpolate_pos_encoding=True)
|
||||
return outputs.last_hidden_state[:, 1:] # remove cls token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder construction helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_encoder(
|
||||
encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int, head_dim: int = 64
|
||||
) -> nn.Module:
|
||||
"""Build a frozen encoder from config (no pretrained download)."""
|
||||
num_attention_heads = hidden_size // head_dim # all supported encoders use head_dim=64
|
||||
|
||||
if encoder_type == "dinov2":
|
||||
config = Dinov2WithRegistersConfig(
|
||||
hidden_size=hidden_size,
|
||||
patch_size=patch_size,
|
||||
image_size=518,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
)
|
||||
model = Dinov2WithRegistersModel(config)
|
||||
# RAE strips the final layernorm affine params (identity LN). Remove them from
|
||||
# the architecture so `from_pretrained` doesn't leave them on the meta device.
|
||||
model.layernorm.weight = None
|
||||
model.layernorm.bias = None
|
||||
elif encoder_type == "siglip2":
|
||||
config = SiglipVisionConfig(
|
||||
hidden_size=hidden_size,
|
||||
patch_size=patch_size,
|
||||
image_size=256,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
)
|
||||
model = SiglipVisionModel(config)
|
||||
# See dinov2 comment above.
|
||||
model.vision_model.post_layernorm.weight = None
|
||||
model.vision_model.post_layernorm.bias = None
|
||||
elif encoder_type == "mae":
|
||||
config = ViTMAEConfig(
|
||||
hidden_size=hidden_size,
|
||||
patch_size=patch_size,
|
||||
image_size=224,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
mask_ratio=0.0,
|
||||
)
|
||||
model = ViTMAEModel(config)
|
||||
# See dinov2 comment above.
|
||||
model.layernorm.weight = None
|
||||
model.layernorm.bias = None
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder_type='{encoder_type}'. Available: dinov2, siglip2, mae")
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
_ENCODER_FORWARD_FNS = {
|
||||
"dinov2": _dinov2_encoder_forward,
|
||||
"siglip2": _siglip2_encoder_forward,
|
||||
"mae": _mae_encoder_forward,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAEDecoderOutput(BaseOutput):
|
||||
"""
|
||||
Output of `RAEDecoder`.
|
||||
|
||||
Args:
|
||||
logits (`torch.Tensor`):
|
||||
Patch reconstruction logits of shape `(batch_size, num_patches, patch_size**2 * num_channels)`.
|
||||
"""
|
||||
|
||||
logits: torch.Tensor
|
||||
|
||||
|
||||
class ViTMAEIntermediate(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(hidden_size, intermediate_size)
|
||||
self.intermediate_act_fn = get_activation(hidden_act)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ViTMAEOutput(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(intermediate_size, hidden_size)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = hidden_states + input_tensor
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ViTMAELayer(nn.Module):
|
||||
"""
|
||||
This matches the naming/parameter structure used in RAE-main (ViTMAE decoder block).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
intermediate_size: int,
|
||||
qkv_bias: bool = True,
|
||||
layer_norm_eps: float = 1e-12,
|
||||
hidden_dropout_prob: float = 0.0,
|
||||
attention_probs_dropout_prob: float = 0.0,
|
||||
hidden_act: str = "gelu",
|
||||
):
|
||||
super().__init__()
|
||||
if hidden_size % num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}"
|
||||
)
|
||||
self.attention = Attention(
|
||||
query_dim=hidden_size,
|
||||
heads=num_attention_heads,
|
||||
dim_head=hidden_size // num_attention_heads,
|
||||
dropout=attention_probs_dropout_prob,
|
||||
bias=qkv_bias,
|
||||
)
|
||||
self.intermediate = ViTMAEIntermediate(
|
||||
hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act
|
||||
)
|
||||
self.output = ViTMAEOutput(
|
||||
hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob
|
||||
)
|
||||
self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
attention_output = self.attention(self.layernorm_before(hidden_states))
|
||||
hidden_states = attention_output + hidden_states
|
||||
|
||||
layer_output = self.layernorm_after(hidden_states)
|
||||
layer_output = self.intermediate(layer_output)
|
||||
layer_output = self.output(layer_output, hidden_states)
|
||||
return layer_output
|
||||
|
||||
|
||||
class RAEDecoder(nn.Module):
|
||||
"""Lightweight RAE decoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 768,
|
||||
decoder_hidden_size: int = 512,
|
||||
decoder_num_hidden_layers: int = 8,
|
||||
decoder_num_attention_heads: int = 16,
|
||||
decoder_intermediate_size: int = 2048,
|
||||
num_patches: int = 256,
|
||||
patch_size: int = 16,
|
||||
num_channels: int = 3,
|
||||
image_size: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
layer_norm_eps: float = 1e-12,
|
||||
hidden_dropout_prob: float = 0.0,
|
||||
attention_probs_dropout_prob: float = 0.0,
|
||||
hidden_act: str = "gelu",
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_hidden_size = decoder_hidden_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
|
||||
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_hidden_size))
|
||||
|
||||
self.decoder_layers = nn.ModuleList(
|
||||
[
|
||||
ViTMAELayer(
|
||||
hidden_size=decoder_hidden_size,
|
||||
num_attention_heads=decoder_num_attention_heads,
|
||||
intermediate_size=decoder_intermediate_size,
|
||||
qkv_bias=qkv_bias,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
hidden_act=hidden_act,
|
||||
)
|
||||
for _ in range(decoder_num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps)
|
||||
self.decoder_pred = nn.Linear(decoder_hidden_size, patch_size**2 * num_channels, bias=True)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self._initialize_weights(num_patches)
|
||||
self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
|
||||
|
||||
def _initialize_weights(self, num_patches: int):
|
||||
# Skip initialization when parameters are on meta device (e.g. during
|
||||
# accelerate.init_empty_weights() used by low_cpu_mem_usage loading).
|
||||
# The weights are initialized.
|
||||
if self.decoder_pos_embed.device.type == "meta":
|
||||
return
|
||||
|
||||
grid_size = int(num_patches**0.5)
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
self.decoder_pos_embed.shape[-1],
|
||||
grid_size,
|
||||
cls_token=True,
|
||||
extra_tokens=1,
|
||||
output_type="pt",
|
||||
device=self.decoder_pos_embed.device,
|
||||
)
|
||||
self.decoder_pos_embed.data.copy_(pos_embed.unsqueeze(0).to(dtype=self.decoder_pos_embed.dtype))
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
|
||||
embeddings_positions = embeddings.shape[1] - 1
|
||||
num_positions = self.decoder_pos_embed.shape[1] - 1
|
||||
|
||||
class_pos_embed = self.decoder_pos_embed[:, 0, :]
|
||||
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
|
||||
dim = self.decoder_pos_embed.shape[-1]
|
||||
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim).permute(0, 3, 1, 2)
|
||||
patch_pos_embed = F.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(1, embeddings_positions / num_positions),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, l, c = x.shape
|
||||
if l == self.num_patches:
|
||||
return x
|
||||
h = w = int(l**0.5)
|
||||
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
||||
target_size = (int(self.num_patches**0.5), int(self.num_patches**0.5))
|
||||
x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
|
||||
x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c)
|
||||
return x
|
||||
|
||||
def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: tuple[int, int] | None = None):
|
||||
patch_size, num_channels = self.patch_size, self.num_channels
|
||||
original_image_size = (
|
||||
original_image_size if original_image_size is not None else (self.image_size, self.image_size)
|
||||
)
|
||||
original_height, original_width = original_image_size
|
||||
num_patches_h = original_height // patch_size
|
||||
num_patches_w = original_width // patch_size
|
||||
if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
|
||||
raise ValueError(
|
||||
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
|
||||
)
|
||||
|
||||
batch_size = patchified_pixel_values.shape[0]
|
||||
patchified_pixel_values = patchified_pixel_values.reshape(
|
||||
batch_size,
|
||||
num_patches_h,
|
||||
num_patches_w,
|
||||
patch_size,
|
||||
patch_size,
|
||||
num_channels,
|
||||
)
|
||||
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
|
||||
pixel_values = patchified_pixel_values.reshape(
|
||||
batch_size,
|
||||
num_channels,
|
||||
num_patches_h * patch_size,
|
||||
num_patches_w * patch_size,
|
||||
)
|
||||
return pixel_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
drop_cls_token: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> RAEDecoderOutput | tuple[torch.Tensor]:
|
||||
x = self.decoder_embed(hidden_states)
|
||||
if drop_cls_token:
|
||||
x_ = x[:, 1:, :]
|
||||
x_ = self.interpolate_latent(x_)
|
||||
else:
|
||||
x_ = self.interpolate_latent(x)
|
||||
|
||||
cls_token = self.trainable_cls_token.expand(x_.shape[0], -1, -1)
|
||||
x = torch.cat([cls_token, x_], dim=1)
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
if not drop_cls_token:
|
||||
raise ValueError("interpolate_pos_encoding only supports drop_cls_token=True")
|
||||
decoder_pos_embed = self.interpolate_pos_encoding(x)
|
||||
else:
|
||||
decoder_pos_embed = self.decoder_pos_embed
|
||||
|
||||
hidden_states = x + decoder_pos_embed.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
for layer_module in self.decoder_layers:
|
||||
hidden_states = layer_module(hidden_states)
|
||||
|
||||
hidden_states = self.decoder_norm(hidden_states)
|
||||
logits = self.decoder_pred(hidden_states)
|
||||
logits = logits[:, 1:, :]
|
||||
|
||||
if not return_dict:
|
||||
return (logits,)
|
||||
return RAEDecoderOutput(logits=logits)
|
||||
|
||||
|
||||
class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images.
|
||||
|
||||
This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct
|
||||
images from learned representations.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
|
||||
all models (such as downloading or saving).
|
||||
|
||||
Args:
|
||||
encoder_type (`str`, *optional*, defaults to `"dinov2"`):
|
||||
Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`.
|
||||
encoder_hidden_size (`int`, *optional*, defaults to `768`):
|
||||
Hidden size of the encoder model.
|
||||
encoder_patch_size (`int`, *optional*, defaults to `14`):
|
||||
Patch size of the encoder model.
|
||||
encoder_num_hidden_layers (`int`, *optional*, defaults to `12`):
|
||||
Number of hidden layers in the encoder model.
|
||||
patch_size (`int`, *optional*, defaults to `16`):
|
||||
Decoder patch size (used for unpatchify and decoder head).
|
||||
encoder_input_size (`int`, *optional*, defaults to `224`):
|
||||
Input size expected by the encoder.
|
||||
image_size (`int`, *optional*):
|
||||
Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like
|
||||
RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size //
|
||||
encoder_patch_size) ** 2`.
|
||||
num_channels (`int`, *optional*, defaults to `3`):
|
||||
Number of input/output channels.
|
||||
encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
|
||||
Channel-wise mean for encoder input normalization (ImageNet defaults).
|
||||
encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
|
||||
Channel-wise std for encoder input normalization (ImageNet defaults).
|
||||
latents_mean (`list` or `tuple`, *optional*):
|
||||
Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable
|
||||
lists.
|
||||
latents_std (`list` or `tuple`, *optional*):
|
||||
Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to
|
||||
config-serializable lists.
|
||||
noise_tau (`float`, *optional*, defaults to `0.0`):
|
||||
Noise level for training (adds noise to latents during training).
|
||||
reshape_to_2d (`bool`, *optional*, defaults to `True`):
|
||||
Whether to reshape latents to 2D (B, C, H, W) format.
|
||||
use_encoder_loss (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use encoder hidden states in the loss (for advanced training).
|
||||
"""
|
||||
|
||||
# NOTE: gradient checkpointing is not wired up for this model yet.
|
||||
_supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["ViTMAELayer"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
encoder_type: str = "dinov2",
|
||||
encoder_hidden_size: int = 768,
|
||||
encoder_patch_size: int = 14,
|
||||
encoder_num_hidden_layers: int = 12,
|
||||
decoder_hidden_size: int = 512,
|
||||
decoder_num_hidden_layers: int = 8,
|
||||
decoder_num_attention_heads: int = 16,
|
||||
decoder_intermediate_size: int = 2048,
|
||||
patch_size: int = 16,
|
||||
encoder_input_size: int = 224,
|
||||
image_size: int | None = None,
|
||||
num_channels: int = 3,
|
||||
encoder_norm_mean: list | None = None,
|
||||
encoder_norm_std: list | None = None,
|
||||
latents_mean: list | tuple | torch.Tensor | None = None,
|
||||
latents_std: list | tuple | torch.Tensor | None = None,
|
||||
noise_tau: float = 0.0,
|
||||
reshape_to_2d: bool = True,
|
||||
use_encoder_loss: bool = False,
|
||||
scaling_factor: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if encoder_type not in _ENCODER_FORWARD_FNS:
|
||||
raise ValueError(
|
||||
f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}"
|
||||
)
|
||||
|
||||
if encoder_input_size % encoder_patch_size != 0:
|
||||
raise ValueError(
|
||||
f"encoder_input_size={encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}."
|
||||
)
|
||||
|
||||
decoder_patch_size = patch_size
|
||||
if decoder_patch_size <= 0:
|
||||
raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).")
|
||||
|
||||
num_patches = (encoder_input_size // encoder_patch_size) ** 2
|
||||
grid = int(sqrt(num_patches))
|
||||
if grid * grid != num_patches:
|
||||
raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.")
|
||||
|
||||
derived_image_size = decoder_patch_size * grid
|
||||
if image_size is None:
|
||||
image_size = derived_image_size
|
||||
else:
|
||||
image_size = int(image_size)
|
||||
if image_size != derived_image_size:
|
||||
raise ValueError(
|
||||
f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} "
|
||||
f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}."
|
||||
)
|
||||
|
||||
def _to_config_compatible(value: Any) -> Any:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.detach().cpu().tolist()
|
||||
if isinstance(value, tuple):
|
||||
return [_to_config_compatible(v) for v in value]
|
||||
if isinstance(value, list):
|
||||
return [_to_config_compatible(v) for v in value]
|
||||
return value
|
||||
|
||||
def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tensor | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.detach().clone()
|
||||
return torch.tensor(value, dtype=torch.float32)
|
||||
|
||||
latents_std_tensor = _as_optional_tensor(latents_std)
|
||||
|
||||
# Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors.
|
||||
self.register_to_config(
|
||||
latents_mean=_to_config_compatible(latents_mean),
|
||||
latents_std=_to_config_compatible(latents_std),
|
||||
)
|
||||
|
||||
# Frozen representation encoder (built from config, no downloads)
|
||||
self.encoder: nn.Module = _build_encoder(
|
||||
encoder_type=encoder_type,
|
||||
hidden_size=encoder_hidden_size,
|
||||
patch_size=encoder_patch_size,
|
||||
num_hidden_layers=encoder_num_hidden_layers,
|
||||
)
|
||||
self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type]
|
||||
num_patches = (encoder_input_size // encoder_patch_size) ** 2
|
||||
|
||||
# Encoder input normalization stats (ImageNet defaults)
|
||||
if encoder_norm_mean is None:
|
||||
encoder_norm_mean = [0.485, 0.456, 0.406]
|
||||
if encoder_norm_std is None:
|
||||
encoder_norm_std = [0.229, 0.224, 0.225]
|
||||
encoder_mean_tensor = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
|
||||
encoder_std_tensor = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
|
||||
|
||||
self.register_buffer("encoder_mean", encoder_mean_tensor, persistent=True)
|
||||
self.register_buffer("encoder_std", encoder_std_tensor, persistent=True)
|
||||
|
||||
# Latent normalization buffers (defaults are no-ops; actual values come from checkpoint)
|
||||
latents_mean_tensor = _as_optional_tensor(latents_mean)
|
||||
if latents_mean_tensor is None:
|
||||
latents_mean_tensor = torch.zeros(1)
|
||||
self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True)
|
||||
|
||||
if latents_std_tensor is None:
|
||||
latents_std_tensor = torch.ones(1)
|
||||
self.register_buffer("_latents_std", latents_std_tensor, persistent=True)
|
||||
|
||||
# ViT-MAE style decoder
|
||||
self.decoder = RAEDecoder(
|
||||
hidden_size=int(encoder_hidden_size),
|
||||
decoder_hidden_size=int(decoder_hidden_size),
|
||||
decoder_num_hidden_layers=int(decoder_num_hidden_layers),
|
||||
decoder_num_attention_heads=int(decoder_num_attention_heads),
|
||||
decoder_intermediate_size=int(decoder_intermediate_size),
|
||||
num_patches=int(num_patches),
|
||||
patch_size=int(decoder_patch_size),
|
||||
num_channels=int(num_channels),
|
||||
image_size=int(image_size),
|
||||
)
|
||||
|
||||
self.num_patches = int(num_patches)
|
||||
self.decoder_patch_size = int(decoder_patch_size)
|
||||
self.decoder_image_size = int(image_size)
|
||||
|
||||
# Slicing support (batch dimension) similar to other diffusers autoencoders
|
||||
self.use_slicing = False
|
||||
|
||||
def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
# Per-sample random sigma in [0, noise_tau]
|
||||
noise_sigma = self.config.noise_tau * torch.rand(
|
||||
(x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator
|
||||
)
|
||||
return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype)
|
||||
|
||||
def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, h, w = x.shape
|
||||
if h != self.config.encoder_input_size or w != self.config.encoder_input_size:
|
||||
x = F.interpolate(
|
||||
x,
|
||||
size=(self.config.encoder_input_size, self.config.encoder_input_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
|
||||
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
|
||||
return (x - mean) / std
|
||||
|
||||
def _denormalize_image(self, x: torch.Tensor) -> torch.Tensor:
|
||||
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
|
||||
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
|
||||
return x * std + mean
|
||||
|
||||
def _normalize_latents(self, z: torch.Tensor) -> torch.Tensor:
|
||||
latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype)
|
||||
latents_std = self._latents_std.to(device=z.device, dtype=z.dtype)
|
||||
return (z - latents_mean) / (latents_std + 1e-5)
|
||||
|
||||
def _denormalize_latents(self, z: torch.Tensor) -> torch.Tensor:
|
||||
latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype)
|
||||
latents_std = self._latents_std.to(device=z.device, dtype=z.dtype)
|
||||
return z * (latents_std + 1e-5) + latents_mean
|
||||
|
||||
def _encode(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
x = self._resize_and_normalize(x)
|
||||
|
||||
if self.config.encoder_type == "mae":
|
||||
tokens = self._encoder_forward_fn(self.encoder, x, self.config.encoder_patch_size)
|
||||
else:
|
||||
tokens = self._encoder_forward_fn(self.encoder, x) # (B, N, C)
|
||||
|
||||
if self.training and self.config.noise_tau > 0:
|
||||
tokens = self._noising(tokens, generator=generator)
|
||||
|
||||
if self.config.reshape_to_2d:
|
||||
b, n, c = tokens.shape
|
||||
side = int(sqrt(n))
|
||||
if side * side != n:
|
||||
raise ValueError(f"Token length n={n} is not a perfect square; cannot reshape to 2D.")
|
||||
z = tokens.transpose(1, 2).contiguous().view(b, c, side, side) # (B, C, h, w)
|
||||
else:
|
||||
z = tokens
|
||||
|
||||
z = self._normalize_latents(z)
|
||||
|
||||
# Follow diffusers convention: optionally scale latents for diffusion
|
||||
if self.config.scaling_factor != 1.0:
|
||||
z = z * self.config.scaling_factor
|
||||
|
||||
return z
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None
|
||||
) -> EncoderOutput | tuple[torch.Tensor]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
latents = torch.cat([self._encode(x_slice, generator=generator) for x_slice in x.split(1)], dim=0)
|
||||
else:
|
||||
latents = self._encode(x, generator=generator)
|
||||
|
||||
if not return_dict:
|
||||
return (latents,)
|
||||
return EncoderOutput(latent=latents)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
# Undo scaling factor if applied at encode time
|
||||
if self.config.scaling_factor != 1.0:
|
||||
z = z / self.config.scaling_factor
|
||||
|
||||
z = self._denormalize_latents(z)
|
||||
|
||||
if self.config.reshape_to_2d:
|
||||
b, c, h, w = z.shape
|
||||
tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C)
|
||||
else:
|
||||
tokens = z
|
||||
|
||||
logits = self.decoder(tokens, return_dict=True).logits
|
||||
x_rec = self.decoder.unpatchify(logits)
|
||||
x_rec = self._denormalize_image(x_rec)
|
||||
return x_rec
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded = torch.cat([self._decode(z_slice) for z_slice in z.split(1)], dim=0)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self, sample: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
latents = self.encode(sample, return_dict=False, generator=generator)[0]
|
||||
decoded = self.decode(latents, return_dict=False)[0]
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
@@ -28,7 +28,6 @@ if is_torch_available():
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_flux2 import Flux2Transformer2DModel
|
||||
from .transformer_glm_image import GlmImageTransformer2DModel
|
||||
from .transformer_helios import HeliosTransformer3DModel
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
|
||||
|
||||
@@ -1,814 +0,0 @@
|
||||
# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def pad_for_3d_conv(x, kernel_size):
|
||||
b, c, t, h, w = x.shape
|
||||
pt, ph, pw = kernel_size
|
||||
pad_t = (pt - (t % pt)) % pt
|
||||
pad_h = (ph - (h % ph)) % ph
|
||||
pad_w = (pw - (w % pw)) % pw
|
||||
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
|
||||
|
||||
|
||||
def center_down_sample_3d(x, kernel_size):
|
||||
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
|
||||
|
||||
|
||||
def apply_rotary_emb_transposed(
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
):
|
||||
x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
|
||||
cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
|
||||
out = torch.empty_like(hidden_states)
|
||||
out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]
|
||||
out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]
|
||||
return out.type_as(hidden_states)
|
||||
|
||||
|
||||
def _get_qkv_projections(attn: "HeliosAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
|
||||
# encoder_hidden_states is only passed for cross-attention
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.fused_projections:
|
||||
if not attn.is_cross_attention:
|
||||
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
else:
|
||||
# In cross-attention layers, we can only fuse the KV projections into a single linear
|
||||
query = attn.to_q(hidden_states)
|
||||
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
return query, key, value
|
||||
|
||||
|
||||
class HeliosOutputNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
|
||||
super().__init__()
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||
self.norm = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, original_context_length: int):
|
||||
temb = temb[:, -original_context_length:, :]
|
||||
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
|
||||
shift, scale = shift.squeeze(2).to(hidden_states.device), scale.squeeze(2).to(hidden_states.device)
|
||||
hidden_states = hidden_states[:, -original_context_length:, :]
|
||||
hidden_states = (self.norm(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HeliosAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"HeliosAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "HeliosAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
original_context_length: int = None,
|
||||
) -> torch.Tensor:
|
||||
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if rotary_emb is not None:
|
||||
query = apply_rotary_emb_transposed(query, rotary_emb)
|
||||
key = apply_rotary_emb_transposed(key, rotary_emb)
|
||||
|
||||
if not attn.is_cross_attention and attn.is_amplify_history:
|
||||
history_seq_len = hidden_states.shape[1] - original_context_length
|
||||
|
||||
if history_seq_len > 0:
|
||||
scale_key = 1.0 + torch.sigmoid(attn.history_key_scale) * (attn.max_scale - 1.0)
|
||||
if attn.history_scale_mode == "per_head":
|
||||
scale_key = scale_key.view(1, 1, -1, 1)
|
||||
key = torch.cat([key[:, :history_seq_len] * scale_key, key[:, history_seq_len:]], dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12909
|
||||
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HeliosAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = HeliosAttnProcessor
|
||||
_available_processors = [HeliosAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
eps: float = 1e-5,
|
||||
dropout: float = 0.0,
|
||||
added_kv_proj_dim: int | None = None,
|
||||
cross_attention_dim_head: int | None = None,
|
||||
processor=None,
|
||||
is_cross_attention=None,
|
||||
is_amplify_history=False,
|
||||
history_scale_mode="per_head", # [scalar, per_head]
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.cross_attention_dim_head = cross_attention_dim_head
|
||||
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
|
||||
|
||||
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
|
||||
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_out = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.inner_dim, dim, bias=True),
|
||||
torch.nn.Dropout(dropout),
|
||||
]
|
||||
)
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.add_k_proj = self.add_v_proj = None
|
||||
if added_kv_proj_dim is not None:
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
if is_cross_attention is not None:
|
||||
self.is_cross_attention = is_cross_attention
|
||||
else:
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
self.is_amplify_history = is_amplify_history
|
||||
if is_amplify_history:
|
||||
if history_scale_mode == "scalar":
|
||||
self.history_key_scale = nn.Parameter(torch.ones(1))
|
||||
elif history_scale_mode == "per_head":
|
||||
self.history_key_scale = nn.Parameter(torch.ones(heads))
|
||||
else:
|
||||
raise ValueError(f"Unknown history_scale_mode: {history_scale_mode}")
|
||||
self.history_scale_mode = history_scale_mode
|
||||
self.max_scale = 10.0
|
||||
|
||||
def fuse_projections(self):
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if not self.is_cross_attention:
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_qkv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
else:
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_kv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
|
||||
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_added_kv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
|
||||
self.fused_projections = True
|
||||
|
||||
@torch.no_grad()
|
||||
def unfuse_projections(self):
|
||||
if not getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if hasattr(self, "to_qkv"):
|
||||
delattr(self, "to_qkv")
|
||||
if hasattr(self, "to_kv"):
|
||||
delattr(self, "to_kv")
|
||||
if hasattr(self, "to_added_kv"):
|
||||
delattr(self, "to_added_kv")
|
||||
|
||||
self.fused_projections = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
original_context_length: int = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
rotary_emb,
|
||||
original_context_length,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class HeliosTimeTextEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
time_freq_dim: int,
|
||||
time_proj_dim: int,
|
||||
text_embed_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
||||
self.act_fn = nn.SiLU()
|
||||
self.time_proj = nn.Linear(dim, time_proj_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
is_return_encoder_hidden_states: bool = True,
|
||||
):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
|
||||
if encoder_hidden_states is not None and is_return_encoder_hidden_states:
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
|
||||
return temb, timestep_proj, encoder_hidden_states
|
||||
|
||||
|
||||
class HeliosRotaryPosEmbed(nn.Module):
|
||||
def __init__(self, rope_dim, theta):
|
||||
super().__init__()
|
||||
self.DT, self.DY, self.DX = rope_dim
|
||||
self.theta = theta
|
||||
self.register_buffer("freqs_base_t", self._get_freqs_base(self.DT), persistent=False)
|
||||
self.register_buffer("freqs_base_y", self._get_freqs_base(self.DY), persistent=False)
|
||||
self.register_buffer("freqs_base_x", self._get_freqs_base(self.DX), persistent=False)
|
||||
|
||||
def _get_freqs_base(self, dim):
|
||||
return 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim))
|
||||
|
||||
@torch.no_grad()
|
||||
def get_frequency_batched(self, freqs_base, pos):
|
||||
freqs = torch.einsum("d,bthw->dbthw", freqs_base, pos)
|
||||
freqs = freqs.repeat_interleave(2, dim=0)
|
||||
return freqs.cos(), freqs.sin()
|
||||
|
||||
@torch.no_grad()
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_spatial_meshgrid(self, height, width, device_str):
|
||||
device = torch.device(device_str)
|
||||
grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)
|
||||
grid_x_coords = torch.arange(width, device=device, dtype=torch.float32)
|
||||
grid_y, grid_x = torch.meshgrid(grid_y_coords, grid_x_coords, indexing="ij")
|
||||
return grid_y, grid_x
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, frame_indices, height, width, device):
|
||||
batch_size = frame_indices.shape[0]
|
||||
num_frames = frame_indices.shape[1]
|
||||
|
||||
frame_indices = frame_indices.to(device=device, dtype=torch.float32)
|
||||
grid_y, grid_x = self._get_spatial_meshgrid(height, width, str(device))
|
||||
|
||||
grid_t = frame_indices[:, :, None, None].expand(batch_size, num_frames, height, width)
|
||||
grid_y_batch = grid_y[None, None, :, :].expand(batch_size, num_frames, -1, -1)
|
||||
grid_x_batch = grid_x[None, None, :, :].expand(batch_size, num_frames, -1, -1)
|
||||
|
||||
freqs_cos_t, freqs_sin_t = self.get_frequency_batched(self.freqs_base_t, grid_t)
|
||||
freqs_cos_y, freqs_sin_y = self.get_frequency_batched(self.freqs_base_y, grid_y_batch)
|
||||
freqs_cos_x, freqs_sin_x = self.get_frequency_batched(self.freqs_base_x, grid_x_batch)
|
||||
|
||||
result = torch.cat([freqs_cos_t, freqs_cos_y, freqs_cos_x, freqs_sin_t, freqs_sin_y, freqs_sin_x], dim=0)
|
||||
|
||||
return result.permute(1, 0, 2, 3, 4)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class HeliosTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ffn_dim: int,
|
||||
num_heads: int,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
added_kv_proj_dim: int | None = None,
|
||||
guidance_cross_attn: bool = False,
|
||||
is_amplify_history: bool = False,
|
||||
history_scale_mode: str = "per_head", # [scalar, per_head]
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = HeliosAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
eps=eps,
|
||||
cross_attention_dim_head=None,
|
||||
processor=HeliosAttnProcessor(),
|
||||
is_amplify_history=is_amplify_history,
|
||||
history_scale_mode=history_scale_mode,
|
||||
)
|
||||
|
||||
# 2. Cross-attention
|
||||
self.attn2 = HeliosAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
eps=eps,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
cross_attention_dim_head=dim // num_heads,
|
||||
processor=HeliosAttnProcessor(),
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
||||
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
# 4. Guidance cross-attention
|
||||
self.guidance_cross_attn = guidance_cross_attn
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
original_context_length: int = None,
|
||||
) -> torch.Tensor:
|
||||
if temb.ndim == 4:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table.unsqueeze(0) + temb.float()
|
||||
).chunk(6, dim=2)
|
||||
# batch_size, seq_len, 1, inner_dim
|
||||
shift_msa = shift_msa.squeeze(2)
|
||||
scale_msa = scale_msa.squeeze(2)
|
||||
gate_msa = gate_msa.squeeze(2)
|
||||
c_shift_msa = c_shift_msa.squeeze(2)
|
||||
c_scale_msa = c_scale_msa.squeeze(2)
|
||||
c_gate_msa = c_gate_msa.squeeze(2)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
None,
|
||||
None,
|
||||
rotary_emb,
|
||||
original_context_length,
|
||||
)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
if self.guidance_cross_attn:
|
||||
history_seq_len = hidden_states.shape[1] - original_context_length
|
||||
|
||||
history_hidden_states, hidden_states = torch.split(
|
||||
hidden_states, [history_seq_len, original_context_length], dim=1
|
||||
)
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
None,
|
||||
original_context_length,
|
||||
)
|
||||
hidden_states = hidden_states + attn_output
|
||||
hidden_states = torch.cat([history_hidden_states, hidden_states], dim=1)
|
||||
else:
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
None,
|
||||
original_context_length,
|
||||
)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
||||
hidden_states
|
||||
)
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HeliosTransformer3DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
|
||||
):
|
||||
r"""
|
||||
A Transformer model for video-like data used in the Helios model.
|
||||
|
||||
Args:
|
||||
patch_size (`tuple[int]`, defaults to `(1, 2, 2)`):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
||||
num_attention_heads (`int`, defaults to `40`):
|
||||
Fixed length for text embeddings.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_dim (`int`, defaults to `512`):
|
||||
Input dimension for text embeddings.
|
||||
freq_dim (`int`, defaults to `256`):
|
||||
Dimension for sinusoidal time embeddings.
|
||||
ffn_dim (`int`, defaults to `13824`):
|
||||
Intermediate dimension in feed-forward network.
|
||||
num_layers (`int`, defaults to `40`):
|
||||
The number of layers of transformer blocks to use.
|
||||
window_size (`tuple[int]`, defaults to `(-1, -1)`):
|
||||
Window size for local attention (-1 indicates global attention).
|
||||
cross_attn_norm (`bool`, defaults to `True`):
|
||||
Enable cross-attention normalization.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Enable query/key normalization.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
add_img_emb (`bool`, defaults to `False`):
|
||||
Whether to use img_emb.
|
||||
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = [
|
||||
"patch_embedding",
|
||||
"patch_short",
|
||||
"patch_mid",
|
||||
"patch_long",
|
||||
"condition_embedder",
|
||||
"norm",
|
||||
]
|
||||
_no_split_modules = ["HeliosTransformerBlock", "HeliosOutputNorm"]
|
||||
_keep_in_fp32_modules = [
|
||||
"time_embedder",
|
||||
"scale_shift_table",
|
||||
"norm1",
|
||||
"norm2",
|
||||
"norm3",
|
||||
"history_key_scale",
|
||||
]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["HeliosTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"blocks.0": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"blocks.*": {
|
||||
"temb": ContextParallelInput(split_dim=1, expected_dims=4, split_output=False),
|
||||
"rotary_emb": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"blocks.39": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: tuple[int, ...] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
text_dim: int = 4096,
|
||||
freq_dim: int = 256,
|
||||
ffn_dim: int = 13824,
|
||||
num_layers: int = 40,
|
||||
cross_attn_norm: bool = True,
|
||||
qk_norm: str | None = "rms_norm_across_heads",
|
||||
eps: float = 1e-6,
|
||||
added_kv_proj_dim: int | None = None,
|
||||
rope_dim: tuple[int, ...] = (44, 42, 42),
|
||||
rope_theta: float = 10000.0,
|
||||
guidance_cross_attn: bool = True,
|
||||
zero_history_timestep: bool = True,
|
||||
has_multi_term_memory_patch: bool = True,
|
||||
is_amplify_history: bool = False,
|
||||
history_scale_mode: str = "per_head", # [scalar, per_head]
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Patch & position embedding
|
||||
self.rope = HeliosRotaryPosEmbed(rope_dim=rope_dim, theta=rope_theta)
|
||||
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# 2. Initial Multi Term Memory Patch
|
||||
self.zero_history_timestep = zero_history_timestep
|
||||
if has_multi_term_memory_patch:
|
||||
self.patch_short = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.patch_mid = nn.Conv3d(
|
||||
in_channels,
|
||||
inner_dim,
|
||||
kernel_size=tuple(2 * p for p in patch_size),
|
||||
stride=tuple(2 * p for p in patch_size),
|
||||
)
|
||||
self.patch_long = nn.Conv3d(
|
||||
in_channels,
|
||||
inner_dim,
|
||||
kernel_size=tuple(4 * p for p in patch_size),
|
||||
stride=tuple(4 * p for p in patch_size),
|
||||
)
|
||||
|
||||
# 3. Condition embeddings
|
||||
self.condition_embedder = HeliosTimeTextEmbedding(
|
||||
dim=inner_dim,
|
||||
time_freq_dim=freq_dim,
|
||||
time_proj_dim=inner_dim * 6,
|
||||
text_embed_dim=text_dim,
|
||||
)
|
||||
|
||||
# 4. Transformer blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
HeliosTransformerBlock(
|
||||
inner_dim,
|
||||
ffn_dim,
|
||||
num_attention_heads,
|
||||
qk_norm,
|
||||
cross_attn_norm,
|
||||
eps,
|
||||
added_kv_proj_dim,
|
||||
guidance_cross_attn=guidance_cross_attn,
|
||||
is_amplify_history=is_amplify_history,
|
||||
history_scale_mode=history_scale_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output norm & projection
|
||||
self.norm_out = HeliosOutputNorm(inner_dim, eps, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
# ------------ Stage 1 ------------
|
||||
indices_hidden_states=None,
|
||||
indices_latents_history_short=None,
|
||||
indices_latents_history_mid=None,
|
||||
indices_latents_history_long=None,
|
||||
latents_history_short=None,
|
||||
latents_history_mid=None,
|
||||
latents_history_long=None,
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
) -> torch.Tensor | dict[str, torch.Tensor]:
|
||||
# 1. Input
|
||||
batch_size = hidden_states.shape[0]
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
|
||||
# 2. Process noisy latents
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
_, _, post_patch_num_frames, post_patch_height, post_patch_width = hidden_states.shape
|
||||
|
||||
if indices_hidden_states is None:
|
||||
indices_hidden_states = torch.arange(0, post_patch_num_frames).unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
rotary_emb = self.rope(
|
||||
frame_indices=indices_hidden_states,
|
||||
height=post_patch_height,
|
||||
width=post_patch_width,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
rotary_emb = rotary_emb.flatten(2).transpose(1, 2)
|
||||
original_context_length = hidden_states.shape[1]
|
||||
|
||||
# 3. Process short history latents
|
||||
if latents_history_short is not None and indices_latents_history_short is not None:
|
||||
latents_history_short = self.patch_short(latents_history_short)
|
||||
_, _, _, H1, W1 = latents_history_short.shape
|
||||
latents_history_short = latents_history_short.flatten(2).transpose(1, 2)
|
||||
|
||||
rotary_emb_history_short = self.rope(
|
||||
frame_indices=indices_latents_history_short,
|
||||
height=H1,
|
||||
width=W1,
|
||||
device=latents_history_short.device,
|
||||
)
|
||||
rotary_emb_history_short = rotary_emb_history_short.flatten(2).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.cat([latents_history_short, hidden_states], dim=1)
|
||||
rotary_emb = torch.cat([rotary_emb_history_short, rotary_emb], dim=1)
|
||||
|
||||
# 4. Process mid history latents
|
||||
if latents_history_mid is not None and indices_latents_history_mid is not None:
|
||||
latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4))
|
||||
latents_history_mid = self.patch_mid(latents_history_mid)
|
||||
latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2)
|
||||
|
||||
rotary_emb_history_mid = self.rope(
|
||||
frame_indices=indices_latents_history_mid,
|
||||
height=H1,
|
||||
width=W1,
|
||||
device=latents_history_mid.device,
|
||||
)
|
||||
rotary_emb_history_mid = pad_for_3d_conv(rotary_emb_history_mid, (2, 2, 2))
|
||||
rotary_emb_history_mid = center_down_sample_3d(rotary_emb_history_mid, (2, 2, 2))
|
||||
rotary_emb_history_mid = rotary_emb_history_mid.flatten(2).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1)
|
||||
rotary_emb = torch.cat([rotary_emb_history_mid, rotary_emb], dim=1)
|
||||
|
||||
# 5. Process long history latents
|
||||
if latents_history_long is not None and indices_latents_history_long is not None:
|
||||
latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8))
|
||||
latents_history_long = self.patch_long(latents_history_long)
|
||||
latents_history_long = latents_history_long.flatten(2).transpose(1, 2)
|
||||
|
||||
rotary_emb_history_long = self.rope(
|
||||
frame_indices=indices_latents_history_long,
|
||||
height=H1,
|
||||
width=W1,
|
||||
device=latents_history_long.device,
|
||||
)
|
||||
rotary_emb_history_long = pad_for_3d_conv(rotary_emb_history_long, (4, 4, 4))
|
||||
rotary_emb_history_long = center_down_sample_3d(rotary_emb_history_long, (4, 4, 4))
|
||||
rotary_emb_history_long = rotary_emb_history_long.flatten(2).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.cat([latents_history_long, hidden_states], dim=1)
|
||||
rotary_emb = torch.cat([rotary_emb_history_long, rotary_emb], dim=1)
|
||||
|
||||
history_context_length = hidden_states.shape[1] - original_context_length
|
||||
|
||||
if indices_hidden_states is not None and self.zero_history_timestep:
|
||||
timestep_t0 = torch.zeros((1), dtype=timestep.dtype, device=timestep.device)
|
||||
temb_t0, timestep_proj_t0, _ = self.condition_embedder(
|
||||
timestep_t0, encoder_hidden_states, is_return_encoder_hidden_states=False
|
||||
)
|
||||
temb_t0 = temb_t0.unsqueeze(1).expand(batch_size, history_context_length, -1)
|
||||
timestep_proj_t0 = (
|
||||
timestep_proj_t0.unflatten(-1, (6, -1))
|
||||
.view(1, 6, 1, -1)
|
||||
.expand(batch_size, -1, history_context_length, -1)
|
||||
)
|
||||
|
||||
temb, timestep_proj, encoder_hidden_states = self.condition_embedder(timestep, encoder_hidden_states)
|
||||
timestep_proj = timestep_proj.unflatten(-1, (6, -1))
|
||||
|
||||
if indices_hidden_states is not None and not self.zero_history_timestep:
|
||||
main_repeat_size = hidden_states.shape[1]
|
||||
else:
|
||||
main_repeat_size = original_context_length
|
||||
temb = temb.view(batch_size, 1, -1).expand(batch_size, main_repeat_size, -1)
|
||||
timestep_proj = timestep_proj.view(batch_size, 6, 1, -1).expand(batch_size, 6, main_repeat_size, -1)
|
||||
|
||||
if indices_hidden_states is not None and self.zero_history_timestep:
|
||||
temb = torch.cat([temb_t0, temb], dim=1)
|
||||
timestep_proj = torch.cat([timestep_proj_t0, timestep_proj], dim=2)
|
||||
|
||||
if timestep_proj.ndim == 4:
|
||||
timestep_proj = timestep_proj.permute(0, 2, 1, 3)
|
||||
|
||||
# 6. Transformer blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
rotary_emb = rotary_emb.contiguous()
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep_proj,
|
||||
rotary_emb,
|
||||
original_context_length,
|
||||
)
|
||||
else:
|
||||
for block in self.blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep_proj,
|
||||
rotary_emb,
|
||||
original_context_length,
|
||||
)
|
||||
|
||||
# 7. Normalization
|
||||
hidden_states = self.norm_out(hidden_states, temb, original_context_length)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 8. Unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -47,7 +47,6 @@ from .modular_pipeline_utils import (
|
||||
InputParam,
|
||||
InsertableDict,
|
||||
OutputParam,
|
||||
_validate_requirements,
|
||||
combine_inputs,
|
||||
combine_outputs,
|
||||
format_components,
|
||||
@@ -298,7 +297,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
config_name = "modular_config.json"
|
||||
model_name = None
|
||||
_requirements: dict[str, str] | None = None
|
||||
_workflow_map = None
|
||||
|
||||
@classmethod
|
||||
@@ -413,9 +411,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
|
||||
if "requirements" in config and config["requirements"] is not None:
|
||||
_ = _validate_requirements(config["requirements"])
|
||||
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
module_file = module_file + ".py"
|
||||
@@ -440,13 +435,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
|
||||
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
|
||||
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
|
||||
|
||||
self.register_to_config(auto_map=auto_map)
|
||||
|
||||
# resolve requirements
|
||||
requirements = _validate_requirements(getattr(self, "_requirements", None))
|
||||
if requirements:
|
||||
self.register_to_config(requirements=requirements)
|
||||
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
config = dict(self.config)
|
||||
self._internal_dict = FrozenDict(config)
|
||||
@@ -668,15 +658,6 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements
|
||||
def _requirements(self) -> dict[str, str]:
|
||||
requirements = {}
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
if getattr(block, "_requirements", None):
|
||||
requirements[block_name] = block._requirements
|
||||
return requirements
|
||||
|
||||
# used for `__repr__`
|
||||
def _get_trigger_inputs(self) -> set:
|
||||
"""
|
||||
@@ -1266,14 +1247,6 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
expected_configs=self.expected_configs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _requirements(self) -> dict[str, str]:
|
||||
requirements = {}
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
if getattr(block, "_requirements", None):
|
||||
requirements[block_name] = block._requirements
|
||||
return requirements
|
||||
|
||||
|
||||
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
@@ -1412,15 +1385,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
def outputs(self) -> list[str]:
|
||||
return next(reversed(self.sub_blocks.values())).intermediate_outputs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements
|
||||
def _requirements(self) -> dict[str, str]:
|
||||
requirements = {}
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
if getattr(block, "_requirements", None):
|
||||
requirements[block_name] = block._requirements
|
||||
return requirements
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
for block_name, block in zip(self.block_names, self.block_classes):
|
||||
@@ -1743,8 +1707,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
|
||||
)
|
||||
|
||||
self._pretrained_model_name_or_path = pretrained_model_name_or_path
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -1921,7 +1883,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
private = kwargs.pop("private", None)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
update_model_card = kwargs.pop("update_model_card", False)
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
for component_name, component_spec in self._component_specs.items():
|
||||
@@ -1996,7 +1957,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
is_pipeline=True,
|
||||
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
|
||||
is_modular=True,
|
||||
update_model_card=update_model_card,
|
||||
)
|
||||
model_card = populate_model_card(model_card, tags=card_content["tags"])
|
||||
model_card.save(os.path.join(save_directory, "README.md"))
|
||||
@@ -2292,11 +2252,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
new_component_spec = current_component_spec
|
||||
if hasattr(self, name) and getattr(self, name) is not None:
|
||||
logger.warning(f"ModularPipeline.update_components: setting {name} to None (spec unchanged)")
|
||||
elif (
|
||||
current_component_spec.default_creation_method == "from_pretrained"
|
||||
and getattr(component, "_diffusers_load_id", None) is None
|
||||
):
|
||||
new_component_spec = ComponentSpec(name=name, type_hint=type(component))
|
||||
else:
|
||||
new_component_spec = ComponentSpec.from_component(name, component)
|
||||
|
||||
@@ -2368,49 +2323,17 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
elif "default" in value:
|
||||
# check if the default is specified
|
||||
component_load_kwargs[key] = value["default"]
|
||||
# Only pass trust_remote_code to components from the same repo as the pipeline.
|
||||
# When a user passes trust_remote_code=True, they intend to trust code from the
|
||||
# pipeline's repo, not from external repos referenced in modular_model_index.json.
|
||||
trust_remote_code_stripped = False
|
||||
if (
|
||||
"trust_remote_code" in component_load_kwargs
|
||||
and self._pretrained_model_name_or_path is not None
|
||||
and spec.pretrained_model_name_or_path != self._pretrained_model_name_or_path
|
||||
):
|
||||
component_load_kwargs.pop("trust_remote_code")
|
||||
trust_remote_code_stripped = True
|
||||
|
||||
if not spec.pretrained_model_name_or_path:
|
||||
logger.info(f"Skipping component `{name}`: no pretrained model path specified.")
|
||||
continue
|
||||
|
||||
try:
|
||||
components_to_register[name] = spec.load(**component_load_kwargs)
|
||||
except Exception:
|
||||
tb = traceback.format_exc()
|
||||
if trust_remote_code_stripped and "trust_remote_code" in tb:
|
||||
warning_msg = (
|
||||
f"Failed to load component `{name}` from external repository "
|
||||
f"`{spec.pretrained_model_name_or_path}`.\n\n"
|
||||
f"`trust_remote_code=True` was not forwarded to `{name}` because it comes from "
|
||||
f"a different repository than the pipeline (`{self._pretrained_model_name_or_path}`). "
|
||||
f"For safety, `trust_remote_code` is only forwarded to components from the same "
|
||||
f"repository as the pipeline.\n\n"
|
||||
f"You need to load this component manually with `trust_remote_code=True` and pass it "
|
||||
f"to the pipeline via `pipe.update_components()`. For example, if it is a custom model:\n\n"
|
||||
f' {name} = AutoModel.from_pretrained("{spec.pretrained_model_name_or_path}", trust_remote_code=True)\n'
|
||||
f" pipe.update_components({name}={name})\n"
|
||||
)
|
||||
else:
|
||||
warning_msg = (
|
||||
f"Failed to create component {name}:\n"
|
||||
f"- Component spec: {spec}\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n"
|
||||
"If this component is not required for your workflow you can safely ignore this message.\n\n"
|
||||
"Traceback:\n"
|
||||
f"{tb}"
|
||||
)
|
||||
logger.warning(warning_msg)
|
||||
logger.warning(
|
||||
f"\nFailed to create component {name}:\n"
|
||||
f"- Component spec: {spec}\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n"
|
||||
"If this component is not required for your workflow you can safely ignore this message.\n\n"
|
||||
"Traceback:\n"
|
||||
f"{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
# Register all components at once
|
||||
self.register_components(**components_to_register)
|
||||
|
||||
@@ -22,12 +22,10 @@ from typing import Any, Literal, Type, Union, get_args, get_origin
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
|
||||
from ..utils.import_utils import _is_package_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -52,7 +50,11 @@ This modular pipeline is composed of the following blocks:
|
||||
|
||||
{components_description} {configs_section}
|
||||
|
||||
{io_specification_section}
|
||||
## Input/Output Specification
|
||||
|
||||
### Inputs {inputs_description}
|
||||
|
||||
### Outputs {outputs_description}
|
||||
"""
|
||||
|
||||
|
||||
@@ -809,46 +811,6 @@ def format_output_params(output_params, indent_level=4, max_line_length=115):
|
||||
return format_params(output_params, "Outputs", indent_level, max_line_length)
|
||||
|
||||
|
||||
def format_params_markdown(params, header="Inputs"):
|
||||
"""Format a list of InputParam or OutputParam objects as a markdown bullet-point list.
|
||||
|
||||
Suitable for model cards rendered on Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
params: list of InputParam or OutputParam objects to format
|
||||
header: Header text (e.g. "Inputs" or "Outputs")
|
||||
|
||||
Returns:
|
||||
A formatted markdown string, or empty string if params is empty.
|
||||
"""
|
||||
if not params:
|
||||
return ""
|
||||
|
||||
def get_type_str(type_hint):
|
||||
if isinstance(type_hint, UnionType) or get_origin(type_hint) is Union:
|
||||
type_strs = [t.__name__ if hasattr(t, "__name__") else str(t) for t in get_args(type_hint)]
|
||||
return " | ".join(type_strs)
|
||||
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
|
||||
|
||||
lines = [f"**{header}:**\n"] if header else []
|
||||
for param in params:
|
||||
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
|
||||
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
|
||||
param_str = f"- `{name}` (`{type_str}`"
|
||||
|
||||
if hasattr(param, "required") and not param.required:
|
||||
param_str += ", *optional*"
|
||||
if param.default is not None:
|
||||
param_str += f", defaults to `{param.default}`"
|
||||
param_str += ")"
|
||||
|
||||
desc = param.description if param.description else "No description provided"
|
||||
param_str += f": {desc}"
|
||||
lines.append(param_str)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
|
||||
"""Format a list of ComponentSpec objects into a readable string representation.
|
||||
|
||||
@@ -1022,89 +984,6 @@ def make_doc_string(
|
||||
return output
|
||||
|
||||
|
||||
def _validate_requirements(reqs):
|
||||
if reqs is None:
|
||||
normalized_reqs = {}
|
||||
else:
|
||||
if not isinstance(reqs, dict):
|
||||
raise ValueError(
|
||||
"Requirements must be provided as a dictionary mapping package names to version specifiers."
|
||||
)
|
||||
normalized_reqs = _normalize_requirements(reqs)
|
||||
|
||||
if not normalized_reqs:
|
||||
return {}
|
||||
|
||||
final: dict[str, str] = {}
|
||||
for req, specified_ver in normalized_reqs.items():
|
||||
req_available, req_actual_ver = _is_package_available(req)
|
||||
if not req_available:
|
||||
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
|
||||
|
||||
if specified_ver:
|
||||
try:
|
||||
specifier = SpecifierSet(specified_ver)
|
||||
except InvalidSpecifier as err:
|
||||
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
|
||||
|
||||
if req_actual_ver == "N/A":
|
||||
logger.warning(
|
||||
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
|
||||
)
|
||||
elif not specifier.contains(req_actual_ver, prereleases=True):
|
||||
logger.warning(
|
||||
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
|
||||
)
|
||||
|
||||
final[req] = specified_ver
|
||||
|
||||
return final
|
||||
|
||||
|
||||
def _normalize_requirements(reqs):
|
||||
if not reqs:
|
||||
return {}
|
||||
|
||||
normalized: "OrderedDict[str, str]" = OrderedDict()
|
||||
|
||||
def _accumulate(mapping: dict[str, Any]):
|
||||
for pkg, spec in mapping.items():
|
||||
if isinstance(spec, dict):
|
||||
# This is recursive because blocks are composable. This way, we can merge requirements
|
||||
# from multiple blocks.
|
||||
_accumulate(spec)
|
||||
continue
|
||||
|
||||
pkg_name = str(pkg).strip()
|
||||
if not pkg_name:
|
||||
raise ValueError("Requirement package name cannot be empty.")
|
||||
|
||||
spec_str = "" if spec is None else str(spec).strip()
|
||||
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
|
||||
spec_str = f"=={spec_str}"
|
||||
|
||||
existing_spec = normalized.get(pkg_name)
|
||||
if existing_spec is not None:
|
||||
if not existing_spec and spec_str:
|
||||
normalized[pkg_name] = spec_str
|
||||
elif existing_spec and spec_str and existing_spec != spec_str:
|
||||
try:
|
||||
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
|
||||
except InvalidSpecifier:
|
||||
logger.warning(
|
||||
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
|
||||
)
|
||||
else:
|
||||
normalized[pkg_name] = str(combined_spec)
|
||||
continue
|
||||
|
||||
normalized[pkg_name] = spec_str
|
||||
|
||||
_accumulate(reqs)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]:
|
||||
"""
|
||||
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current
|
||||
@@ -1188,7 +1067,8 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
|
||||
- blocks_description: Detailed architecture of blocks
|
||||
- components_description: List of required components
|
||||
- configs_section: Configuration parameters section
|
||||
- io_specification_section: Input/Output specification (per-workflow or unified)
|
||||
- inputs_description: Input parameters specification
|
||||
- outputs_description: Output parameters specification
|
||||
- trigger_inputs_section: Conditional execution information
|
||||
- tags: List of relevant tags for the model card
|
||||
"""
|
||||
@@ -1207,6 +1087,15 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
|
||||
if block_desc:
|
||||
blocks_desc_parts.append(f" - {block_desc}")
|
||||
|
||||
# add sub-blocks if any
|
||||
if hasattr(block, "sub_blocks") and block.sub_blocks:
|
||||
for sub_name, sub_block in block.sub_blocks.items():
|
||||
sub_class = sub_block.__class__.__name__
|
||||
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
|
||||
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
|
||||
if sub_desc:
|
||||
blocks_desc_parts.append(f" - {sub_desc}")
|
||||
|
||||
blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."
|
||||
|
||||
components = getattr(blocks, "expected_components", [])
|
||||
@@ -1232,76 +1121,63 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
|
||||
if configs_description:
|
||||
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
|
||||
|
||||
# Branch on whether workflows are defined
|
||||
has_workflows = getattr(blocks, "_workflow_map", None) is not None
|
||||
inputs = blocks.inputs
|
||||
outputs = blocks.outputs
|
||||
|
||||
if has_workflows:
|
||||
workflow_map = blocks._workflow_map
|
||||
parts = []
|
||||
# format inputs as markdown list
|
||||
inputs_parts = []
|
||||
required_inputs = [inp for inp in inputs if inp.required]
|
||||
optional_inputs = [inp for inp in inputs if not inp.required]
|
||||
|
||||
# If blocks overrides outputs (e.g. to return just "images" instead of all intermediates),
|
||||
# use that as the shared output for all workflows
|
||||
blocks_outputs = blocks.outputs
|
||||
blocks_intermediate = getattr(blocks, "intermediate_outputs", None)
|
||||
shared_outputs = (
|
||||
blocks_outputs if blocks_intermediate is not None and blocks_outputs != blocks_intermediate else None
|
||||
)
|
||||
if required_inputs:
|
||||
inputs_parts.append("**Required:**\n")
|
||||
for inp in required_inputs:
|
||||
if hasattr(inp.type_hint, "__name__"):
|
||||
type_str = inp.type_hint.__name__
|
||||
elif inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = inp.description or "No description provided"
|
||||
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
|
||||
|
||||
parts.append("## Workflow Input Specification\n")
|
||||
if optional_inputs:
|
||||
if required_inputs:
|
||||
inputs_parts.append("")
|
||||
inputs_parts.append("**Optional:**\n")
|
||||
for inp in optional_inputs:
|
||||
if hasattr(inp.type_hint, "__name__"):
|
||||
type_str = inp.type_hint.__name__
|
||||
elif inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = inp.description or "No description provided"
|
||||
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
|
||||
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
|
||||
|
||||
# Per-workflow details: show trigger inputs with full param descriptions
|
||||
for wf_name, trigger_inputs in workflow_map.items():
|
||||
trigger_input_names = set(trigger_inputs.keys())
|
||||
try:
|
||||
workflow_blocks = blocks.get_workflow(wf_name)
|
||||
except Exception:
|
||||
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
|
||||
parts.append("*Could not resolve workflow blocks.*\n")
|
||||
parts.append("</details>\n")
|
||||
continue
|
||||
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
|
||||
|
||||
wf_inputs = workflow_blocks.inputs
|
||||
# Show only trigger inputs with full parameter descriptions
|
||||
trigger_params = [p for p in wf_inputs if p.name in trigger_input_names]
|
||||
# format outputs as markdown list
|
||||
outputs_parts = []
|
||||
for out in outputs:
|
||||
if hasattr(out.type_hint, "__name__"):
|
||||
type_str = out.type_hint.__name__
|
||||
elif out.type_hint is not None:
|
||||
type_str = str(out.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = out.description or "No description provided"
|
||||
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
|
||||
|
||||
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
|
||||
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
|
||||
|
||||
inputs_str = format_params_markdown(trigger_params, header=None)
|
||||
parts.append(inputs_str if inputs_str else "No additional inputs required.")
|
||||
parts.append("")
|
||||
|
||||
parts.append("</details>\n")
|
||||
|
||||
# Common Inputs & Outputs section (like non-workflow pipelines)
|
||||
all_inputs = blocks.inputs
|
||||
all_outputs = shared_outputs if shared_outputs is not None else blocks.outputs
|
||||
|
||||
inputs_str = format_params_markdown(all_inputs, "Inputs")
|
||||
outputs_str = format_params_markdown(all_outputs, "Outputs")
|
||||
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
|
||||
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
|
||||
|
||||
parts.append(f"\n## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}")
|
||||
|
||||
io_specification_section = "\n".join(parts)
|
||||
# Suppress trigger_inputs_section when workflows are shown (it's redundant)
|
||||
trigger_inputs_section = ""
|
||||
else:
|
||||
# Unified I/O section (original behavior)
|
||||
inputs = blocks.inputs
|
||||
outputs = blocks.outputs
|
||||
inputs_str = format_params_markdown(inputs, "Inputs")
|
||||
outputs_str = format_params_markdown(outputs, "Outputs")
|
||||
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
|
||||
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
|
||||
io_specification_section = f"## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}"
|
||||
|
||||
trigger_inputs_section = ""
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
|
||||
if trigger_inputs_list:
|
||||
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
|
||||
trigger_inputs_section = f"""
|
||||
trigger_inputs_section = ""
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
|
||||
if trigger_inputs_list:
|
||||
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
|
||||
trigger_inputs_section = f"""
|
||||
### Conditional Execution
|
||||
|
||||
This pipeline contains blocks that are selected at runtime based on inputs:
|
||||
@@ -1314,18 +1190,7 @@ This pipeline contains blocks that are selected at runtime based on inputs:
|
||||
if hasattr(blocks, "model_name") and blocks.model_name:
|
||||
tags.append(blocks.model_name)
|
||||
|
||||
if has_workflows:
|
||||
# Derive tags from workflow names
|
||||
workflow_names = set(blocks._workflow_map.keys())
|
||||
if any("inpainting" in wf for wf in workflow_names):
|
||||
tags.append("inpainting")
|
||||
if any("image2image" in wf for wf in workflow_names):
|
||||
tags.append("image-to-image")
|
||||
if any("controlnet" in wf for wf in workflow_names):
|
||||
tags.append("controlnet")
|
||||
if any("text2image" in wf for wf in workflow_names):
|
||||
tags.append("text-to-image")
|
||||
elif hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
triggers = blocks.trigger_inputs
|
||||
if any(t in triggers for t in ["mask", "mask_image"]):
|
||||
tags.append("inpainting")
|
||||
@@ -1353,7 +1218,8 @@ This pipeline uses a {block_count}-block architecture that can be customized and
|
||||
"blocks_description": blocks_description,
|
||||
"components_description": components_description,
|
||||
"configs_section": configs_section,
|
||||
"io_specification_section": io_specification_section,
|
||||
"inputs_description": inputs_description,
|
||||
"outputs_description": outputs_description,
|
||||
"trigger_inputs_section": trigger_inputs_section,
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
@@ -237,7 +237,6 @@ else:
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimateControlPipeline",
|
||||
]
|
||||
_import_structure["helios"] = ["HeliosPipeline", "HeliosPyramidPipeline"]
|
||||
_import_structure["hidream_image"] = ["HiDreamImagePipeline"]
|
||||
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
|
||||
_import_structure["hunyuan_video"] = [
|
||||
@@ -668,7 +667,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .helios import HeliosPipeline, HeliosPyramidPipeline
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
|
||||
from .hunyuan_video import (
|
||||
|
||||
@@ -54,7 +54,6 @@ from .flux import (
|
||||
)
|
||||
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .helios import HeliosPipeline, HeliosPyramidPipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
@@ -175,8 +174,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
("cogview4", CogView4Pipeline),
|
||||
("glm_image", GlmImagePipeline),
|
||||
("helios", HeliosPipeline),
|
||||
("helios-pyramid", HeliosPyramidPipeline),
|
||||
("cogview4-control", CogView4ControlPipeline),
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_helios"] = ["HeliosPipeline"]
|
||||
_import_structure["pipeline_helios_pyramid"] = ["HeliosPyramidPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_helios import HeliosPipeline
|
||||
from .pipeline_helios_pyramid import HeliosPyramidPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,916 +0,0 @@
|
||||
# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import HeliosLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, HeliosTransformer3DModel
|
||||
from ...schedulers import HeliosScheduler
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import HeliosPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from diffusers import AutoencoderKLWan, HeliosPipeline
|
||||
|
||||
>>> # Available models: BestWishYsh/Helios-Base, BestWishYsh/Helios-Mid, BestWishYsh/Helios-Distilled
|
||||
>>> model_id = "BestWishYsh/Helios-Base"
|
||||
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
>>> pipe = HeliosPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
||||
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=384,
|
||||
... width=640,
|
||||
... num_frames=132,
|
||||
... guidance_scale=5.0,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
class HeliosPipeline(DiffusionPipeline, HeliosLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video / image-to-video / video-to-video generation using Helios.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
tokenizer ([`T5Tokenizer`]):
|
||||
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
||||
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`HeliosTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
scheduler ([`HeliosScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: HeliosScheduler,
|
||||
transformer: HeliosTransformer3DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: str | list[str] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, text_inputs.attention_mask.bool()
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str | list[str],
|
||||
negative_prompt: str | list[str] | None = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `list[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, _ = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, _ = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
image=None,
|
||||
video=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if image is not None and video is not None:
|
||||
raise ValueError("image and video cannot be provided simultaneously")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 384,
|
||||
width: int = 640,
|
||||
num_frames: int = 33,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_latent_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
latents_mean: torch.Tensor,
|
||||
latents_std: torch.Tensor,
|
||||
num_latent_frames_per_chunk: int,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
fake_latents: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = device or self._execution_device
|
||||
if latents is None:
|
||||
image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype)
|
||||
latents = self.vae.encode(image).latent_dist.sample(generator=generator)
|
||||
latents = (latents - latents_mean) * latents_std
|
||||
if fake_latents is None:
|
||||
min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1
|
||||
fake_video = image.repeat(1, 1, min_frames, 1, 1).to(device=device, dtype=self.vae.dtype)
|
||||
fake_latents_full = self.vae.encode(fake_video).latent_dist.sample(generator=generator)
|
||||
fake_latents_full = (fake_latents_full - latents_mean) * latents_std
|
||||
fake_latents = fake_latents_full[:, :, -1:, :, :]
|
||||
return latents.to(device=device, dtype=dtype), fake_latents.to(device=device, dtype=dtype)
|
||||
|
||||
def prepare_video_latents(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
latents_mean: torch.Tensor,
|
||||
latents_std: torch.Tensor,
|
||||
num_latent_frames_per_chunk: int,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = device or self._execution_device
|
||||
video = video.to(device=device, dtype=self.vae.dtype)
|
||||
if latents is None:
|
||||
num_frames = video.shape[2]
|
||||
min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1
|
||||
num_chunks = num_frames // min_frames
|
||||
if num_chunks == 0:
|
||||
raise ValueError(
|
||||
f"Video must have at least {min_frames} frames "
|
||||
f"(got {num_frames} frames). "
|
||||
f"Required: (num_latent_frames_per_chunk - 1) * {self.vae_scale_factor_temporal} + 1 = ({num_latent_frames_per_chunk} - 1) * {self.vae_scale_factor_temporal} + 1 = {min_frames}"
|
||||
)
|
||||
total_valid_frames = num_chunks * min_frames
|
||||
start_frame = num_frames - total_valid_frames
|
||||
|
||||
first_frame = video[:, :, 0:1, :, :]
|
||||
first_frame_latent = self.vae.encode(first_frame).latent_dist.sample(generator=generator)
|
||||
first_frame_latent = (first_frame_latent - latents_mean) * latents_std
|
||||
|
||||
latents_chunks = []
|
||||
for i in range(num_chunks):
|
||||
chunk_start = start_frame + i * min_frames
|
||||
chunk_end = chunk_start + min_frames
|
||||
video_chunk = video[:, :, chunk_start:chunk_end, :, :]
|
||||
chunk_latents = self.vae.encode(video_chunk).latent_dist.sample(generator=generator)
|
||||
chunk_latents = (chunk_latents - latents_mean) * latents_std
|
||||
latents_chunks.append(chunk_latents)
|
||||
latents = torch.cat(latents_chunks, dim=2)
|
||||
return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str | list[str] = None,
|
||||
negative_prompt: str | list[str] = None,
|
||||
height: int = 384,
|
||||
width: int = 640,
|
||||
num_frames: int = 132,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: list[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
num_videos_per_prompt: int | None = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
output_type: str | None = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
# ------------ I2V ------------
|
||||
image: PipelineImageInput | None = None,
|
||||
image_latents: torch.Tensor | None = None,
|
||||
fake_image_latents: torch.Tensor | None = None,
|
||||
add_noise_to_image_latents: bool = True,
|
||||
image_noise_sigma_min: float = 0.111,
|
||||
image_noise_sigma_max: float = 0.135,
|
||||
# ------------ V2V ------------
|
||||
video: PipelineImageInput | None = None,
|
||||
video_latents: torch.Tensor | None = None,
|
||||
add_noise_to_video_latents: bool = True,
|
||||
video_noise_sigma_min: float = 0.111,
|
||||
video_noise_sigma_max: float = 0.135,
|
||||
# ------------ Stage 1 ------------
|
||||
history_sizes: list = [16, 2, 1],
|
||||
num_latent_frames_per_chunk: int = 9,
|
||||
keep_first_frame: bool = True,
|
||||
is_skip_first_chunk: bool = False,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
|
||||
negative_prompt (`str` or `list[str]`, *optional*):
|
||||
The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (`guidance_scale` < `1`).
|
||||
height (`int`, defaults to `384`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `640`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `132`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`HeliosPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`list`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
||||
truncated. If the prompt is shorter, it will be padded to this length.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~HeliosPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`HeliosPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
history_sizes = sorted(history_sizes, reverse=True) # From big to small
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
image,
|
||||
video,
|
||||
)
|
||||
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
vae_dtype = self.vae.dtype
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(device, self.vae.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
device, self.vae.dtype
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare image or video
|
||||
if image is not None:
|
||||
image = self.video_processor.preprocess(image, height=height, width=width)
|
||||
image_latents, fake_image_latents = self.prepare_image_latents(
|
||||
image,
|
||||
latents_mean=latents_mean,
|
||||
latents_std=latents_std,
|
||||
num_latent_frames_per_chunk=num_latent_frames_per_chunk,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=image_latents,
|
||||
fake_latents=fake_image_latents,
|
||||
)
|
||||
|
||||
if image_latents is not None and add_noise_to_image_latents:
|
||||
image_noise_sigma = (
|
||||
torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min)
|
||||
+ image_noise_sigma_min
|
||||
)
|
||||
image_latents = (
|
||||
image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device)
|
||||
+ (1 - image_noise_sigma) * image_latents
|
||||
)
|
||||
fake_image_noise_sigma = (
|
||||
torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min)
|
||||
+ video_noise_sigma_min
|
||||
)
|
||||
fake_image_latents = (
|
||||
fake_image_noise_sigma * randn_tensor(fake_image_latents.shape, generator=generator, device=device)
|
||||
+ (1 - fake_image_noise_sigma) * fake_image_latents
|
||||
)
|
||||
|
||||
if video is not None:
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
image_latents, video_latents = self.prepare_video_latents(
|
||||
video,
|
||||
latents_mean=latents_mean,
|
||||
latents_std=latents_std,
|
||||
num_latent_frames_per_chunk=num_latent_frames_per_chunk,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=video_latents,
|
||||
)
|
||||
|
||||
if video_latents is not None and add_noise_to_video_latents:
|
||||
image_noise_sigma = (
|
||||
torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min)
|
||||
+ image_noise_sigma_min
|
||||
)
|
||||
image_latents = (
|
||||
image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device)
|
||||
+ (1 - image_noise_sigma) * image_latents
|
||||
)
|
||||
|
||||
noisy_latents_chunks = []
|
||||
num_latent_chunks = video_latents.shape[2] // num_latent_frames_per_chunk
|
||||
for i in range(num_latent_chunks):
|
||||
chunk_start = i * num_latent_frames_per_chunk
|
||||
chunk_end = chunk_start + num_latent_frames_per_chunk
|
||||
latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :]
|
||||
|
||||
chunk_frames = latent_chunk.shape[2]
|
||||
frame_sigmas = (
|
||||
torch.rand(chunk_frames, device=device, generator=generator)
|
||||
* (video_noise_sigma_max - video_noise_sigma_min)
|
||||
+ video_noise_sigma_min
|
||||
)
|
||||
frame_sigmas = frame_sigmas.view(1, 1, chunk_frames, 1, 1)
|
||||
|
||||
noisy_chunk = (
|
||||
frame_sigmas * randn_tensor(latent_chunk.shape, generator=generator, device=device)
|
||||
+ (1 - frame_sigmas) * latent_chunk
|
||||
)
|
||||
noisy_latents_chunks.append(noisy_chunk)
|
||||
video_latents = torch.cat(noisy_latents_chunks, dim=2)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1
|
||||
num_latent_chunk = max(1, (num_frames + window_num_frames - 1) // window_num_frames)
|
||||
num_history_latent_frames = sum(history_sizes)
|
||||
history_video = None
|
||||
total_generated_latent_frames = 0
|
||||
|
||||
if not keep_first_frame:
|
||||
history_sizes[-1] = history_sizes[-1] + 1
|
||||
history_latents = torch.zeros(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_history_latent_frames,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
if fake_image_latents is not None:
|
||||
history_latents = torch.cat([history_latents[:, :, :-1, :, :], fake_image_latents], dim=2)
|
||||
total_generated_latent_frames += 1
|
||||
if video_latents is not None:
|
||||
history_frames = history_latents.shape[2]
|
||||
video_frames = video_latents.shape[2]
|
||||
if video_frames < history_frames:
|
||||
keep_frames = history_frames - video_frames
|
||||
history_latents = torch.cat([history_latents[:, :, :keep_frames, :, :], video_latents], dim=2)
|
||||
else:
|
||||
history_latents = video_latents
|
||||
total_generated_latent_frames += video_latents.shape[2]
|
||||
|
||||
if keep_first_frame:
|
||||
indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk]))
|
||||
(
|
||||
indices_prefix,
|
||||
indices_latents_history_long,
|
||||
indices_latents_history_mid,
|
||||
indices_latents_history_1x,
|
||||
indices_hidden_states,
|
||||
) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0)
|
||||
indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0)
|
||||
else:
|
||||
indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk]))
|
||||
(
|
||||
indices_latents_history_long,
|
||||
indices_latents_history_mid,
|
||||
indices_latents_history_short,
|
||||
indices_hidden_states,
|
||||
) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0)
|
||||
indices_hidden_states = indices_hidden_states.unsqueeze(0)
|
||||
indices_latents_history_short = indices_latents_history_short.unsqueeze(0)
|
||||
indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0)
|
||||
indices_latents_history_long = indices_latents_history_long.unsqueeze(0)
|
||||
|
||||
# 6. Denoising loop
|
||||
patch_size = self.transformer.config.patch_size
|
||||
image_seq_len = (
|
||||
num_latent_frames_per_chunk
|
||||
* (height // self.vae_scale_factor_spatial)
|
||||
* (width // self.vae_scale_factor_spatial)
|
||||
// (patch_size[0] * patch_size[1] * patch_size[2])
|
||||
)
|
||||
sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
|
||||
for k in range(num_latent_chunk):
|
||||
is_first_chunk = k == 0
|
||||
is_second_chunk = k == 1
|
||||
if keep_first_frame:
|
||||
latents_history_long, latents_history_mid, latents_history_1x = history_latents[
|
||||
:, :, -num_history_latent_frames:
|
||||
].split(history_sizes, dim=2)
|
||||
if image_latents is None and is_first_chunk:
|
||||
latents_prefix = torch.zeros(
|
||||
(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
1,
|
||||
latents_history_1x.shape[-2],
|
||||
latents_history_1x.shape[-1],
|
||||
),
|
||||
device=device,
|
||||
dtype=latents_history_1x.dtype,
|
||||
)
|
||||
else:
|
||||
latents_prefix = image_latents
|
||||
latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2)
|
||||
else:
|
||||
latents_history_long, latents_history_mid, latents_history_short = history_latents[
|
||||
:, :, -num_history_latent_frames:
|
||||
].split(history_sizes, dim=2)
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
window_num_frames,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=None,
|
||||
)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device, sigmas=sigmas, mu=mu)
|
||||
timesteps = self.scheduler.timesteps
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
latents_history_short = latents_history_short.to(transformer_dtype)
|
||||
latents_history_mid = latents_history_mid.to(transformer_dtype)
|
||||
latents_history_long = latents_history_long.to(transformer_dtype)
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
indices_hidden_states=indices_hidden_states,
|
||||
indices_latents_history_short=indices_latents_history_short,
|
||||
indices_latents_history_mid=indices_latents_history_mid,
|
||||
indices_latents_history_long=indices_latents_history_long,
|
||||
latents_history_short=latents_history_short,
|
||||
latents_history_mid=latents_history_mid,
|
||||
latents_history_long=latents_history_long,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
indices_hidden_states=indices_hidden_states,
|
||||
indices_latents_history_short=indices_latents_history_short,
|
||||
indices_latents_history_mid=indices_latents_history_mid,
|
||||
indices_latents_history_long=indices_latents_history_long,
|
||||
latents_history_short=latents_history_short,
|
||||
latents_history_mid=latents_history_mid,
|
||||
latents_history_long=latents_history_long,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
generator=generator,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if keep_first_frame and (
|
||||
(is_first_chunk and image_latents is None) or (is_skip_first_chunk and is_second_chunk)
|
||||
):
|
||||
image_latents = latents[:, :, 0:1, :, :]
|
||||
|
||||
total_generated_latent_frames += latents.shape[2]
|
||||
history_latents = torch.cat([history_latents, latents], dim=2)
|
||||
real_history_latents = history_latents[:, :, -total_generated_latent_frames:]
|
||||
current_latents = (
|
||||
real_history_latents[:, :, -num_latent_frames_per_chunk:].to(vae_dtype) / latents_std
|
||||
+ latents_mean
|
||||
)
|
||||
current_video = self.vae.decode(current_latents, return_dict=False)[0]
|
||||
|
||||
if history_video is None:
|
||||
history_video = current_video
|
||||
else:
|
||||
history_video = torch.cat([history_video, current_video], dim=2)
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if output_type != "latent":
|
||||
generated_frames = history_video.size(2)
|
||||
generated_frames = (
|
||||
generated_frames - 1
|
||||
) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
history_video = history_video[:, :, :generated_frames]
|
||||
video = self.video_processor.postprocess_video(history_video, output_type=output_type)
|
||||
else:
|
||||
video = real_history_latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return HeliosPipelineOutput(frames=video)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,20 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeliosPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Helios pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
from transformers import AutoTokenizer, PreTrainedModel
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnets import ZImageControlNetModel
|
||||
from ...models.transformers import ZImageTransformer2DModel
|
||||
@@ -185,7 +185,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImageControlNetPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
|
||||
class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
@@ -365,7 +365,7 @@ class ZImageControlNetPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSin
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 0
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn.functional as F
|
||||
from transformers import AutoTokenizer, PreTrainedModel
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnets import ZImageControlNetModel
|
||||
from ...models.transformers import ZImageTransformer2DModel
|
||||
@@ -185,7 +185,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImageControlNetInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
|
||||
class ZImageControlNetInpaintPipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
@@ -372,7 +372,7 @@ class ZImageControlNetInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin,
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 0
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
|
||||
@@ -347,7 +347,7 @@ class ZImageImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingle
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 0
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
|
||||
@@ -462,7 +462,7 @@ class ZImageInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingle
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 0
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
|
||||
@@ -339,7 +339,7 @@ class ZImageOmniPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFil
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 0
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
|
||||
@@ -61,8 +61,6 @@ else:
|
||||
_import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"]
|
||||
_import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"]
|
||||
_import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"]
|
||||
_import_structure["scheduling_helios"] = ["HeliosScheduler"]
|
||||
_import_structure["scheduling_helios_dmd"] = ["HeliosDMDScheduler"]
|
||||
_import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"]
|
||||
_import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
|
||||
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
|
||||
@@ -166,8 +164,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
|
||||
from .scheduling_flow_match_lcm import FlowMatchLCMScheduler
|
||||
from .scheduling_helios import HeliosScheduler
|
||||
from .scheduling_helios_dmd import HeliosDMDScheduler
|
||||
from .scheduling_heun_discrete import HeunDiscreteScheduler
|
||||
from .scheduling_ipndm import IPNDMScheduler
|
||||
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
|
||||
|
||||
@@ -1,867 +0,0 @@
|
||||
# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..schedulers.scheduling_utils import SchedulerMixin
|
||||
from ..utils import BaseOutput, deprecate
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeliosSchedulerOutput(BaseOutput):
|
||||
prev_sample: torch.FloatTensor
|
||||
model_outputs: torch.FloatTensor | None = None
|
||||
last_sample: torch.FloatTensor | None = None
|
||||
this_order: int | None = None
|
||||
|
||||
|
||||
class HeliosScheduler(SchedulerMixin, ConfigMixin):
|
||||
_compatibles = []
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
shift: float = 1.0, # Following Stable diffusion 3,
|
||||
stages: int = 3,
|
||||
stage_range: list = [0, 1 / 3, 2 / 3, 1],
|
||||
gamma: float = 1 / 3,
|
||||
# For UniPC
|
||||
thresholding: bool = False,
|
||||
prediction_type: str = "flow_prediction",
|
||||
solver_order: int = 2,
|
||||
predict_x0: bool = True,
|
||||
solver_type: str = "bh2",
|
||||
lower_order_final: bool = True,
|
||||
disable_corrector: list[int] = [],
|
||||
solver_p: SchedulerMixin = None,
|
||||
use_flow_sigmas: bool = True,
|
||||
scheduler_type: str = "unipc", # ["euler", "unipc"]
|
||||
use_dynamic_shifting: bool = False,
|
||||
time_shift_type: Literal["exponential", "linear"] = "exponential",
|
||||
):
|
||||
self.timestep_ratios = {} # The timestep ratio for each stage
|
||||
self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage)
|
||||
self.sigmas_per_stage = {} # always uniform [1000, 0]
|
||||
self.start_sigmas = {} # for start point / upsample renoise
|
||||
self.end_sigmas = {} # for end point
|
||||
self.ori_start_sigmas = {}
|
||||
|
||||
# self.init_sigmas()
|
||||
self.init_sigmas_for_each_stage()
|
||||
self.sigma_min = self.sigmas[-1].item()
|
||||
self.sigma_max = self.sigmas[0].item()
|
||||
self.gamma = gamma
|
||||
|
||||
if solver_type not in ["bh1", "bh2"]:
|
||||
if solver_type in ["midpoint", "heun", "logrho"]:
|
||||
self.register_to_config(solver_type="bh2")
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
||||
|
||||
self.predict_x0 = predict_x0
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.timestep_list = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self.disable_corrector = disable_corrector
|
||||
self.solver_p = solver_p
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def init_sigmas(self):
|
||||
"""
|
||||
initialize the global timesteps and sigmas
|
||||
"""
|
||||
num_train_timesteps = self.config.num_train_timesteps
|
||||
shift = self.config.shift
|
||||
|
||||
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1)
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy()
|
||||
sigmas = torch.from_numpy(sigmas)
|
||||
timesteps = (sigmas * num_train_timesteps).clone()
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.timesteps = timesteps
|
||||
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def init_sigmas_for_each_stage(self):
|
||||
"""
|
||||
Init the timesteps for each stage
|
||||
"""
|
||||
self.init_sigmas()
|
||||
|
||||
stage_distance = []
|
||||
stages = self.config.stages
|
||||
training_steps = self.config.num_train_timesteps
|
||||
stage_range = self.config.stage_range
|
||||
|
||||
# Init the start and end point of each stage
|
||||
for i_s in range(stages):
|
||||
# To decide the start and ends point
|
||||
start_indice = int(stage_range[i_s] * training_steps)
|
||||
start_indice = max(start_indice, 0)
|
||||
end_indice = int(stage_range[i_s + 1] * training_steps)
|
||||
end_indice = min(end_indice, training_steps)
|
||||
start_sigma = self.sigmas[start_indice].item()
|
||||
end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
|
||||
self.ori_start_sigmas[i_s] = start_sigma
|
||||
|
||||
if i_s != 0:
|
||||
ori_sigma = 1 - start_sigma
|
||||
gamma = self.config.gamma
|
||||
corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
|
||||
# corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
|
||||
start_sigma = 1 - corrected_sigma
|
||||
|
||||
stage_distance.append(start_sigma - end_sigma)
|
||||
self.start_sigmas[i_s] = start_sigma
|
||||
self.end_sigmas[i_s] = end_sigma
|
||||
|
||||
# Determine the ratio of each stage according to flow length
|
||||
tot_distance = sum(stage_distance)
|
||||
for i_s in range(stages):
|
||||
if i_s == 0:
|
||||
start_ratio = 0.0
|
||||
else:
|
||||
start_ratio = sum(stage_distance[:i_s]) / tot_distance
|
||||
if i_s == stages - 1:
|
||||
end_ratio = 0.9999999999999999
|
||||
else:
|
||||
end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance
|
||||
|
||||
self.timestep_ratios[i_s] = (start_ratio, end_ratio)
|
||||
|
||||
# Determine the timesteps and sigmas for each stage
|
||||
for i_s in range(stages):
|
||||
timestep_ratio = self.timestep_ratios[i_s]
|
||||
# timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
|
||||
timestep_max = min(self.timesteps[int(timestep_ratio[0] * training_steps)], 999)
|
||||
timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
|
||||
timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1)
|
||||
self.timesteps_per_stage[i_s] = (
|
||||
timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1])
|
||||
)
|
||||
stage_sigmas = np.linspace(0.999, 0, training_steps + 1)
|
||||
self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
stage_index: int | None = None,
|
||||
device: str | torch.device = None,
|
||||
sigmas: bool | None = None,
|
||||
mu: bool | None = None,
|
||||
is_amplify_first_chunk: bool = False,
|
||||
):
|
||||
"""
|
||||
Setting the timesteps and sigmas for each stage
|
||||
"""
|
||||
if self.config.scheduler_type == "dmd":
|
||||
if is_amplify_first_chunk:
|
||||
num_inference_steps = num_inference_steps * 2 + 1
|
||||
else:
|
||||
num_inference_steps = num_inference_steps + 1
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.init_sigmas()
|
||||
|
||||
if self.config.stages == 1:
|
||||
if sigmas is None:
|
||||
sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype(
|
||||
np.float32
|
||||
)
|
||||
if self.config.shift != 1.0:
|
||||
assert not self.config.use_dynamic_shifting
|
||||
sigmas = self.time_shift(self.config.shift, 1.0, sigmas)
|
||||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||
sigmas = torch.from_numpy(sigmas)
|
||||
else:
|
||||
stage_timesteps = self.timesteps_per_stage[stage_index]
|
||||
timesteps = np.linspace(
|
||||
stage_timesteps[0].item(),
|
||||
stage_timesteps[-1].item(),
|
||||
num_inference_steps,
|
||||
)
|
||||
|
||||
stage_sigmas = self.sigmas_per_stage[stage_index]
|
||||
ratios = np.linspace(stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps)
|
||||
sigmas = torch.from_numpy(ratios)
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device)
|
||||
|
||||
self._step_index = None
|
||||
self.reset_scheduler_history()
|
||||
|
||||
if self.config.scheduler_type == "dmd":
|
||||
self.timesteps = self.timesteps[:-1]
|
||||
self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]])
|
||||
|
||||
if self.config.use_dynamic_shifting:
|
||||
assert self.config.shift == 1.0
|
||||
self.sigmas = self.time_shift(mu, 1.0, self.sigmas)
|
||||
if self.config.stages == 1:
|
||||
self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps
|
||||
else:
|
||||
self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * (
|
||||
self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min()
|
||||
)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
"""
|
||||
Apply time shifting to the sigmas.
|
||||
|
||||
Args:
|
||||
mu (`float`):
|
||||
The mu parameter for the time shift.
|
||||
sigma (`float`):
|
||||
The sigma parameter for the time shift.
|
||||
t (`torch.Tensor`):
|
||||
The input timesteps.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The time-shifted timesteps.
|
||||
"""
|
||||
if self.config.time_shift_type == "exponential":
|
||||
return self._time_shift_exponential(mu, sigma, t)
|
||||
elif self.config.time_shift_type == "linear":
|
||||
return self._time_shift_linear(mu, sigma, t)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
|
||||
def _time_shift_exponential(self, mu, sigma, t):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
|
||||
def _time_shift_linear(self, mu, sigma, t):
|
||||
return mu / (mu + (1 / t - 1) ** sigma)
|
||||
|
||||
# ---------------------------------- Euler ----------------------------------
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step_euler(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: float | torch.FloatTensor = None,
|
||||
sample: torch.FloatTensor = None,
|
||||
generator: torch.Generator | None = None,
|
||||
sigma: torch.FloatTensor | None = None,
|
||||
sigma_next: torch.FloatTensor | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> HeliosSchedulerOutput | tuple:
|
||||
assert (sigma is None) == (sigma_next is None), "sigma and sigma_next must both be None or both be not None"
|
||||
|
||||
if sigma is None and sigma_next is None:
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._step_index = 0
|
||||
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
if sigma is None and sigma_next is None:
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return HeliosSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# ---------------------------------- UniPC ----------------------------------
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = torch.clamp(sigma, min=1e-8)
|
||||
else:
|
||||
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
||||
sigma_t = sigma * alpha_t
|
||||
|
||||
return alpha_t, sigma_t
|
||||
|
||||
def convert_model_output(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sigma: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Convert the model output to the corresponding type the UniPC algorithm needs.
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The converted model output.
|
||||
"""
|
||||
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyword argument")
|
||||
if timestep is not None:
|
||||
deprecate(
|
||||
"timesteps",
|
||||
"1.0.0",
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
flag = False
|
||||
if sigma is None:
|
||||
flag = True
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
|
||||
if self.predict_x0:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
if flag:
|
||||
sigma_t = self.sigmas[self.step_index]
|
||||
else:
|
||||
sigma_t = sigma
|
||||
x0_pred = sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
|
||||
"`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
return x0_pred
|
||||
else:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
return model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction` for the UniPCMultistepScheduler."
|
||||
)
|
||||
|
||||
def multistep_uni_p_bh_update(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
order: int = None,
|
||||
sigma: torch.Tensor = None,
|
||||
sigma_next: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model at the current timestep.
|
||||
prev_timestep (`int`):
|
||||
The previous discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
order (`int`):
|
||||
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
||||
if sample is None:
|
||||
if len(args) > 1:
|
||||
sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `sample` as a required keyword argument")
|
||||
if order is None:
|
||||
if len(args) > 2:
|
||||
order = args[2]
|
||||
else:
|
||||
raise ValueError("missing `order` as a required keyword argument")
|
||||
if prev_timestep is not None:
|
||||
deprecate(
|
||||
"prev_timestep",
|
||||
"1.0.0",
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
s0 = self.timestep_list[-1]
|
||||
m0 = model_output_list[-1]
|
||||
x = sample
|
||||
|
||||
if self.solver_p:
|
||||
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
||||
return x_t
|
||||
|
||||
if sigma_next is None and sigma is None:
|
||||
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
else:
|
||||
sigma_t, sigma_s0 = sigma_next, sigma
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
device = sample.device
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = self.step_index - i
|
||||
mi = model_output_list[-(i + 1)]
|
||||
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
||||
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
|
||||
rks.append(1.0)
|
||||
rks = torch.tensor(rks, device=device)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.config.solver_type == "bh1":
|
||||
B_h = hh
|
||||
elif self.config.solver_type == "bh2":
|
||||
B_h = torch.expm1(hh)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(torch.pow(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= i + 1
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.tensor(b, device=device)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
if self.predict_x0:
|
||||
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - alpha_t * B_h * pred_res
|
||||
else:
|
||||
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - sigma_t * B_h * pred_res
|
||||
|
||||
x_t = x_t.to(x.dtype)
|
||||
return x_t
|
||||
|
||||
def multistep_uni_c_bh_update(
|
||||
self,
|
||||
this_model_output: torch.Tensor,
|
||||
*args,
|
||||
last_sample: torch.Tensor = None,
|
||||
this_sample: torch.Tensor = None,
|
||||
order: int = None,
|
||||
sigma_before: torch.Tensor = None,
|
||||
sigma: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
One step for the UniC (B(h) version).
|
||||
|
||||
Args:
|
||||
this_model_output (`torch.Tensor`):
|
||||
The model outputs at `x_t`.
|
||||
this_timestep (`int`):
|
||||
The current timestep `t`.
|
||||
last_sample (`torch.Tensor`):
|
||||
The generated sample before the last predictor `x_{t-1}`.
|
||||
this_sample (`torch.Tensor`):
|
||||
The generated sample after the last predictor `x_{t}`.
|
||||
order (`int`):
|
||||
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The corrected sample tensor at the current timestep.
|
||||
"""
|
||||
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
|
||||
if last_sample is None:
|
||||
if len(args) > 1:
|
||||
last_sample = args[1]
|
||||
else:
|
||||
raise ValueError("missing `last_sample` as a required keyword argument")
|
||||
if this_sample is None:
|
||||
if len(args) > 2:
|
||||
this_sample = args[2]
|
||||
else:
|
||||
raise ValueError("missing `this_sample` as a required keyword argument")
|
||||
if order is None:
|
||||
if len(args) > 3:
|
||||
order = args[3]
|
||||
else:
|
||||
raise ValueError("missing `order` as a required keyword argument")
|
||||
if this_timestep is not None:
|
||||
deprecate(
|
||||
"this_timestep",
|
||||
"1.0.0",
|
||||
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
model_output_list = self.model_outputs
|
||||
|
||||
m0 = model_output_list[-1]
|
||||
x = last_sample
|
||||
x_t = this_sample
|
||||
model_t = this_model_output
|
||||
|
||||
if sigma_before is None and sigma is None:
|
||||
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
|
||||
else:
|
||||
sigma_t, sigma_s0 = sigma, sigma_before
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
||||
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
||||
|
||||
h = lambda_t - lambda_s0
|
||||
device = this_sample.device
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
si = self.step_index - (i + 1)
|
||||
mi = model_output_list[-(i + 1)]
|
||||
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
||||
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
||||
rk = (lambda_si - lambda_s0) / h
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
|
||||
rks.append(1.0)
|
||||
rks = torch.tensor(rks, device=device)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h if self.predict_x0 else h
|
||||
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.config.solver_type == "bh1":
|
||||
B_h = hh
|
||||
elif self.config.solver_type == "bh2":
|
||||
B_h = torch.expm1(hh)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(torch.pow(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= i + 1
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.tensor(b, device=device)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1)
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
# for order 1, we use a simplified version
|
||||
if order == 1:
|
||||
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
||||
else:
|
||||
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
|
||||
|
||||
if self.predict_x0:
|
||||
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = model_t - m0
|
||||
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
||||
else:
|
||||
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = model_t - m0
|
||||
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
||||
x_t = x_t.to(x.dtype)
|
||||
return x_t
|
||||
|
||||
def step_unipc(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int | torch.Tensor = None,
|
||||
sample: torch.Tensor = None,
|
||||
return_dict: bool = True,
|
||||
model_outputs: list = None,
|
||||
timestep_list: list = None,
|
||||
sigma_before: torch.Tensor = None,
|
||||
sigma: torch.Tensor = None,
|
||||
sigma_next: torch.Tensor = None,
|
||||
cus_step_index: int = None,
|
||||
cus_lower_order_num: int = None,
|
||||
cus_this_order: int = None,
|
||||
cus_last_sample: torch.Tensor = None,
|
||||
) -> HeliosSchedulerOutput | tuple:
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
if cus_step_index is None:
|
||||
if self.step_index is None:
|
||||
self._step_index = 0
|
||||
else:
|
||||
self._step_index = cus_step_index
|
||||
|
||||
if cus_lower_order_num is not None:
|
||||
self.lower_order_nums = cus_lower_order_num
|
||||
|
||||
if cus_this_order is not None:
|
||||
self.this_order = cus_this_order
|
||||
|
||||
if cus_last_sample is not None:
|
||||
self.last_sample = cus_last_sample
|
||||
|
||||
use_corrector = (
|
||||
self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
|
||||
)
|
||||
|
||||
# Convert model output using the proper conversion method
|
||||
model_output_convert = self.convert_model_output(model_output, sample=sample, sigma=sigma)
|
||||
|
||||
if model_outputs is not None and timestep_list is not None:
|
||||
self.model_outputs = model_outputs[:-1]
|
||||
self.timestep_list = timestep_list[:-1]
|
||||
|
||||
if use_corrector:
|
||||
sample = self.multistep_uni_c_bh_update(
|
||||
this_model_output=model_output_convert,
|
||||
last_sample=self.last_sample,
|
||||
this_sample=sample,
|
||||
order=self.this_order,
|
||||
sigma_before=sigma_before,
|
||||
sigma=sigma,
|
||||
)
|
||||
|
||||
if model_outputs is not None and timestep_list is not None:
|
||||
model_outputs[-1] = model_output_convert
|
||||
self.model_outputs = model_outputs[1:]
|
||||
self.timestep_list = timestep_list[1:]
|
||||
else:
|
||||
for i in range(self.config.solver_order - 1):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.timestep_list[i] = self.timestep_list[i + 1]
|
||||
self.model_outputs[-1] = model_output_convert
|
||||
self.timestep_list[-1] = timestep
|
||||
|
||||
if self.config.lower_order_final:
|
||||
this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
|
||||
else:
|
||||
this_order = self.config.solver_order
|
||||
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
|
||||
assert self.this_order > 0
|
||||
|
||||
self.last_sample = sample
|
||||
prev_sample = self.multistep_uni_p_bh_update(
|
||||
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
||||
sample=sample,
|
||||
order=self.this_order,
|
||||
sigma=sigma,
|
||||
sigma_next=sigma_next,
|
||||
)
|
||||
|
||||
if cus_lower_order_num is None:
|
||||
if self.lower_order_nums < self.config.solver_order:
|
||||
self.lower_order_nums += 1
|
||||
|
||||
# upon completion increase step index by one
|
||||
if cus_step_index is None:
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, model_outputs, self.last_sample, self.this_order)
|
||||
|
||||
return HeliosSchedulerOutput(
|
||||
prev_sample=prev_sample,
|
||||
model_outputs=model_outputs,
|
||||
last_sample=self.last_sample,
|
||||
this_order=self.this_order,
|
||||
)
|
||||
|
||||
# ---------------------------------- Merge ----------------------------------
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: float | torch.FloatTensor = None,
|
||||
sample: torch.FloatTensor = None,
|
||||
generator: torch.Generator | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> HeliosSchedulerOutput | tuple:
|
||||
if self.config.scheduler_type == "euler":
|
||||
return self.step_euler(
|
||||
model_output=model_output,
|
||||
timestep=timestep,
|
||||
sample=sample,
|
||||
generator=generator,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
elif self.config.scheduler_type == "unipc":
|
||||
return self.step_unipc(
|
||||
model_output=model_output,
|
||||
timestep=timestep,
|
||||
sample=sample,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_scheduler_history(self):
|
||||
self.model_outputs = [None] * self.config.solver_order
|
||||
self.timestep_list = [None] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
self.disable_corrector = self.config.disable_corrector
|
||||
self.solver_p = self.config.solver_p
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -1,331 +0,0 @@
|
||||
# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..schedulers.scheduling_utils import SchedulerMixin
|
||||
from ..utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeliosDMDSchedulerOutput(BaseOutput):
|
||||
prev_sample: torch.FloatTensor
|
||||
model_outputs: torch.FloatTensor | None = None
|
||||
last_sample: torch.FloatTensor | None = None
|
||||
this_order: int | None = None
|
||||
|
||||
|
||||
class HeliosDMDScheduler(SchedulerMixin, ConfigMixin):
|
||||
_compatibles = []
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
shift: float = 1.0, # Following Stable diffusion 3,
|
||||
stages: int = 3,
|
||||
stage_range: list = [0, 1 / 3, 2 / 3, 1],
|
||||
gamma: float = 1 / 3,
|
||||
prediction_type: str = "flow_prediction",
|
||||
use_flow_sigmas: bool = True,
|
||||
use_dynamic_shifting: bool = False,
|
||||
time_shift_type: Literal["exponential", "linear"] = "linear",
|
||||
):
|
||||
self.timestep_ratios = {} # The timestep ratio for each stage
|
||||
self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage)
|
||||
self.sigmas_per_stage = {} # always uniform [1000, 0]
|
||||
self.start_sigmas = {} # for start point / upsample renoise
|
||||
self.end_sigmas = {} # for end point
|
||||
self.ori_start_sigmas = {}
|
||||
|
||||
# self.init_sigmas()
|
||||
self.init_sigmas_for_each_stage()
|
||||
self.sigma_min = self.sigmas[-1].item()
|
||||
self.sigma_max = self.sigmas[0].item()
|
||||
self.gamma = gamma
|
||||
|
||||
self.last_sample = None
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def init_sigmas(self):
|
||||
"""
|
||||
initialize the global timesteps and sigmas
|
||||
"""
|
||||
num_train_timesteps = self.config.num_train_timesteps
|
||||
shift = self.config.shift
|
||||
|
||||
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1)
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy()
|
||||
sigmas = torch.from_numpy(sigmas)
|
||||
timesteps = (sigmas * num_train_timesteps).clone()
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.timesteps = timesteps
|
||||
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def init_sigmas_for_each_stage(self):
|
||||
"""
|
||||
Init the timesteps for each stage
|
||||
"""
|
||||
self.init_sigmas()
|
||||
|
||||
stage_distance = []
|
||||
stages = self.config.stages
|
||||
training_steps = self.config.num_train_timesteps
|
||||
stage_range = self.config.stage_range
|
||||
|
||||
# Init the start and end point of each stage
|
||||
for i_s in range(stages):
|
||||
# To decide the start and ends point
|
||||
start_indice = int(stage_range[i_s] * training_steps)
|
||||
start_indice = max(start_indice, 0)
|
||||
end_indice = int(stage_range[i_s + 1] * training_steps)
|
||||
end_indice = min(end_indice, training_steps)
|
||||
start_sigma = self.sigmas[start_indice].item()
|
||||
end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
|
||||
self.ori_start_sigmas[i_s] = start_sigma
|
||||
|
||||
if i_s != 0:
|
||||
ori_sigma = 1 - start_sigma
|
||||
gamma = self.config.gamma
|
||||
corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
|
||||
# corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
|
||||
start_sigma = 1 - corrected_sigma
|
||||
|
||||
stage_distance.append(start_sigma - end_sigma)
|
||||
self.start_sigmas[i_s] = start_sigma
|
||||
self.end_sigmas[i_s] = end_sigma
|
||||
|
||||
# Determine the ratio of each stage according to flow length
|
||||
tot_distance = sum(stage_distance)
|
||||
for i_s in range(stages):
|
||||
if i_s == 0:
|
||||
start_ratio = 0.0
|
||||
else:
|
||||
start_ratio = sum(stage_distance[:i_s]) / tot_distance
|
||||
if i_s == stages - 1:
|
||||
end_ratio = 0.9999999999999999
|
||||
else:
|
||||
end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance
|
||||
|
||||
self.timestep_ratios[i_s] = (start_ratio, end_ratio)
|
||||
|
||||
# Determine the timesteps and sigmas for each stage
|
||||
for i_s in range(stages):
|
||||
timestep_ratio = self.timestep_ratios[i_s]
|
||||
# timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
|
||||
timestep_max = min(self.timesteps[int(timestep_ratio[0] * training_steps)], 999)
|
||||
timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
|
||||
timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1)
|
||||
self.timesteps_per_stage[i_s] = (
|
||||
timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1])
|
||||
)
|
||||
stage_sigmas = np.linspace(0.999, 0, training_steps + 1)
|
||||
self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
stage_index: int | None = None,
|
||||
device: str | torch.device = None,
|
||||
sigmas: bool | None = None,
|
||||
mu: bool | None = None,
|
||||
is_amplify_first_chunk: bool = False,
|
||||
):
|
||||
"""
|
||||
Setting the timesteps and sigmas for each stage
|
||||
"""
|
||||
if is_amplify_first_chunk:
|
||||
num_inference_steps = num_inference_steps * 2 + 1
|
||||
else:
|
||||
num_inference_steps = num_inference_steps + 1
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.init_sigmas()
|
||||
|
||||
if self.config.stages == 1:
|
||||
if sigmas is None:
|
||||
sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype(
|
||||
np.float32
|
||||
)
|
||||
if self.config.shift != 1.0:
|
||||
assert not self.config.use_dynamic_shifting
|
||||
sigmas = self.time_shift(self.config.shift, 1.0, sigmas)
|
||||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||
sigmas = torch.from_numpy(sigmas)
|
||||
else:
|
||||
stage_timesteps = self.timesteps_per_stage[stage_index]
|
||||
timesteps = np.linspace(
|
||||
stage_timesteps[0].item(),
|
||||
stage_timesteps[-1].item(),
|
||||
num_inference_steps,
|
||||
)
|
||||
|
||||
stage_sigmas = self.sigmas_per_stage[stage_index]
|
||||
ratios = np.linspace(stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps)
|
||||
sigmas = torch.from_numpy(ratios)
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device)
|
||||
|
||||
self._step_index = None
|
||||
self.reset_scheduler_history()
|
||||
|
||||
self.timesteps = self.timesteps[:-1]
|
||||
self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]])
|
||||
|
||||
if self.config.use_dynamic_shifting:
|
||||
assert self.config.shift == 1.0
|
||||
self.sigmas = self.time_shift(mu, 1.0, self.sigmas)
|
||||
if self.config.stages == 1:
|
||||
self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps
|
||||
else:
|
||||
self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * (
|
||||
self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min()
|
||||
)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
"""
|
||||
Apply time shifting to the sigmas.
|
||||
|
||||
Args:
|
||||
mu (`float`):
|
||||
The mu parameter for the time shift.
|
||||
sigma (`float`):
|
||||
The sigma parameter for the time shift.
|
||||
t (`torch.Tensor`):
|
||||
The input timesteps.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The time-shifted timesteps.
|
||||
"""
|
||||
if self.config.time_shift_type == "exponential":
|
||||
return self._time_shift_exponential(mu, sigma, t)
|
||||
elif self.config.time_shift_type == "linear":
|
||||
return self._time_shift_linear(mu, sigma, t)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
|
||||
def _time_shift_exponential(self, mu, sigma, t):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
|
||||
def _time_shift_linear(self, mu, sigma, t):
|
||||
return mu / (mu + (1 / t - 1) ** sigma)
|
||||
|
||||
# ---------------------------------- For DMD ----------------------------------
|
||||
def add_noise(self, original_samples, noise, timestep, sigmas, timesteps):
|
||||
sigmas = sigmas.to(noise.device)
|
||||
timesteps = timesteps.to(noise.device)
|
||||
timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
|
||||
sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1)
|
||||
sample = (1 - sigma) * original_samples + sigma * noise
|
||||
return sample.type_as(noise)
|
||||
|
||||
def convert_flow_pred_to_x0(self, flow_pred, xt, timestep, sigmas, timesteps):
|
||||
# use higher precision for calculations
|
||||
original_dtype = flow_pred.dtype
|
||||
device = flow_pred.device
|
||||
flow_pred, xt, sigmas, timesteps = (x.double().to(device) for x in (flow_pred, xt, sigmas, timesteps))
|
||||
|
||||
timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
|
||||
sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1)
|
||||
x0_pred = xt - sigma_t * flow_pred
|
||||
return x0_pred.to(original_dtype)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: float | torch.FloatTensor = None,
|
||||
sample: torch.FloatTensor = None,
|
||||
generator: torch.Generator | None = None,
|
||||
return_dict: bool = True,
|
||||
cur_sampling_step: int = 0,
|
||||
dmd_noisy_tensor: torch.FloatTensor | None = None,
|
||||
dmd_sigmas: torch.FloatTensor | None = None,
|
||||
dmd_timesteps: torch.FloatTensor | None = None,
|
||||
all_timesteps: torch.FloatTensor | None = None,
|
||||
) -> HeliosDMDSchedulerOutput | tuple:
|
||||
pred_image_or_video = self.convert_flow_pred_to_x0(
|
||||
flow_pred=model_output,
|
||||
xt=sample,
|
||||
timestep=torch.full((model_output.shape[0],), timestep, dtype=torch.long, device=model_output.device),
|
||||
sigmas=dmd_sigmas,
|
||||
timesteps=dmd_timesteps,
|
||||
)
|
||||
if cur_sampling_step < len(all_timesteps) - 1:
|
||||
prev_sample = self.add_noise(
|
||||
pred_image_or_video,
|
||||
dmd_noisy_tensor,
|
||||
torch.full(
|
||||
(model_output.shape[0],),
|
||||
all_timesteps[cur_sampling_step + 1],
|
||||
dtype=torch.long,
|
||||
device=model_output.device,
|
||||
),
|
||||
sigmas=dmd_sigmas,
|
||||
timesteps=dmd_timesteps,
|
||||
)
|
||||
else:
|
||||
prev_sample = pred_image_or_video
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return HeliosDMDSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def reset_scheduler_history(self):
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -31,18 +31,14 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
trained_betas (`np.ndarray` or `List[float]`, *optional*):
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
trained_betas: np.ndarray | list[float] | None = None,
|
||||
):
|
||||
def __init__(self, num_train_timesteps: int = 1000, trained_betas: np.ndarray | list[float] | None = None):
|
||||
# set `betas`, `alphas`, `timesteps`
|
||||
self.set_timesteps(num_train_timesteps)
|
||||
|
||||
@@ -60,29 +56,21 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = None
|
||||
|
||||
@property
|
||||
def step_index(self) -> int | None:
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The index counter for current timestep.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> int | None:
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The index for the first timestep.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
@@ -181,7 +169,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
@@ -240,30 +228,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_prev_sample(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep_index: int,
|
||||
prev_timestep_index: int,
|
||||
ets: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Predicts the previous sample based on the current sample, timestep indices, and running model outputs.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The current sample.
|
||||
timestep_index (`int`):
|
||||
Index of the current timestep in the schedule.
|
||||
prev_timestep_index (`int`):
|
||||
Index of the previous timestep in the schedule.
|
||||
ets (`torch.Tensor`):
|
||||
The running sequence of model outputs.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The predicted previous sample.
|
||||
"""
|
||||
def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
|
||||
alpha = self.alphas[timestep_index]
|
||||
sigma = self.betas[timestep_index]
|
||||
|
||||
@@ -275,5 +240,5 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return prev_sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -656,21 +656,6 @@ class AutoencoderOobleck(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderRAE(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderTiny(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1046,21 +1031,6 @@ class GlmImageTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HeliosTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HiDreamImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -2773,36 +2743,6 @@ class FlowMatchLCMScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HeliosDMDScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HeliosScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HeunDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1352,36 +1352,6 @@ class GlmImagePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HeliosPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HeliosPyramidPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HiDreamImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -107,7 +107,6 @@ def load_or_create_model_card(
|
||||
widget: list[dict] | None = None,
|
||||
inference: bool | None = None,
|
||||
is_modular: bool = False,
|
||||
update_model_card: bool = False,
|
||||
) -> ModelCard:
|
||||
"""
|
||||
Loads or creates a model card.
|
||||
@@ -134,9 +133,6 @@ def load_or_create_model_card(
|
||||
`load_or_create_model_card` from a training script.
|
||||
is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline.
|
||||
When True, uses model_description as-is without additional template formatting.
|
||||
update_model_card: (`bool`, optional): When True, regenerates the model card content even if one
|
||||
already exists on the remote repo. Existing card metadata (tags, license, etc.) is preserved. Only
|
||||
supported for modular pipelines (i.e., `is_modular=True`).
|
||||
"""
|
||||
if not is_jinja_available():
|
||||
raise ValueError(
|
||||
@@ -145,17 +141,9 @@ def load_or_create_model_card(
|
||||
" To install it, please run `pip install Jinja2`."
|
||||
)
|
||||
|
||||
if update_model_card and not is_modular:
|
||||
raise ValueError("`update_model_card=True` is only supported for modular pipelines (`is_modular=True`).")
|
||||
|
||||
try:
|
||||
# Check if the model card is present on the remote repo
|
||||
model_card = ModelCard.load(repo_id_or_path, token=token)
|
||||
# For modular pipelines, regenerate card content when requested (preserve existing metadata)
|
||||
if update_model_card and is_modular and model_description is not None:
|
||||
existing_data = model_card.data
|
||||
model_card = ModelCard(model_description)
|
||||
model_card.data = existing_data
|
||||
except (EntryNotFoundError, RepositoryNotFoundError):
|
||||
# Otherwise create a model card from template
|
||||
if from_training:
|
||||
|
||||
14
test_automodel_meta.py
Normal file
14
test_automodel_meta.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
|
||||
repo = "meituan-longcat/LongCat-Image"
|
||||
subfolder = "transformer"
|
||||
|
||||
config = AutoModel.load_config(repo, subfolder=subfolder)
|
||||
|
||||
with torch.device("meta"):
|
||||
model = AutoModel.from_config(config)
|
||||
print(f"model.config:")
|
||||
for k, v in dict(model.config).items():
|
||||
if not k.startswith("_"):
|
||||
print(f" {k}: {v}")
|
||||
11
test_dataclass_config.py
Normal file
11
test_dataclass_config.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import dataclasses
|
||||
from diffusers import AutoModel, LongCatImageTransformer2DModel
|
||||
|
||||
config_dict = AutoModel.load_config(
|
||||
"meituan-longcat/LongCat-Image",
|
||||
subfolder="transformer",
|
||||
)
|
||||
# import DiT based on _class_name
|
||||
typed_config = LongCatImageTransformer2DModel._get_dataclass_from_config(config_dict)
|
||||
for f in dataclasses.fields(typed_config):
|
||||
print(f"{f.name}: {f.type}")
|
||||
29
test_pretrained_config.py
Normal file
29
test_pretrained_config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import dataclasses
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models import AutoModel
|
||||
|
||||
repo = "black-forest-labs/FLUX.2-dev"
|
||||
subfolder = "transformer"
|
||||
|
||||
print("=== From load_config (no model instantiation) ===")
|
||||
config_dict = FluxTransformer2DModel.load_config(repo, subfolder=subfolder)
|
||||
tc = FluxTransformer2DModel._get_dataclass_from_config(config_dict)
|
||||
print(f"Type: {type(tc).__name__}")
|
||||
for k, v in dataclasses.asdict(tc).items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
print()
|
||||
print("=== From AutoModel.from_config on meta device ===")
|
||||
with torch.device("meta"):
|
||||
model = AutoModel.from_config(repo, subfolder=subfolder)
|
||||
print(f"model.config:")
|
||||
for k, v in dict(model.config).items():
|
||||
if not k.startswith("_"):
|
||||
print(f" {k}: {v}")
|
||||
|
||||
print()
|
||||
print("=== Comparison ===")
|
||||
dc_dict = dataclasses.asdict(tc)
|
||||
config = {k: v for k, v in dict(model.config).items() if not k.startswith("_")}
|
||||
print(f"Match: {dc_dict == config}")
|
||||
@@ -566,127 +566,3 @@ class GroupOffloadTests(unittest.TestCase):
|
||||
"layers_per_block": 1,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
|
||||
# Model with conditionally-executed modules, simulating Helios patch_short/patch_mid/patch_long behavior.
|
||||
# These modules are only called when optional inputs are provided, which means the lazy prefetch
|
||||
# execution order tracer may not see them on the first forward pass. This can cause a device mismatch
|
||||
# on subsequent calls when the modules ARE invoked but their weights were never onloaded.
|
||||
# See: https://github.com/huggingface/diffusers/pull/13211
|
||||
class DummyModelWithConditionalModules(ModelMixin):
|
||||
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
|
||||
)
|
||||
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
# These modules are only invoked when optional_input is not None.
|
||||
# Output dimension matches hidden_features so they can be added after linear_1.
|
||||
self.optional_proj_1 = torch.nn.Linear(in_features, hidden_features)
|
||||
self.optional_proj_2 = torch.nn.Linear(in_features, hidden_features)
|
||||
|
||||
def forward(self, x: torch.Tensor, optional_input: torch.Tensor | None = None) -> torch.Tensor:
|
||||
x = self.linear_1(x)
|
||||
x = self.activation(x)
|
||||
if optional_input is not None:
|
||||
# Add optional projections after linear_1 so dimensions match (both hidden_features)
|
||||
x = x + self.optional_proj_1(optional_input)
|
||||
x = x + self.optional_proj_2(optional_input)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConditionalModuleGroupOffloadTests(GroupOffloadTests):
|
||||
"""Tests for conditionally-executed modules under group offloading with streams.
|
||||
|
||||
Regression tests for the case where a module is not executed during the first forward pass
|
||||
(when the lazy prefetch execution order is traced), but IS executed on subsequent passes.
|
||||
Without the fix, the weights of such modules remain on CPU while the input is on GPU,
|
||||
causing a RuntimeError about tensor device mismatch.
|
||||
"""
|
||||
|
||||
def get_model(self):
|
||||
torch.manual_seed(0)
|
||||
return DummyModelWithConditionalModules(
|
||||
in_features=self.in_features,
|
||||
hidden_features=self.hidden_features,
|
||||
out_features=self.out_features,
|
||||
num_layers=self.num_layers,
|
||||
)
|
||||
|
||||
@parameterized.expand([("leaf_level",), ("block_level",)])
|
||||
@unittest.skipIf(
|
||||
torch.device(torch_device).type not in ["cuda", "xpu"],
|
||||
"Test requires a CUDA or XPU device.",
|
||||
)
|
||||
def test_conditional_modules_with_stream(self, offload_type: str):
|
||||
"""Regression test: conditionally-executed modules must not cause device mismatch when using streams.
|
||||
|
||||
The model contains two optional Linear layers (optional_proj_1, optional_proj_2) that are only
|
||||
executed when `optional_input` is provided. This simulates modules like patch_short/patch_mid/
|
||||
patch_long in HeliosTransformer3DModel, which are only called when history latents are present.
|
||||
|
||||
When using streams, `LazyPrefetchGroupOffloadingHook` traces the execution order on the first
|
||||
forward pass and sets up a prefetch chain so each module pre-loads the next one's weights.
|
||||
Modules not executed during this tracing pass are excluded from the prefetch chain.
|
||||
|
||||
The bug: if a module was absent from the first (tracing) pass, its `onload_self` flag gets set
|
||||
to False (meaning "someone else will onload me"). But since it's not in the prefetch chain,
|
||||
nobody ever does — so its weights remain on CPU. When the module is eventually called in a
|
||||
subsequent pass, the input is on GPU but the weights are on CPU, causing a RuntimeError.
|
||||
|
||||
We therefore must invoke the model multiple times:
|
||||
1. First pass WITHOUT optional_input: triggers the lazy prefetch tracing. optional_proj_1/2
|
||||
are absent, so they are excluded from the prefetch chain.
|
||||
2. Second pass WITH optional_input: the regression case. Without the fix, this raises a
|
||||
RuntimeError because optional_proj_1/2 weights are still on CPU.
|
||||
3. Third pass WITHOUT optional_input: verifies the model remains stable after having seen
|
||||
both code paths.
|
||||
"""
|
||||
|
||||
model = self.get_model()
|
||||
model_ref = self.get_model()
|
||||
model_ref.load_state_dict(model.state_dict(), strict=True)
|
||||
model_ref.to(torch_device)
|
||||
|
||||
model.enable_group_offload(
|
||||
torch_device,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=1,
|
||||
use_stream=True,
|
||||
)
|
||||
|
||||
x = torch.randn(4, self.in_features).to(torch_device)
|
||||
optional_input = torch.randn(4, self.in_features).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
# First forward pass WITHOUT optional_input — this is when the lazy prefetch
|
||||
# execution order is traced. optional_proj_1/2 are NOT in the traced order.
|
||||
out_ref_no_opt = model_ref(x, optional_input=None)
|
||||
out_no_opt = model(x, optional_input=None)
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5),
|
||||
f"[{offload_type}] Outputs do not match on first pass (no optional_input).",
|
||||
)
|
||||
|
||||
# Second forward pass WITH optional_input — optional_proj_1/2 ARE now called.
|
||||
out_ref_with_opt = model_ref(x, optional_input=optional_input)
|
||||
out_with_opt = model(x, optional_input=optional_input)
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5),
|
||||
f"[{offload_type}] Outputs do not match on second pass (with optional_input).",
|
||||
)
|
||||
|
||||
# Third pass again without optional_input — verify stable behavior.
|
||||
out_ref_no_opt2 = model_ref(x, optional_input=None)
|
||||
out_no_opt2 = model(x, optional_input=None)
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5),
|
||||
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input).",
|
||||
)
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, HeliosPipeline, HeliosTransformer3DModel
|
||||
|
||||
from ..testing_utils import floats_tensor, require_peft_backend, skip_mps
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class HeliosLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = HeliosPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"text_dim": 32,
|
||||
"freq_dim": 256,
|
||||
"ffn_dim": 32,
|
||||
"num_layers": 2,
|
||||
"cross_attn_norm": True,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_dim": (4, 4, 4),
|
||||
"has_multi_term_memory_patch": True,
|
||||
"guidance_cross_attn": True,
|
||||
"zero_history_timestep": True,
|
||||
"is_amplify_history": False,
|
||||
}
|
||||
transformer_cls = HeliosTransformer3DModel
|
||||
vae_kwargs = {
|
||||
"base_dim": 3,
|
||||
"z_dim": 16,
|
||||
"dim_mult": [1, 1, 1, 1],
|
||||
"num_res_blocks": 1,
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
vae_cls = AutoencoderKLWan
|
||||
has_two_text_encoders = True
|
||||
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
|
||||
text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
supports_text_encoder_loras = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 33, 32, 32, 3)
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 16
|
||||
num_channels = 4
|
||||
num_frames = 9
|
||||
num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
|
||||
sizes = (4, 4)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "",
|
||||
"num_frames": num_frames,
|
||||
"num_inference_steps": 1,
|
||||
"guidance_scale": 6.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": sequence_length,
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in Helios.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Helios.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Helios.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
@@ -1,300 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms.functional import to_tensor
|
||||
|
||||
import diffusers.models.autoencoders.autoencoder_rae as _rae_module
|
||||
from diffusers.models.autoencoders.autoencoder_rae import (
|
||||
_ENCODER_FORWARD_FNS,
|
||||
AutoencoderRAE,
|
||||
_build_encoder,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..testing_utils import BaseModelTesterConfig, ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tiny test encoder for fast unit tests (no transformers dependency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _TinyTestEncoderModule(torch.nn.Module):
|
||||
"""Minimal encoder that mimics the patch-token interface without any HF model."""
|
||||
|
||||
def __init__(self, hidden_size: int = 16, patch_size: int = 8, **kwargs):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
pooled = F.avg_pool2d(images.mean(dim=1, keepdim=True), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
tokens = pooled.flatten(2).transpose(1, 2).contiguous()
|
||||
return tokens.repeat(1, 1, self.hidden_size)
|
||||
|
||||
|
||||
def _tiny_test_encoder_forward(model, images):
|
||||
return model(images)
|
||||
|
||||
|
||||
def _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers):
|
||||
return _TinyTestEncoderModule(hidden_size=hidden_size, patch_size=patch_size)
|
||||
|
||||
|
||||
# Monkey-patch the dispatch tables so "tiny_test" is recognised by AutoencoderRAE
|
||||
_ENCODER_FORWARD_FNS["tiny_test"] = _tiny_test_encoder_forward
|
||||
_original_build_encoder = _build_encoder
|
||||
|
||||
|
||||
def _patched_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers):
|
||||
if encoder_type == "tiny_test":
|
||||
return _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers)
|
||||
return _original_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers)
|
||||
|
||||
|
||||
_rae_module._build_encoder = _patched_build_encoder
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AutoencoderRAETesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderRAE
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 16, 16)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"encoder_type": "tiny_test",
|
||||
"encoder_hidden_size": 16,
|
||||
"encoder_patch_size": 8,
|
||||
"encoder_input_size": 32,
|
||||
"patch_size": 4,
|
||||
"image_size": 16,
|
||||
"decoder_hidden_size": 32,
|
||||
"decoder_num_hidden_layers": 1,
|
||||
"decoder_num_attention_heads": 4,
|
||||
"decoder_intermediate_size": 64,
|
||||
"num_channels": 3,
|
||||
"encoder_norm_mean": [0.5, 0.5, 0.5],
|
||||
"encoder_norm_std": [0.5, 0.5, 0.5],
|
||||
"noise_tau": 0.0,
|
||||
"reshape_to_2d": True,
|
||||
"scaling_factor": 1.0,
|
||||
}
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {"sample": torch.randn(2, 3, 32, 32, generator=self.generator, device="cpu").to(torch_device)}
|
||||
|
||||
# Bridge for AutoencoderTesterMixin which still uses the old interface
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return self.get_init_dict(), self.get_dummy_inputs()
|
||||
|
||||
def _make_model(self, **overrides) -> AutoencoderRAE:
|
||||
config = self.get_init_dict()
|
||||
config.update(overrides)
|
||||
return AutoencoderRAE(**config).to(torch_device)
|
||||
|
||||
|
||||
class TestAutoEncoderRAE(AutoencoderRAETesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for AutoencoderRAE."""
|
||||
|
||||
@pytest.mark.skip(reason="AutoencoderRAE does not support torch dynamo yet")
|
||||
def test_from_save_pretrained_dynamo(self): ...
|
||||
|
||||
def test_fast_encode_decode_and_forward_shapes(self):
|
||||
model = self._make_model().eval()
|
||||
x = torch.rand(2, 3, 32, 32, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
z = model.encode(x).latent
|
||||
decoded = model.decode(z).sample
|
||||
recon = model(x).sample
|
||||
|
||||
assert z.shape == (2, 16, 4, 4)
|
||||
assert decoded.shape == (2, 3, 16, 16)
|
||||
assert recon.shape == (2, 3, 16, 16)
|
||||
assert torch.isfinite(recon).all().item()
|
||||
|
||||
def test_fast_scaling_factor_encode_and_decode_consistency(self):
|
||||
torch.manual_seed(0)
|
||||
model_base = self._make_model(scaling_factor=1.0).eval()
|
||||
torch.manual_seed(0)
|
||||
model_scaled = self._make_model(scaling_factor=2.0).eval()
|
||||
|
||||
x = torch.rand(2, 3, 32, 32, device=torch_device)
|
||||
with torch.no_grad():
|
||||
z_base = model_base.encode(x).latent
|
||||
z_scaled = model_scaled.encode(x).latent
|
||||
recon_base = model_base.decode(z_base).sample
|
||||
recon_scaled = model_scaled.decode(z_scaled).sample
|
||||
|
||||
assert torch.allclose(z_scaled, z_base * 2.0, atol=1e-5, rtol=1e-4)
|
||||
assert torch.allclose(recon_scaled, recon_base, atol=1e-5, rtol=1e-4)
|
||||
|
||||
def test_fast_latents_normalization_matches_formula(self):
|
||||
latents_mean = torch.full((1, 16, 1, 1), 0.25, dtype=torch.float32)
|
||||
latents_std = torch.full((1, 16, 1, 1), 2.0, dtype=torch.float32)
|
||||
|
||||
model_raw = self._make_model().eval()
|
||||
model_norm = self._make_model(latents_mean=latents_mean, latents_std=latents_std).eval()
|
||||
x = torch.rand(1, 3, 32, 32, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
z_raw = model_raw.encode(x).latent
|
||||
z_norm = model_norm.encode(x).latent
|
||||
|
||||
expected = (z_raw - latents_mean.to(z_raw.device, z_raw.dtype)) / (
|
||||
latents_std.to(z_raw.device, z_raw.dtype) + 1e-5
|
||||
)
|
||||
assert torch.allclose(z_norm, expected, atol=1e-5, rtol=1e-4)
|
||||
|
||||
def test_fast_slicing_matches_non_slicing(self):
|
||||
model = self._make_model().eval()
|
||||
x = torch.rand(3, 3, 32, 32, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
model.use_slicing = False
|
||||
z_no_slice = model.encode(x).latent
|
||||
out_no_slice = model.decode(z_no_slice).sample
|
||||
|
||||
model.use_slicing = True
|
||||
z_slice = model.encode(x).latent
|
||||
out_slice = model.decode(z_slice).sample
|
||||
|
||||
assert torch.allclose(z_slice, z_no_slice, atol=1e-6, rtol=1e-5)
|
||||
assert torch.allclose(out_slice, out_no_slice, atol=1e-6, rtol=1e-5)
|
||||
|
||||
def test_fast_noise_tau_applies_only_in_train(self):
|
||||
model = self._make_model(noise_tau=0.5).to(torch_device)
|
||||
x = torch.rand(2, 3, 32, 32, device=torch_device)
|
||||
|
||||
model.train()
|
||||
torch.manual_seed(0)
|
||||
z_train_1 = model.encode(x).latent
|
||||
torch.manual_seed(1)
|
||||
z_train_2 = model.encode(x).latent
|
||||
|
||||
model.eval()
|
||||
torch.manual_seed(0)
|
||||
z_eval_1 = model.encode(x).latent
|
||||
torch.manual_seed(1)
|
||||
z_eval_2 = model.encode(x).latent
|
||||
|
||||
assert z_train_1.shape == z_eval_1.shape
|
||||
assert not torch.allclose(z_train_1, z_train_2)
|
||||
assert torch.allclose(z_eval_1, z_eval_2, atol=1e-6, rtol=1e-5)
|
||||
|
||||
|
||||
class TestAutoEncoderRAESlicingTiling(AutoencoderRAETesterConfig, AutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderRAE."""
|
||||
|
||||
|
||||
@slow
|
||||
@pytest.mark.skip(reason="Not enough model usage to justify slow tests yet.")
|
||||
class AutoencoderRAEEncoderIntegrationTests:
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_dinov2_encoder_forward_shape(self):
|
||||
encoder = _build_encoder("dinov2", hidden_size=768, patch_size=14, num_hidden_layers=12).to(torch_device)
|
||||
x = torch.rand(1, 3, 224, 224, device=torch_device)
|
||||
y = _ENCODER_FORWARD_FNS["dinov2"](encoder, x)
|
||||
|
||||
assert y.ndim == 3
|
||||
assert y.shape[0] == 1
|
||||
assert y.shape[1] == 256 # (224/14)^2 - 5 (CLS + 4 register) = 251? Actually dinov2 has 256 patches
|
||||
assert y.shape[2] == 768
|
||||
|
||||
def test_siglip2_encoder_forward_shape(self):
|
||||
encoder = _build_encoder("siglip2", hidden_size=768, patch_size=16, num_hidden_layers=12).to(torch_device)
|
||||
x = torch.rand(1, 3, 224, 224, device=torch_device)
|
||||
y = _ENCODER_FORWARD_FNS["siglip2"](encoder, x)
|
||||
|
||||
assert y.ndim == 3
|
||||
assert y.shape[0] == 1
|
||||
assert y.shape[1] == 196 # (224/16)^2
|
||||
assert y.shape[2] == 768
|
||||
|
||||
def test_mae_encoder_forward_shape(self):
|
||||
encoder = _build_encoder("mae", hidden_size=768, patch_size=16, num_hidden_layers=12).to(torch_device)
|
||||
x = torch.rand(1, 3, 224, 224, device=torch_device)
|
||||
y = _ENCODER_FORWARD_FNS["mae"](encoder, x, patch_size=16)
|
||||
|
||||
assert y.ndim == 3
|
||||
assert y.shape[0] == 1
|
||||
assert y.shape[1] == 196 # (224/16)^2
|
||||
assert y.shape[2] == 768
|
||||
|
||||
|
||||
@slow
|
||||
@pytest.mark.skip(reason="Not enough model usage to justify slow tests yet.")
|
||||
class AutoencoderRAEIntegrationTests:
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_autoencoder_rae_from_pretrained_dinov2(self):
|
||||
model = AutoencoderRAE.from_pretrained("nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08").to(torch_device)
|
||||
model.eval()
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
|
||||
)
|
||||
image = image.convert("RGB").resize((224, 224))
|
||||
x = to_tensor(image).unsqueeze(0).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
latents = model.encode(x).latent
|
||||
assert latents.shape == (1, 768, 16, 16)
|
||||
|
||||
recon = model.decode(latents).sample
|
||||
assert recon.shape == (1, 3, 256, 256)
|
||||
assert torch.isfinite(recon).all().item()
|
||||
|
||||
# fmt: off
|
||||
expected_latent_slice = torch.tensor([0.7617, 0.8824, -0.4891])
|
||||
expected_recon_slice = torch.tensor([0.1263, 0.1355, 0.1435])
|
||||
# fmt: on
|
||||
|
||||
assert torch_all_close(latents[0, :3, 0, 0].float().cpu(), expected_latent_slice, atol=1e-3)
|
||||
assert torch_all_close(recon[0, 0, 0, :3].float().cpu(), expected_recon_slice, atol=1e-3)
|
||||
@@ -7,9 +7,7 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
from transformers import CLIPTextModel, LongformerModel
|
||||
|
||||
from diffusers import ConfigMixin
|
||||
from diffusers.models import AutoModel, UNet2DConditionModel
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class TestAutoModel(unittest.TestCase):
|
||||
@@ -145,51 +143,3 @@ class TestAutoModelFromConfig(unittest.TestCase):
|
||||
def test_from_config_raises_on_none(self):
|
||||
with self.assertRaises(ValueError, msg="Please provide a `pretrained_model_name_or_path_or_dict`"):
|
||||
AutoModel.from_config(None)
|
||||
|
||||
|
||||
class TestRegisterForAutoClass(unittest.TestCase):
|
||||
def test_register_for_auto_class_sets_attribute(self):
|
||||
class DummyModel(ModelMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
DummyModel.register_for_auto_class("AutoModel")
|
||||
self.assertEqual(DummyModel._auto_class, "AutoModel")
|
||||
|
||||
def test_register_for_auto_class_rejects_unsupported(self):
|
||||
class DummyModel(ModelMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
with self.assertRaises(ValueError, msg="Only 'AutoModel' is supported"):
|
||||
DummyModel.register_for_auto_class("AutoPipeline")
|
||||
|
||||
def test_auto_map_in_saved_config(self):
|
||||
class DummyModel(ModelMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
DummyModel.register_for_auto_class("AutoModel")
|
||||
model = DummyModel()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_config(tmpdir)
|
||||
config_path = os.path.join(tmpdir, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.assertIn("auto_map", config)
|
||||
self.assertIn("AutoModel", config["auto_map"])
|
||||
module_name = DummyModel.__module__.split(".")[-1]
|
||||
self.assertEqual(config["auto_map"]["AutoModel"], f"{module_name}.DummyModel")
|
||||
|
||||
def test_no_auto_map_without_register(self):
|
||||
class DummyModel(ModelMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
model = DummyModel()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_config(tmpdir)
|
||||
config_path = os.path.join(tmpdir, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.assertNotIn("auto_map", config)
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import HeliosTransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HeliosTransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return HeliosTransformer3DModel
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-helios-base-transformer"
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (4, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (4, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
|
||||
return {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"text_dim": 16,
|
||||
"freq_dim": 256,
|
||||
"ffn_dim": 32,
|
||||
"num_layers": 2,
|
||||
"cross_attn_norm": True,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_dim": (4, 4, 4),
|
||||
"has_multi_term_memory_patch": True,
|
||||
"guidance_cross_attn": True,
|
||||
"zero_history_timestep": True,
|
||||
"is_amplify_history": False,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
)
|
||||
indices_hidden_states = torch.ones((batch_size, num_frames)).to(torch_device)
|
||||
indices_latents_history_short = torch.ones((batch_size, num_frames - 1)).to(torch_device)
|
||||
indices_latents_history_mid = torch.ones((batch_size, num_frames - 1)).to(torch_device)
|
||||
indices_latents_history_long = torch.ones((batch_size, (num_frames - 1) * 4)).to(torch_device)
|
||||
latents_history_short = randn_tensor(
|
||||
(batch_size, num_channels, num_frames - 1, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
)
|
||||
latents_history_mid = randn_tensor(
|
||||
(batch_size, num_channels, num_frames - 1, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
)
|
||||
latents_history_long = randn_tensor(
|
||||
(batch_size, num_channels, (num_frames - 1) * 4, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"indices_hidden_states": indices_hidden_states,
|
||||
"indices_latents_history_short": indices_latents_history_short,
|
||||
"indices_latents_history_mid": indices_latents_history_mid,
|
||||
"indices_latents_history_long": indices_latents_history_long,
|
||||
"latents_history_short": latents_history_short,
|
||||
"latents_history_mid": latents_history_mid,
|
||||
"latents_history_long": latents_history_long,
|
||||
}
|
||||
|
||||
|
||||
class TestHeliosTransformer3D(HeliosTransformer3DTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Helios Transformer 3D."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
|
||||
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
|
||||
pytest.skip("Tolerance requirements too high for meaningful test")
|
||||
|
||||
|
||||
class TestHeliosTransformer3DMemory(HeliosTransformer3DTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Helios Transformer 3D."""
|
||||
|
||||
|
||||
class TestHeliosTransformer3DTraining(HeliosTransformer3DTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Helios Transformer 3D."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HeliosTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestHeliosTransformer3DAttention(HeliosTransformer3DTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Helios Transformer 3D."""
|
||||
|
||||
|
||||
class TestHeliosTransformer3DCompile(HeliosTransformer3DTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for Helios Transformer 3D."""
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Helios DiT does not compile when deterministic algorithms are used due to https://github.com/pytorch/pytorch/issues/170079"
|
||||
)
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
@@ -10,11 +10,6 @@ import torch
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines import (
|
||||
ConditionalPipelineBlocks,
|
||||
LoopSequentialPipelineBlocks,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
@@ -24,13 +19,7 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
@@ -440,117 +429,6 @@ class ModularGuiderTesterMixin:
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def get_dummy_conditional_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
class DummyConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [DummyBlockOne, DummyBlockTwo]
|
||||
block_names = ["block_one", "block_two"]
|
||||
block_trigger_inputs = []
|
||||
|
||||
def select_block(self, **kwargs):
|
||||
return "block_one"
|
||||
|
||||
return DummyConditionalBlocks()
|
||||
|
||||
def get_dummy_loop_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
|
||||
|
||||
def test_sequential_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
pipe.save_pretrained(tmp_path)
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_sequential_block_requirements_warnings(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(tmp_path)
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
def test_conditional_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_conditional_block_pipe()
|
||||
pipe.save_pretrained(tmp_path)
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
def test_loop_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_loop_block_pipe()
|
||||
pipe.save_pretrained(tmp_path)
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
|
||||
class TestModularModelCardContent:
|
||||
def create_mock_block(self, name="TestBlock", description="Test block description"):
|
||||
class MockBlock:
|
||||
@@ -605,7 +483,8 @@ class TestModularModelCardContent:
|
||||
"blocks_description",
|
||||
"components_description",
|
||||
"configs_section",
|
||||
"io_specification_section",
|
||||
"inputs_description",
|
||||
"outputs_description",
|
||||
"trigger_inputs_section",
|
||||
"tags",
|
||||
]
|
||||
@@ -702,19 +581,18 @@ class TestModularModelCardContent:
|
||||
blocks = self.create_mock_blocks(inputs=inputs)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
io_section = content["io_specification_section"]
|
||||
assert "**Inputs:**" in io_section
|
||||
assert "prompt" in io_section
|
||||
assert "num_steps" in io_section
|
||||
assert "*optional*" in io_section
|
||||
assert "defaults to `50`" in io_section
|
||||
assert "**Required:**" in content["inputs_description"]
|
||||
assert "**Optional:**" in content["inputs_description"]
|
||||
assert "prompt" in content["inputs_description"]
|
||||
assert "num_steps" in content["inputs_description"]
|
||||
assert "default: `50`" in content["inputs_description"]
|
||||
|
||||
def test_inputs_description_empty(self):
|
||||
"""Test handling of pipelines without specific inputs."""
|
||||
blocks = self.create_mock_blocks(inputs=[])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "No specific inputs defined" in content["io_specification_section"]
|
||||
assert "No specific inputs defined" in content["inputs_description"]
|
||||
|
||||
def test_outputs_description_formatting(self):
|
||||
"""Test that outputs are correctly formatted."""
|
||||
@@ -724,16 +602,15 @@ class TestModularModelCardContent:
|
||||
blocks = self.create_mock_blocks(outputs=outputs)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
io_section = content["io_specification_section"]
|
||||
assert "images" in io_section
|
||||
assert "Generated images" in io_section
|
||||
assert "images" in content["outputs_description"]
|
||||
assert "Generated images" in content["outputs_description"]
|
||||
|
||||
def test_outputs_description_empty(self):
|
||||
"""Test handling of pipelines without specific outputs."""
|
||||
blocks = self.create_mock_blocks(outputs=[])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "Standard pipeline outputs" in content["io_specification_section"]
|
||||
assert "Standard pipeline outputs" in content["outputs_description"]
|
||||
|
||||
def test_trigger_inputs_section_with_triggers(self):
|
||||
"""Test that trigger inputs section is generated when present."""
|
||||
@@ -751,6 +628,35 @@ class TestModularModelCardContent:
|
||||
|
||||
assert content["trigger_inputs_section"] == ""
|
||||
|
||||
def test_blocks_description_with_sub_blocks(self):
|
||||
"""Test that blocks with sub-blocks are correctly described."""
|
||||
|
||||
class MockBlockWithSubBlocks:
|
||||
def __init__(self):
|
||||
self.__class__.__name__ = "ParentBlock"
|
||||
self.description = "Parent block"
|
||||
self.sub_blocks = {
|
||||
"child1": self.create_child_block("ChildBlock1", "Child 1 description"),
|
||||
"child2": self.create_child_block("ChildBlock2", "Child 2 description"),
|
||||
}
|
||||
|
||||
def create_child_block(self, name, desc):
|
||||
class ChildBlock:
|
||||
def __init__(self):
|
||||
self.__class__.__name__ = name
|
||||
self.description = desc
|
||||
|
||||
return ChildBlock()
|
||||
|
||||
blocks = self.create_mock_blocks()
|
||||
blocks.sub_blocks["parent"] = MockBlockWithSubBlocks()
|
||||
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "parent" in content["blocks_description"]
|
||||
assert "child1" in content["blocks_description"]
|
||||
assert "child2" in content["blocks_description"]
|
||||
|
||||
def test_model_description_includes_block_count(self):
|
||||
"""Test that model description includes the number of blocks."""
|
||||
blocks = self.create_mock_blocks(num_blocks=5)
|
||||
@@ -809,18 +715,6 @@ class TestLoadComponentsSkipBehavior:
|
||||
assert pipe.unet is not None
|
||||
assert getattr(pipe, "vae", None) is None
|
||||
|
||||
def test_load_components_selective_loading_incremental(self):
|
||||
"""Loading a subset of components should not affect already-loaded components."""
|
||||
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
|
||||
|
||||
pipe.load_components(names="unet", torch_dtype=torch.float32)
|
||||
pipe.load_components(names="text_encoder", torch_dtype=torch.float32)
|
||||
|
||||
assert hasattr(pipe, "unet")
|
||||
assert pipe.unet is not None
|
||||
assert hasattr(pipe, "text_encoder")
|
||||
assert pipe.text_encoder is not None
|
||||
|
||||
def test_load_components_skips_invalid_pretrained_path(self):
|
||||
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
|
||||
|
||||
@@ -883,36 +777,6 @@ class TestCustomModelSavePretrained:
|
||||
for key in original_state_dict:
|
||||
assert torch.equal(original_state_dict[key], loaded_state_dict[key]), f"Mismatch in {key}"
|
||||
|
||||
def test_save_pretrained_updates_index_for_model_with_no_load_id(self, tmp_path):
|
||||
"""testing the workflow of update the pipeline with a custom model and save the pipeline,
|
||||
the modular_model_index.json should point to the save directory."""
|
||||
import json
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-xl-pipe", subfolder="unet"
|
||||
)
|
||||
assert not hasattr(unet, "_diffusers_load_id")
|
||||
|
||||
pipe.update_components(unet=unet)
|
||||
|
||||
save_dir = str(tmp_path / "my-pipeline")
|
||||
pipe.save_pretrained(save_dir)
|
||||
|
||||
with open(os.path.join(save_dir, "modular_model_index.json")) as f:
|
||||
index = json.load(f)
|
||||
|
||||
_library, _cls, unet_spec = index["unet"]
|
||||
assert unet_spec["pretrained_model_name_or_path"] == save_dir
|
||||
assert unet_spec["subfolder"] == "unet"
|
||||
|
||||
_library, _cls, vae_spec = index["vae"]
|
||||
assert vae_spec["pretrained_model_name_or_path"] == "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
|
||||
def test_save_pretrained_overwrite_modular_index(self, tmp_path):
|
||||
"""With overwrite_modular_index=True, all component references should point to the save directory."""
|
||||
import json
|
||||
|
||||
@@ -192,156 +192,6 @@ class TestModularCustomBlocks:
|
||||
assert len(pipe.components) == 1
|
||||
assert pipe.component_names[0] == "transformer"
|
||||
|
||||
def test_trust_remote_code_not_propagated_to_external_repo(self):
|
||||
"""When a modular pipeline repo references a component from an external repo that has custom
|
||||
code (auto_map in config), calling load_components(trust_remote_code=True) should NOT
|
||||
propagate trust_remote_code to that external component. The external component should fail
|
||||
to load."""
|
||||
|
||||
from diffusers import ModularPipeline
|
||||
|
||||
CUSTOM_MODEL_CODE = (
|
||||
"import torch\n"
|
||||
"from diffusers import ModelMixin, ConfigMixin\n"
|
||||
"from diffusers.configuration_utils import register_to_config\n"
|
||||
"\n"
|
||||
"class CustomModel(ModelMixin, ConfigMixin):\n"
|
||||
" @register_to_config\n"
|
||||
" def __init__(self, hidden_size=8):\n"
|
||||
" super().__init__()\n"
|
||||
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
|
||||
"\n"
|
||||
" def forward(self, x):\n"
|
||||
" return self.linear(x)\n"
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as external_repo_dir, tempfile.TemporaryDirectory() as pipeline_repo_dir:
|
||||
# Step 1: Create an external model repo with custom code (requires trust_remote_code)
|
||||
with open(os.path.join(external_repo_dir, "modeling.py"), "w") as f:
|
||||
f.write(CUSTOM_MODEL_CODE)
|
||||
|
||||
config = {
|
||||
"_class_name": "CustomModel",
|
||||
"_diffusers_version": "0.0.0",
|
||||
"auto_map": {"AutoModel": "modeling.CustomModel"},
|
||||
"hidden_size": 8,
|
||||
}
|
||||
with open(os.path.join(external_repo_dir, "config.json"), "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
torch.save({}, os.path.join(external_repo_dir, "diffusion_pytorch_model.bin"))
|
||||
|
||||
# Step 2: Create a custom block that references the external repo.
|
||||
# Define both the class (for direct use) and its code string (for block.py).
|
||||
class ExternalRefBlock(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
"custom_model",
|
||||
AutoModel,
|
||||
pretrained_model_name_or_path=external_repo_dir,
|
||||
)
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("prompt", type_hint=str, required=True)]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("output", type_hint=str)]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.output = "test"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
EXTERNAL_REF_BLOCK_CODE_STR = (
|
||||
"from typing import List\n"
|
||||
"from diffusers import AutoModel\n"
|
||||
"from diffusers.modular_pipelines import (\n"
|
||||
" ComponentSpec,\n"
|
||||
" InputParam,\n"
|
||||
" ModularPipelineBlocks,\n"
|
||||
" OutputParam,\n"
|
||||
" PipelineState,\n"
|
||||
")\n"
|
||||
"\n"
|
||||
"class ExternalRefBlock(ModularPipelineBlocks):\n"
|
||||
" @property\n"
|
||||
" def expected_components(self):\n"
|
||||
" return [\n"
|
||||
" ComponentSpec(\n"
|
||||
' "custom_model",\n'
|
||||
" AutoModel,\n"
|
||||
f' pretrained_model_name_or_path="{external_repo_dir}",\n'
|
||||
" )\n"
|
||||
" ]\n"
|
||||
"\n"
|
||||
" @property\n"
|
||||
" def inputs(self) -> List[InputParam]:\n"
|
||||
' return [InputParam("prompt", type_hint=str, required=True)]\n'
|
||||
"\n"
|
||||
" @property\n"
|
||||
" def intermediate_inputs(self) -> List[InputParam]:\n"
|
||||
" return []\n"
|
||||
"\n"
|
||||
" @property\n"
|
||||
" def intermediate_outputs(self) -> List[OutputParam]:\n"
|
||||
' return [OutputParam("output", type_hint=str)]\n'
|
||||
"\n"
|
||||
" def __call__(self, components, state: PipelineState) -> PipelineState:\n"
|
||||
" block_state = self.get_block_state(state)\n"
|
||||
' block_state.output = "test"\n'
|
||||
" self.set_block_state(state, block_state)\n"
|
||||
" return components, state\n"
|
||||
)
|
||||
|
||||
# Save the block config, write block.py, then load back via from_pretrained
|
||||
block = ExternalRefBlock()
|
||||
block.save_pretrained(pipeline_repo_dir)
|
||||
|
||||
# auto_map will reference the module name derived from ExternalRefBlock.__module__,
|
||||
# which is "test_modular_pipelines_custom_blocks". Write the code file with that name.
|
||||
code_path = os.path.join(pipeline_repo_dir, "test_modular_pipelines_custom_blocks.py")
|
||||
with open(code_path, "w") as f:
|
||||
f.write(EXTERNAL_REF_BLOCK_CODE_STR)
|
||||
|
||||
block = ModularPipelineBlocks.from_pretrained(pipeline_repo_dir, trust_remote_code=True)
|
||||
pipe = block.init_pipeline()
|
||||
pipe.save_pretrained(pipeline_repo_dir)
|
||||
|
||||
# Step 3: Load the pipeline from the saved directory.
|
||||
loaded_pipe = ModularPipeline.from_pretrained(pipeline_repo_dir, trust_remote_code=True)
|
||||
|
||||
assert loaded_pipe._pretrained_model_name_or_path == pipeline_repo_dir
|
||||
assert loaded_pipe._component_specs["custom_model"].pretrained_model_name_or_path == external_repo_dir
|
||||
assert getattr(loaded_pipe, "custom_model", None) is None
|
||||
|
||||
# Step 4a: load_components WITHOUT trust_remote_code.
|
||||
# It should still fail
|
||||
loaded_pipe.load_components()
|
||||
assert getattr(loaded_pipe, "custom_model", None) is None
|
||||
|
||||
# Step 4b: load_components with trust_remote_code=True.
|
||||
# trust_remote_code should be stripped for the external component, so it fails.
|
||||
# The warning should contain guidance about manually loading with trust_remote_code.
|
||||
loaded_pipe.load_components(trust_remote_code=True)
|
||||
assert getattr(loaded_pipe, "custom_model", None) is None
|
||||
|
||||
# Step 4c: Manually load with AutoModel and update_components — this should work.
|
||||
from diffusers import AutoModel
|
||||
|
||||
custom_model = AutoModel.from_pretrained(external_repo_dir, trust_remote_code=True)
|
||||
loaded_pipe.update_components(custom_model=custom_model)
|
||||
assert getattr(loaded_pipe, "custom_model", None) is not None
|
||||
|
||||
def test_custom_block_loads_from_hub(self):
|
||||
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
|
||||
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
|
||||
|
||||
@@ -1,172 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLWan, HeliosPipeline, HeliosScheduler, HeliosTransformer3DModel
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = HeliosPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = HeliosScheduler(stage_range=[0, 1], stages=1, use_dynamic_shifting=True)
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
text_encoder = T5EncoderModel(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = HeliosTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_dim=(4, 4, 4),
|
||||
has_multi_term_memory_patch=True,
|
||||
guidance_cross_attn=True,
|
||||
zero_history_timestep=True,
|
||||
is_amplify_history=False,
|
||||
)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "negative",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 16,
|
||||
"width": 16,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (33, 3, 16, 16))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.4529, 0.4527, 0.4499, 0.4542, 0.4528, 0.4524, 0.4531, 0.4534, 0.5328,
|
||||
0.5340, 0.5012, 0.5135, 0.5322, 0.5203, 0.5144, 0.5101])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_video.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||
|
||||
# Override to set a more lenient max diff threshold.
|
||||
def test_save_load_float16(self):
|
||||
super().test_save_load_float16(expected_max_diff=0.03)
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Optional components not applicable for Helios")
|
||||
def test_save_load_optional_components(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
class HeliosPipelineIntegrationTests(unittest.TestCase):
|
||||
prompt = "A painting of a squirrel eating a burger."
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@unittest.skip("TODO: test needs to be implemented")
|
||||
def test_helios(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user