mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-20 03:14:43 +08:00
Compare commits
2 Commits
push-test-
...
custom-blo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b6dcf696a | ||
|
|
351b2f172a |
45
.github/workflows/push_tests.yml
vendored
45
.github/workflows/push_tests.yml
vendored
@@ -76,7 +76,6 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers -y && pip uninstall huggingface_hub -y && python -m uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -128,7 +127,7 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers -y && pip uninstall huggingface_hub -y && python -m uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -179,7 +178,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
uv pip uninstall transformers -y && pip uninstall huggingface_hub -y && python -m uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -200,6 +198,47 @@ jobs:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
|
||||
run_xformers_tests:
|
||||
name: PyTorch xformers CUDA tests
|
||||
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-xformers-cuda
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_xformers_test_reports
|
||||
path: reports
|
||||
|
||||
run_examples_tests:
|
||||
name: Examples PyTorch CUDA tests on Ubuntu
|
||||
|
||||
|
||||
@@ -119,8 +119,6 @@
|
||||
title: ComponentsManager
|
||||
- local: modular_diffusers/guiders
|
||||
title: Guiders
|
||||
- local: modular_diffusers/custom_blocks
|
||||
title: Building Custom Blocks
|
||||
title: Modular Diffusers
|
||||
- isExpanded: false
|
||||
sections:
|
||||
@@ -375,8 +373,6 @@
|
||||
title: QwenImageTransformer2DModel
|
||||
- local: api/models/sana_transformer2d
|
||||
title: SanaTransformer2DModel
|
||||
- local: api/models/sana_video_transformer3d
|
||||
title: SanaVideoTransformer3DModel
|
||||
- local: api/models/sd3_transformer2d
|
||||
title: SD3Transformer2DModel
|
||||
- local: api/models/skyreels_v2_transformer_3d
|
||||
@@ -533,6 +529,8 @@
|
||||
title: Kandinsky 2.2
|
||||
- local: api/pipelines/kandinsky3
|
||||
title: Kandinsky 3
|
||||
- local: api/pipelines/kandinsky5
|
||||
title: Kandinsky 5
|
||||
- local: api/pipelines/kolors
|
||||
title: Kolors
|
||||
- local: api/pipelines/latent_consistency_models
|
||||
@@ -567,8 +565,6 @@
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
title: Sana Sprint
|
||||
- local: api/pipelines/sana_video
|
||||
title: Sana Video
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
@@ -642,8 +638,6 @@
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/kandinsky5_video
|
||||
title: Kandinsky 5.0 Video
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ltx_video
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# SanaVideoTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data (video) from [SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation.*
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import SanaVideoTransformer3DModel
|
||||
import torch
|
||||
|
||||
transformer = SanaVideoTransformer3DModel.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## SanaVideoTransformer3DModel
|
||||
|
||||
[[autodoc]] SanaVideoTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
|
||||
@@ -7,9 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Kandinsky 5.0 Video
|
||||
# Kandinsky 5.0
|
||||
|
||||
Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
|
||||
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
|
||||
|
||||
|
||||
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
|
||||
@@ -92,7 +92,7 @@ pipe = pipe.to("cuda")
|
||||
|
||||
pipe.transformer.set_attention_backend(
|
||||
"flex"
|
||||
) # <--- Sett attention bakend to Flex
|
||||
) # <--- Set attention backend to Flex
|
||||
pipe.transformer.compile(
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
dynamic=True
|
||||
@@ -115,7 +115,7 @@ export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
|
||||
### Diffusion Distilled model
|
||||
**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):
|
||||
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
|
||||
|
||||
```python
|
||||
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
|
||||
@@ -24,6 +24,9 @@ The abstract from the paper is:
|
||||
|
||||
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
|
||||
|
||||
> [!TIP]
|
||||
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
|
||||
|
||||
Available models:
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# SanaVideoPipeline
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
[SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation. [this https URL](https://github.com/NVlabs/SANA).*
|
||||
|
||||
This pipeline was contributed by SANA Team. The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://hf.co/collections/Efficient-Large-Model/sana-video).
|
||||
|
||||
Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-----:|:-----------------:|
|
||||
| [`Efficient-Large-Model/SANA-Video_2B_480p_diffusers`](https://huggingface.co/Efficient-Large-Model/ANA-Video_2B_480p_diffusers) | `torch.bfloat16` |
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-video) collection for more information.
|
||||
|
||||
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaVideoPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaVideoTransformer3DModel, SanaVideoPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = AutoModel.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = SanaVideoTransformer3DModel.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = SanaVideoPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
model_score = 30
|
||||
prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
|
||||
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
|
||||
motion_prompt = f" motion score: {model_score}."
|
||||
prompt = prompt + motion_prompt
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=480,
|
||||
width=832,
|
||||
num_frames=81,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=50
|
||||
).frames[0]
|
||||
export_to_video(output, "sana-video-output.mp4", fps=16)
|
||||
```
|
||||
|
||||
## SanaVideoPipeline
|
||||
|
||||
[[autodoc]] SanaVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## SanaVideoPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.sana.pipeline_sana_video.SanaVideoPipelineOutput
|
||||
@@ -1,493 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
|
||||
# Building Custom Blocks
|
||||
|
||||
[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block.
|
||||
|
||||
<Tip>
|
||||
You can find examples of different types of custom blocks in the [Modular Diffusers Custom Blocks collection](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks)
|
||||
</Tip>
|
||||
|
||||
## Project Structure
|
||||
|
||||
Your custom block project should use the following structure:
|
||||
|
||||
```shell
|
||||
.
|
||||
├── block.py
|
||||
└── modular_config.json
|
||||
```
|
||||
|
||||
- `block.py` contains the custom block implementation
|
||||
- `modular_config.json` contains the metadata needed to load the block
|
||||
|
||||
## Example: Florence 2 Inpainting Block
|
||||
|
||||
In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting.
|
||||
|
||||
The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub.
|
||||
|
||||
```py
|
||||
# Inside block.py
|
||||
from diffusers.modular_pipelines import (
|
||||
ModularPipelineBlocks,
|
||||
ComponentSpec,
|
||||
)
|
||||
from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||
|
||||
|
||||
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="image_annotator",
|
||||
type_hint=Florence2ForConditionalGeneration,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
ComponentSpec(
|
||||
name="image_annotator_processor",
|
||||
type_hint=AutoProcessor,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations.
|
||||
|
||||
```py
|
||||
from typing import List, Union
|
||||
from PIL import Image, ImageDraw
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
PipelineState,
|
||||
ModularPipelineBlocks,
|
||||
InputParam,
|
||||
ComponentSpec,
|
||||
OutputParam,
|
||||
)
|
||||
from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||
|
||||
|
||||
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="image_annotator",
|
||||
type_hint=Florence2ForConditionalGeneration,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
ComponentSpec(
|
||||
name="image_annotator_processor",
|
||||
type_hint=AutoProcessor,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"image",
|
||||
type_hint=Union[Image.Image, List[Image.Image]],
|
||||
required=True,
|
||||
description="Image(s) to annotate",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_task",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
default="<REFERRING_EXPRESSION_SEGMENTATION>",
|
||||
description="""Annotation Task to perform on the image.
|
||||
Supported Tasks:
|
||||
|
||||
<OD>
|
||||
<REFERRING_EXPRESSION_SEGMENTATION>
|
||||
<CAPTION>
|
||||
<DETAILED_CAPTION>
|
||||
<MORE_DETAILED_CAPTION>
|
||||
<DENSE_REGION_CAPTION>
|
||||
<CAPTION_TO_PHRASE_GROUNDING>
|
||||
<OPEN_VOCABULARY_DETECTION>
|
||||
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_prompt",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
description="""Annotation Prompt to provide more context to the task.
|
||||
Can be used to detect or segment out specific elements in the image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_output_type",
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
- mask overlayed on the original image
|
||||
bounding_box:
|
||||
- bounding boxes drawn on the original image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_overlay",
|
||||
type_hint=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
description="",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"mask_image",
|
||||
type_hint=Image,
|
||||
description="Inpainting Mask for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"annotations",
|
||||
type_hint=dict,
|
||||
description="Annotations Predictions for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"image",
|
||||
type_hint=Image,
|
||||
description="Annotated input Image(s)",
|
||||
),
|
||||
]
|
||||
|
||||
```
|
||||
|
||||
Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask.
|
||||
|
||||
```py
|
||||
from typing import List, Union
|
||||
from PIL import Image, ImageDraw
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
PipelineState,
|
||||
ModularPipelineBlocks,
|
||||
InputParam,
|
||||
ComponentSpec,
|
||||
OutputParam,
|
||||
)
|
||||
from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||
|
||||
|
||||
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="image_annotator",
|
||||
type_hint=Florence2ForConditionalGeneration,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
ComponentSpec(
|
||||
name="image_annotator_processor",
|
||||
type_hint=AutoProcessor,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"image",
|
||||
type_hint=Union[Image.Image, List[Image.Image]],
|
||||
required=True,
|
||||
description="Image(s) to annotate",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_task",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
default="<REFERRING_EXPRESSION_SEGMENTATION>",
|
||||
description="""Annotation Task to perform on the image.
|
||||
Supported Tasks:
|
||||
|
||||
<OD>
|
||||
<REFERRING_EXPRESSION_SEGMENTATION>
|
||||
<CAPTION>
|
||||
<DETAILED_CAPTION>
|
||||
<MORE_DETAILED_CAPTION>
|
||||
<DENSE_REGION_CAPTION>
|
||||
<CAPTION_TO_PHRASE_GROUNDING>
|
||||
<OPEN_VOCABULARY_DETECTION>
|
||||
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_prompt",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
description="""Annotation Prompt to provide more context to the task.
|
||||
Can be used to detect or segment out specific elements in the image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_output_type",
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
- mask overlayed on the original image
|
||||
bounding_box:
|
||||
- bounding boxes drawn on the original image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_overlay",
|
||||
type_hint=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
description="",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"mask_image",
|
||||
type_hint=Image,
|
||||
description="Inpainting Mask for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"annotations",
|
||||
type_hint=dict,
|
||||
description="Annotations Predictions for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"image",
|
||||
type_hint=Image,
|
||||
description="Annotated input Image(s)",
|
||||
),
|
||||
]
|
||||
|
||||
def get_annotations(self, components, images, prompts, task):
|
||||
task_prompts = [task + prompt for prompt in prompts]
|
||||
|
||||
inputs = components.image_annotator_processor(
|
||||
text=task_prompts, images=images, return_tensors="pt"
|
||||
).to(components.image_annotator.device, components.image_annotator.dtype)
|
||||
|
||||
generated_ids = components.image_annotator.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=1024,
|
||||
early_stopping=False,
|
||||
do_sample=False,
|
||||
num_beams=3,
|
||||
)
|
||||
annotations = components.image_annotator_processor.batch_decode(
|
||||
generated_ids, skip_special_tokens=False
|
||||
)
|
||||
outputs = []
|
||||
for image, annotation in zip(images, annotations):
|
||||
outputs.append(
|
||||
components.image_annotator_processor.post_process_generation(
|
||||
annotation, task=task, image_size=(image.width, image.height)
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
|
||||
def prepare_mask(self, images, annotations, overlay=False, fill="white"):
|
||||
masks = []
|
||||
for image, annotation in zip(images, annotations):
|
||||
mask_image = image.copy() if overlay else Image.new("L", image.size, 0)
|
||||
draw = ImageDraw.Draw(mask_image)
|
||||
|
||||
for _, _annotation in annotation.items():
|
||||
if "polygons" in _annotation:
|
||||
for polygon in _annotation["polygons"]:
|
||||
polygon = np.array(polygon).reshape(-1, 2)
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
polygon = polygon.reshape(-1).tolist()
|
||||
draw.polygon(polygon, fill=fill)
|
||||
|
||||
elif "bbox" in _annotation:
|
||||
bbox = _annotation["bbox"]
|
||||
draw.rectangle(bbox, fill="white")
|
||||
|
||||
masks.append(mask_image)
|
||||
|
||||
return masks
|
||||
|
||||
def prepare_bounding_boxes(self, images, annotations):
|
||||
outputs = []
|
||||
for image, annotation in zip(images, annotations):
|
||||
image_copy = image.copy()
|
||||
draw = ImageDraw.Draw(image_copy)
|
||||
for _, _annotation in annotation.items():
|
||||
bbox = _annotation["bbox"]
|
||||
label = _annotation["label"]
|
||||
|
||||
draw.rectangle(bbox, outline="red", width=3)
|
||||
draw.text((bbox[0], bbox[1] - 20), label, fill="red")
|
||||
|
||||
outputs.append(image_copy)
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs(self, images, prompts):
|
||||
prompts = prompts or ""
|
||||
|
||||
if isinstance(images, Image.Image):
|
||||
images = [images]
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if len(images) != len(prompts):
|
||||
raise ValueError("Number of images and annotation prompts must match.")
|
||||
|
||||
return images, prompts
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
images, annotation_task_prompt = self.prepare_inputs(
|
||||
block_state.image, block_state.annotation_prompt
|
||||
)
|
||||
task = block_state.annotation_task
|
||||
fill = block_state.fill
|
||||
|
||||
annotations = self.get_annotations(
|
||||
components, images, annotation_task_prompt, task
|
||||
)
|
||||
block_state.annotations = annotations
|
||||
if block_state.annotation_output_type == "mask_image":
|
||||
block_state.mask_image = self.prepare_mask(images, annotations)
|
||||
else:
|
||||
block_state.mask_image = None
|
||||
|
||||
if block_state.annotation_output_type == "mask_overlay":
|
||||
block_state.image = self.prepare_mask(images, annotations, overlay=True, fill=fill)
|
||||
|
||||
elif block_state.annotation_output_type == "bounding_box":
|
||||
block_state.image = self.prepare_bounding_boxes(images, annotations)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
```
|
||||
|
||||
Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines.
|
||||
|
||||
<hfoptions id="share">
|
||||
<hfoption id="hf CLI">
|
||||
|
||||
```shell
|
||||
# In the folder with the `block.py` file, run:
|
||||
diffusers-cli custom_block
|
||||
```
|
||||
|
||||
Then upload the block to the Hub:
|
||||
|
||||
```shell
|
||||
hf upload <your repo id> . .
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="push_to_hub">
|
||||
|
||||
```py
|
||||
from block import Florence2ImageAnnotatorBlock
|
||||
block = Florence2ImageAnnotatorBlock()
|
||||
block.push_to_hub("<your repo id>")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Using Custom Blocks
|
||||
|
||||
Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# Fetch the Florence2 image annotator block that will create our mask
|
||||
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True)
|
||||
|
||||
my_blocks = INPAINT_BLOCKS.copy()
|
||||
# insert the annotation block before the image encoding step
|
||||
my_blocks.insert("image_annotator", image_annotator_block, 1)
|
||||
|
||||
# Create our initial set of inpainting blocks
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks)
|
||||
|
||||
repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0"
|
||||
pipe = blocks.init_pipeline(repo_id)
|
||||
pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True)
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true")
|
||||
image = image.resize((1024, 1024))
|
||||
|
||||
prompt = ["A red car"]
|
||||
annotation_task = "<REFERRING_EXPRESSION_SEGMENTATION>"
|
||||
annotation_prompt = ["the car"]
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
annotation_task=annotation_task,
|
||||
annotation_prompt=annotation_prompt,
|
||||
annotation_output_type="mask_image",
|
||||
num_inference_steps=35,
|
||||
guidance_scale=7.5,
|
||||
strength=0.95,
|
||||
output="images"
|
||||
)
|
||||
output[0].save("florence-inpainting.png")
|
||||
```
|
||||
|
||||
## Editing Custom Blocks
|
||||
|
||||
By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# Fetch the Florence2 image annotator block that will create our mask
|
||||
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder")
|
||||
```
|
||||
|
||||
Any changes made to the block files in this folder will be reflected when you load the block again.
|
||||
@@ -104,8 +104,6 @@ To use your own dataset, there are 2 ways:
|
||||
- you can either provide your own folder as `--train_data_dir`
|
||||
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
|
||||
|
||||
If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
|
||||
|
||||
Below, we explain both in more detail.
|
||||
|
||||
#### Provide the dataset as a folder
|
||||
|
||||
@@ -52,24 +52,6 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
return res.expand(broadcast_shape)
|
||||
|
||||
|
||||
def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.
|
||||
"""
|
||||
if tensor.ndim == 2:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
channels = tensor.shape[0]
|
||||
if channels == 3:
|
||||
return tensor
|
||||
if channels == 1:
|
||||
return tensor.repeat(3, 1, 1)
|
||||
if channels == 2:
|
||||
return torch.cat([tensor, tensor[:1]], dim=0)
|
||||
if channels > 3:
|
||||
return tensor[:3]
|
||||
raise ValueError(f"Unsupported number of channels: {channels}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
@@ -278,11 +260,6 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preserve_input_precision",
|
||||
action="store_true",
|
||||
help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -476,41 +453,19 @@ def main(args):
|
||||
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
||||
|
||||
# Preprocessing the datasets and DataLoaders creation.
|
||||
spatial_augmentations = [
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
]
|
||||
|
||||
augmentations = transforms.Compose(
|
||||
spatial_augmentations
|
||||
+ [
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
precision_augmentations = transforms.Compose(
|
||||
[
|
||||
transforms.PILToTensor(),
|
||||
transforms.Lambda(_ensure_three_channels),
|
||||
transforms.ConvertImageDtype(torch.float32),
|
||||
]
|
||||
+ spatial_augmentations
|
||||
+ [transforms.Normalize([0.5], [0.5])]
|
||||
)
|
||||
|
||||
def transform_images(examples):
|
||||
processed = []
|
||||
for image in examples["image"]:
|
||||
if not args.preserve_input_precision:
|
||||
processed.append(augmentations(image.convert("RGB")))
|
||||
else:
|
||||
precise_image = image
|
||||
if precise_image.mode == "P":
|
||||
precise_image = precise_image.convert("RGB")
|
||||
processed.append(precision_augmentations(precise_image))
|
||||
return {"input": processed}
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"input": images}
|
||||
|
||||
logger.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
|
||||
@@ -1,324 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from termcolor import colored
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SanaVideoPipeline,
|
||||
SanaVideoTransformer3DModel,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
|
||||
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
|
||||
|
||||
|
||||
def main(args):
|
||||
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
|
||||
|
||||
if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
|
||||
ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
|
||||
snapshot_download(
|
||||
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
|
||||
cache_dir=cache_dir_path,
|
||||
repo_type="model",
|
||||
)
|
||||
file_path = hf_hub_download(
|
||||
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
|
||||
filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
|
||||
cache_dir=cache_dir_path,
|
||||
repo_type="model",
|
||||
)
|
||||
else:
|
||||
file_path = args.orig_ckpt_path
|
||||
|
||||
print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
|
||||
all_state_dict = torch.load(file_path, weights_only=True)
|
||||
state_dict = all_state_dict.pop("state_dict")
|
||||
converted_state_dict = {}
|
||||
|
||||
# Patch embeddings.
|
||||
converted_state_dict["patch_embedding.weight"] = state_dict.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["patch_embedding.bias"] = state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# Caption projection.
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
|
||||
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
|
||||
# Shared norm.
|
||||
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
|
||||
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
|
||||
|
||||
# y norm
|
||||
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
||||
|
||||
# scheduler
|
||||
flow_shift = 8.0
|
||||
|
||||
# model config
|
||||
layer_num = 20
|
||||
# Positional embedding interpolation scale.
|
||||
qk_norm = True
|
||||
|
||||
# sample size
|
||||
if args.video_size == 480:
|
||||
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
|
||||
patch_size = (1, 2, 2)
|
||||
elif args.video_size == 720:
|
||||
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
|
||||
patch_size = (1, 1, 1)
|
||||
else:
|
||||
raise ValueError(f"Video size {args.video_size} is not supported.")
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
||||
f"blocks.{depth}.scale_shift_table"
|
||||
)
|
||||
|
||||
# Linear Attention is all you need 🤘
|
||||
# Self attention.
|
||||
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.k_norm.weight"
|
||||
)
|
||||
# Projection.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.bias"
|
||||
)
|
||||
|
||||
# Feed-forward.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.inverted_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.depth_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.depth_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.point_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_temp.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.t_conv.weight"
|
||||
)
|
||||
|
||||
# Cross-attention.
|
||||
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
|
||||
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
|
||||
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
|
||||
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.k_norm.weight"
|
||||
)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.bias"
|
||||
)
|
||||
|
||||
# Final block.
|
||||
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
|
||||
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
|
||||
|
||||
# Transformer
|
||||
with CTX():
|
||||
transformer_kwargs = {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 20,
|
||||
"attention_head_dim": 112,
|
||||
"num_layers": 20,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"caption_channels": 2304,
|
||||
"mlp_ratio": 3.0,
|
||||
"attention_bias": False,
|
||||
"sample_size": sample_size,
|
||||
"patch_size": patch_size,
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 1024,
|
||||
}
|
||||
|
||||
transformer = SanaVideoTransformer3DModel(**transformer_kwargs)
|
||||
|
||||
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
|
||||
|
||||
try:
|
||||
state_dict.pop("y_embedder.y_embedding")
|
||||
state_dict.pop("pos_embed")
|
||||
state_dict.pop("logvar_linear.weight")
|
||||
state_dict.pop("logvar_linear.bias")
|
||||
except KeyError:
|
||||
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
|
||||
|
||||
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
|
||||
|
||||
num_model_params = sum(p.numel() for p in transformer.parameters())
|
||||
print(f"Total number of transformer parameters: {num_model_params}")
|
||||
|
||||
transformer = transformer.to(weight_dtype)
|
||||
|
||||
if not args.save_full_pipeline:
|
||||
print(
|
||||
colored(
|
||||
f"Only saving transformer model of {args.model_type}. "
|
||||
f"Set --save_full_pipeline to save the whole Pipeline",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
transformer.save_pretrained(
|
||||
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
else:
|
||||
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
# VAE
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
# Text Encoder
|
||||
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
|
||||
tokenizer.padding_side = "right"
|
||||
text_encoder = AutoModelForCausalLM.from_pretrained(
|
||||
text_encoder_model_path, torch_dtype=torch.bfloat16
|
||||
).get_decoder()
|
||||
|
||||
# Choose the appropriate pipeline and scheduler based on model type
|
||||
# Original Sana scheduler
|
||||
if args.scheduler_type == "flow-dpm_solver":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
flow_shift=flow_shift,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
)
|
||||
elif args.scheduler_type == "flow-euler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
elif args.scheduler_type == "uni-pc":
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction",
|
||||
use_flow_sigmas=True,
|
||||
num_train_timesteps=1000,
|
||||
flow_shift=flow_shift,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
|
||||
|
||||
pipe = SanaVideoPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video_size",
|
||||
default=480,
|
||||
type=int,
|
||||
choices=[480, 720],
|
||||
required=False,
|
||||
help="Video size of pretrained model, 480 or 720.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default="SanaVideo",
|
||||
type=str,
|
||||
choices=[
|
||||
"SanaVideo",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler_type",
|
||||
default="flow-dpm_solver",
|
||||
type=str,
|
||||
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
|
||||
help="Scheduler type to use.",
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
|
||||
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
weight_dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
main(args)
|
||||
@@ -246,7 +246,6 @@ else:
|
||||
"QwenImageTransformer2DModel",
|
||||
"SanaControlNetModel",
|
||||
"SanaTransformer2DModel",
|
||||
"SanaVideoTransformer3DModel",
|
||||
"SD3ControlNetModel",
|
||||
"SD3MultiControlNetModel",
|
||||
"SD3Transformer2DModel",
|
||||
@@ -545,7 +544,6 @@ else:
|
||||
"SanaPipeline",
|
||||
"SanaSprintImg2ImgPipeline",
|
||||
"SanaSprintPipeline",
|
||||
"SanaVideoPipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
@@ -953,7 +951,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageTransformer2DModel,
|
||||
SanaControlNetModel,
|
||||
SanaTransformer2DModel,
|
||||
SanaVideoTransformer3DModel,
|
||||
SD3ControlNetModel,
|
||||
SD3MultiControlNetModel,
|
||||
SD3Transformer2DModel,
|
||||
@@ -1222,7 +1219,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SanaPipeline,
|
||||
SanaSprintImg2ImgPipeline,
|
||||
SanaSprintPipeline,
|
||||
SanaVideoPipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
|
||||
@@ -1045,39 +1045,16 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
r"""
|
||||
Convert an RGB-like depth image to a depth map.
|
||||
|
||||
Args:
|
||||
image (`Union[np.ndarray, torch.Tensor]`):
|
||||
The RGB-like depth image to convert.
|
||||
|
||||
Returns:
|
||||
`Union[np.ndarray, torch.Tensor]`:
|
||||
The corresponding depth map.
|
||||
"""
|
||||
# 1. Cast the tensor to a larger integer type (e.g., int32)
|
||||
# to safely perform the multiplication by 256.
|
||||
# 2. Perform the 16-bit combination: High-byte * 256 + Low-byte.
|
||||
# 3. Cast the final result to the desired depth map type (uint16) if needed
|
||||
# before returning, though leaving it as int32/int64 is often safer
|
||||
# for return value from a library function.
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Cast to a safe dtype (e.g., int32 or int64) for the calculation
|
||||
original_dtype = image.dtype
|
||||
image_safe = image.to(torch.int32)
|
||||
|
||||
# Calculate the depth map
|
||||
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
|
||||
|
||||
# You may want to cast the final result to uint16, but casting to a
|
||||
# larger int type (like int32) is sufficient to fix the overflow.
|
||||
# depth_map = depth_map.to(torch.uint16) # Uncomment if uint16 is strictly required
|
||||
return depth_map.to(original_dtype)
|
||||
|
||||
elif isinstance(image, np.ndarray):
|
||||
# NumPy equivalent: Cast to a safe dtype (e.g., np.int32)
|
||||
original_dtype = image.dtype
|
||||
image_safe = image.astype(np.int32)
|
||||
|
||||
# Calculate the depth map
|
||||
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
|
||||
|
||||
# depth_map = depth_map.astype(np.uint16) # Uncomment if uint16 is strictly required
|
||||
return depth_map.astype(original_dtype)
|
||||
else:
|
||||
raise TypeError("Input image must be a torch.Tensor or np.ndarray")
|
||||
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
||||
|
||||
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
r"""
|
||||
|
||||
@@ -2213,10 +2213,6 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
|
||||
state_dict = {convert_key(k): v for k, v in state_dict.items()}
|
||||
|
||||
has_default = any("default." in k for k in state_dict)
|
||||
if has_default:
|
||||
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
|
||||
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
|
||||
@@ -4940,8 +4940,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
|
||||
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
|
||||
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
has_default = any("default." in k for k in state_dict)
|
||||
if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
|
||||
if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
|
||||
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
|
||||
@@ -102,7 +102,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
@@ -205,7 +204,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PRXTransformer2DModel,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaTransformer2DModel,
|
||||
SanaVideoTransformer3DModel,
|
||||
SD3Transformer2DModel,
|
||||
SkyReelsV2Transformer3DModel,
|
||||
StableAudioDiTModel,
|
||||
|
||||
@@ -649,86 +649,6 @@ def _(
|
||||
# ===== Helper functions to use attention backends with templated CP autograd functions =====
|
||||
|
||||
|
||||
def _native_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
# Native attention does not return_lse
|
||||
if return_lse:
|
||||
raise ValueError("Native attention does not support return_lse=True")
|
||||
|
||||
# used for backward pass
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value)
|
||||
ctx.attn_mask = attn_mask
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.is_causal = is_causal
|
||||
ctx.scale = scale
|
||||
ctx.enable_gqa = enable_gqa
|
||||
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _native_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
query, key, value = ctx.saved_tensors
|
||||
|
||||
query.requires_grad_(True)
|
||||
key.requires_grad_(True)
|
||||
value.requires_grad_(True)
|
||||
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
grad_out_t = grad_out.permute(0, 2, 1, 3)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
|
||||
)
|
||||
|
||||
grad_query = grad_query_t.permute(0, 2, 1, 3)
|
||||
grad_key = grad_key_t.permute(0, 2, 1, 3)
|
||||
grad_value = grad_value_t.permute(0, 2, 1, 3)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
|
||||
# forward declaration:
|
||||
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
@@ -1603,7 +1523,6 @@ def _native_flex_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.NATIVE,
|
||||
constraints=[_check_device, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _native_attention(
|
||||
query: torch.Tensor,
|
||||
@@ -1619,35 +1538,18 @@ def _native_attention(
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
|
||||
if _parallel_config is None:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op=_native_attention_forward_op,
|
||||
backward_op=_native_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -147,13 +147,14 @@ class AutoModel(ConfigMixin):
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
|
||||
|
||||
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
|
||||
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder"]}
|
||||
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
|
||||
|
||||
library = None
|
||||
orig_class_name = None
|
||||
@@ -204,6 +205,7 @@ class AutoModel(ConfigMixin):
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
@@ -36,7 +36,6 @@ if is_torch_available():
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_prx import PRXTransformer2DModel
|
||||
from .transformer_qwenimage import QwenImageTransformer2DModel
|
||||
from .transformer_sana_video import SanaVideoTransformer3DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
|
||||
@@ -1,703 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team and SANA-Video Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class GLUMBTempConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
expand_ratio: float = 4,
|
||||
norm_type: Optional[str] = None,
|
||||
residual_connection: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_channels = int(expand_ratio * in_channels)
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
self.nonlinearity = nn.SiLU()
|
||||
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
|
||||
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
|
||||
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
|
||||
|
||||
self.norm = None
|
||||
if norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
|
||||
|
||||
self.conv_temp = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=(3, 1), stride=1, padding=(1, 0), bias=False
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.residual_connection:
|
||||
residual = hidden_states
|
||||
batch_size, num_frames, height, width, num_channels = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size * num_frames, height, width, num_channels).permute(0, 3, 1, 2)
|
||||
|
||||
hidden_states = self.conv_inverted(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv_depth(hidden_states)
|
||||
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
|
||||
hidden_states = hidden_states * self.nonlinearity(gate)
|
||||
|
||||
hidden_states = self.conv_point(hidden_states)
|
||||
|
||||
# Temporal aggregation
|
||||
hidden_states_temporal = hidden_states.view(batch_size, num_frames, num_channels, height * width).permute(
|
||||
0, 2, 1, 3
|
||||
)
|
||||
hidden_states = hidden_states_temporal + self.conv_temp(hidden_states_temporal)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).view(batch_size, num_frames, height, width, num_channels)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaLinearAttnProcessor3_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product linear attention.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
# B,N,H,C
|
||||
|
||||
query = F.relu(query)
|
||||
key = F.relu(key)
|
||||
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
):
|
||||
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
|
||||
cos = freqs_cos[..., 0::2]
|
||||
sin = freqs_sin[..., 1::2]
|
||||
out = torch.empty_like(hidden_states)
|
||||
out[..., 0::2] = x1 * cos - x2 * sin
|
||||
out[..., 1::2] = x1 * sin + x2 * cos
|
||||
return out.type_as(hidden_states)
|
||||
|
||||
query_rotate = apply_rotary_emb(query, *rotary_emb)
|
||||
key_rotate = apply_rotary_emb(key, *rotary_emb)
|
||||
|
||||
# B,H,C,N
|
||||
query = query.permute(0, 2, 3, 1)
|
||||
key = key.permute(0, 2, 3, 1)
|
||||
query_rotate = query_rotate.permute(0, 2, 3, 1)
|
||||
key_rotate = key_rotate.permute(0, 2, 3, 1)
|
||||
value = value.permute(0, 2, 3, 1)
|
||||
|
||||
query_rotate, key_rotate, value = query_rotate.float(), key_rotate.float(), value.float()
|
||||
|
||||
z = 1 / (key.sum(dim=-1, keepdim=True).transpose(-2, -1) @ query + 1e-15)
|
||||
|
||||
scores = torch.matmul(value, key_rotate.transpose(-1, -2))
|
||||
hidden_states = torch.matmul(scores, query_rotate)
|
||||
|
||||
hidden_states = hidden_states * z
|
||||
# B,H,C,N
|
||||
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size: Tuple[int, int, int],
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
freqs_cos = []
|
||||
freqs_sin = []
|
||||
|
||||
for dim in [t_dim, h_dim, w_dim]:
|
||||
freq_cos, freq_sin = get_1d_rotary_pos_embed(
|
||||
dim,
|
||||
max_seq_len,
|
||||
theta,
|
||||
use_real=True,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
freqs_cos.append(freq_cos)
|
||||
freqs_sin.append(freq_sin)
|
||||
|
||||
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
|
||||
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
|
||||
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.sana_transformer.SanaModulatedNorm
|
||||
class SanaModulatedNorm(nn.Module):
|
||||
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
|
||||
guidance_proj = self.guidance_condition_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
|
||||
conditioning = timesteps_emb + guidance_emb
|
||||
|
||||
return self.linear(self.silu(conditioning)), conditioning
|
||||
|
||||
|
||||
class SanaAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaVideoTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block introduced in [Sana-Video](https://huggingface.co/papers/2509.24695).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 2240,
|
||||
num_attention_heads: int = 20,
|
||||
attention_head_dim: int = 112,
|
||||
dropout: float = 0.0,
|
||||
num_cross_attention_heads: Optional[int] = 20,
|
||||
cross_attention_head_dim: Optional[int] = 112,
|
||||
cross_attention_dim: Optional[int] = 2240,
|
||||
attention_bias: bool = True,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
attention_out_bias: bool = True,
|
||||
mlp_ratio: float = 3.0,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
rope_max_seq_len: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 1. Self Attention
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
kv_heads=num_attention_heads if qk_norm is not None else None,
|
||||
qk_norm=qk_norm,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=None,
|
||||
processor=SanaLinearAttnProcessor3_0(),
|
||||
)
|
||||
|
||||
# 2. Cross Attention
|
||||
if cross_attention_dim is not None:
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
qk_norm=qk_norm,
|
||||
kv_heads=num_cross_attention_heads if qk_norm is not None else None,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_cross_attention_heads,
|
||||
dim_head=cross_attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
out_bias=attention_out_bias,
|
||||
processor=SanaAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ff = GLUMBTempConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
frames: int = None,
|
||||
height: int = None,
|
||||
width: int = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# 1. Modulation
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 2. Self Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
|
||||
|
||||
attn_output = self.attn1(norm_hidden_states, rotary_emb=rotary_emb)
|
||||
hidden_states = hidden_states + gate_msa * attn_output
|
||||
|
||||
# 3. Cross Attention
|
||||
if self.attn2 is not None:
|
||||
attn_output = self.attn2(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = ff_output.flatten(1, 3)
|
||||
hidden_states = hidden_states + gate_mlp * ff_output
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
|
||||
r"""
|
||||
A 3D Transformer model introduced in [Sana-Video](https://huggingface.co/papers/2509.24695) family of models.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `20`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `112`):
|
||||
The number of channels in each head.
|
||||
num_layers (`int`, defaults to `20`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
num_cross_attention_heads (`int`, *optional*, defaults to `20`):
|
||||
The number of heads to use for cross-attention.
|
||||
cross_attention_head_dim (`int`, *optional*, defaults to `112`):
|
||||
The number of channels in each head for cross-attention.
|
||||
cross_attention_dim (`int`, *optional*, defaults to `2240`):
|
||||
The number of channels in the cross-attention output.
|
||||
caption_channels (`int`, defaults to `2304`):
|
||||
The number of channels in the caption embeddings.
|
||||
mlp_ratio (`float`, defaults to `2.5`):
|
||||
The expansion ratio to use in the GLUMBConv layer.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether to use bias in the attention layer.
|
||||
sample_size (`int`, defaults to `32`):
|
||||
The base size of the input latent.
|
||||
patch_size (`int`, defaults to `1`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
norm_elementwise_affine (`bool`, defaults to `False`):
|
||||
Whether to use elementwise affinity in the normalization layer.
|
||||
norm_eps (`float`, defaults to `1e-6`):
|
||||
The epsilon value for the normalization layer.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for the query and key.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SanaVideoTransformerBlock", "SanaModulatedNorm"]
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = 16,
|
||||
num_attention_heads: int = 20,
|
||||
attention_head_dim: int = 112,
|
||||
num_layers: int = 20,
|
||||
num_cross_attention_heads: Optional[int] = 20,
|
||||
cross_attention_head_dim: Optional[int] = 112,
|
||||
cross_attention_dim: Optional[int] = 2240,
|
||||
caption_channels: int = 2304,
|
||||
mlp_ratio: float = 2.5,
|
||||
dropout: float = 0.0,
|
||||
attention_bias: bool = False,
|
||||
sample_size: int = 30,
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
guidance_embeds: bool = False,
|
||||
guidance_embeds_scale: float = 0.1,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
rope_max_seq_len: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Patch & position embedding
|
||||
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
||||
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# 2. Additional condition embeddings
|
||||
if guidance_embeds:
|
||||
self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
|
||||
else:
|
||||
self.time_embed = AdaLayerNormSingle(inner_dim)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
SanaVideoTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
num_cross_attention_heads=num_cross_attention_heads,
|
||||
cross_attention_head_dim=cross_attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output blocks
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
guidance: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Input
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
rotary_emb = self.rope(hidden_states)
|
||||
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
if guidance is not None:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
|
||||
|
||||
# 2. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
post_patch_num_frames,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
rotary_emb,
|
||||
)
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
else:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
post_patch_num_frames,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
rotary_emb,
|
||||
)
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
# 3. Normalization
|
||||
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
|
||||
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. Unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -164,11 +164,7 @@ class AutoOffloadStrategy:
|
||||
|
||||
device_type = execution_device.type
|
||||
device_module = getattr(torch, device_type, torch.cuda)
|
||||
try:
|
||||
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
|
||||
except AttributeError:
|
||||
raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.")
|
||||
|
||||
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
|
||||
mem_on_device = mem_on_device - self.memory_reserve_margin
|
||||
if current_module_size < mem_on_device:
|
||||
return []
|
||||
@@ -703,8 +699,6 @@ class ComponentsManager:
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
|
||||
|
||||
# TODO: add a warning if mem_get_info isn't available on `device`.
|
||||
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
|
||||
remove_hook_from_module(component, recurse=True)
|
||||
|
||||
@@ -598,7 +598,7 @@ class FluxKontextRoPEInputsStep(ModularPipelineBlocks):
|
||||
and getattr(block_state, "image_width", None) is not None
|
||||
):
|
||||
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
|
||||
image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
|
||||
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
|
||||
img_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
None, image_latent_height // 2, image_latent_width // 2, device, dtype
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
|
||||
),
|
||||
InputParam(
|
||||
"guidance",
|
||||
required=False,
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Guidance scale as a tensor",
|
||||
),
|
||||
@@ -141,7 +141,7 @@ class FluxKontextLoopDenoiser(ModularPipelineBlocks):
|
||||
),
|
||||
InputParam(
|
||||
"guidance",
|
||||
required=False,
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Guidance scale as a tensor",
|
||||
),
|
||||
|
||||
@@ -95,7 +95,7 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks):
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
@@ -143,6 +143,10 @@ class FluxProcessImagesInputStep(ModularPipelineBlocks):
|
||||
class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
def __init__(self, _auto_resize=True):
|
||||
self._auto_resize = _auto_resize
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
@@ -163,7 +167,7 @@ class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)]
|
||||
return [InputParam("image")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
@@ -191,8 +195,7 @@ class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
|
||||
img = images[0]
|
||||
image_height, image_width = components.image_processor.get_default_height_width(img)
|
||||
aspect_ratio = image_width / image_height
|
||||
_auto_resize = block_state._auto_resize
|
||||
if _auto_resize:
|
||||
if self._auto_resize:
|
||||
# Kontext is trained on specific resolutions, using one of them is recommended
|
||||
_, image_width, image_height = min(
|
||||
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
|
||||
|
||||
@@ -112,10 +112,6 @@ class FluxTextInputStep(ModularPipelineBlocks):
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt)
|
||||
block_state.pooled_prompt_embeds = pooled_prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, -1
|
||||
)
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
@@ -307,13 +307,14 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"local_files_only",
|
||||
"local_dir",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"subfolder",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||
|
||||
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
|
||||
config = cls.load_config(pretrained_model_name_or_path)
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
@@ -2130,13 +2131,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
component_load_kwargs[key] = value["default"]
|
||||
try:
|
||||
components_to_register[name] = spec.load(**component_load_kwargs)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"\nFailed to create component {name}:\n"
|
||||
f"- Component spec: {spec}\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n\n"
|
||||
f"{traceback.format_exc()}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create component '{name}': {e}")
|
||||
|
||||
# Register all components at once
|
||||
self.register_components(**components_to_register)
|
||||
|
||||
@@ -308,7 +308,6 @@ else:
|
||||
"SanaSprintPipeline",
|
||||
"SanaControlNetPipeline",
|
||||
"SanaSprintImg2ImgPipeline",
|
||||
"SanaVideoPipeline",
|
||||
]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
@@ -736,13 +735,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImagePipeline,
|
||||
)
|
||||
from .sana import (
|
||||
SanaControlNetPipeline,
|
||||
SanaPipeline,
|
||||
SanaSprintImg2ImgPipeline,
|
||||
SanaSprintPipeline,
|
||||
SanaVideoPipeline,
|
||||
)
|
||||
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
|
||||
|
||||
@@ -355,7 +355,7 @@ class StableDiffusion3ControlNetPipeline(
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -373,7 +373,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -326,7 +326,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -342,7 +342,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -26,7 +26,6 @@ else:
|
||||
_import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"]
|
||||
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
|
||||
_import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"]
|
||||
_import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -40,7 +39,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_sana_controlnet import SanaControlNetPipeline
|
||||
from .pipeline_sana_sprint import SanaSprintPipeline
|
||||
from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline
|
||||
from .pipeline_sana_video import SanaVideoPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
@@ -20,18 +19,3 @@ class SanaPipelineOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SanaVideoPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Sana-Video pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2025 SANA Authors and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2025 SANA-Sprint Authors and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -336,7 +336,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -361,7 +361,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -367,7 +367,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
@@ -1308,21 +1308,6 @@ class SanaTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SanaVideoTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SD3ControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -2177,21 +2177,6 @@ class SanaSprintPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class SanaVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class SemanticStableDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -358,7 +358,6 @@ def get_cached_module_file(
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
local_dir=local_dir,
|
||||
revision=revision,
|
||||
token=token,
|
||||
)
|
||||
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
|
||||
|
||||
@@ -13,12 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
|
||||
|
||||
@@ -112,65 +111,3 @@ class VideoProcessor(VaeImageProcessor):
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the binned height and width based on the aspect ratio.
|
||||
|
||||
Args:
|
||||
height (`int`): The height of the image.
|
||||
width (`int`): The width of the image.
|
||||
ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`: The closest binned height and width.
|
||||
"""
|
||||
ar = float(height / width)
|
||||
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
||||
default_hw = ratios[closest_ratio]
|
||||
return int(default_hw[0]), int(default_hw[1])
|
||||
|
||||
@staticmethod
|
||||
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
|
||||
r"""
|
||||
Resizes and crops a tensor of videos to the specified dimensions.
|
||||
|
||||
Args:
|
||||
samples (`torch.Tensor`):
|
||||
A tensor of shape (N, C, T, H, W) where N is the batch size, C is the number of channels, T is the
|
||||
number of frames, H is the height, and W is the width.
|
||||
new_width (`int`): The desired width of the output videos.
|
||||
new_height (`int`): The desired height of the output videos.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: A tensor containing the resized and cropped videos.
|
||||
"""
|
||||
orig_height, orig_width = samples.shape[3], samples.shape[4]
|
||||
|
||||
# Check if resizing is needed
|
||||
if orig_height != new_height or orig_width != new_width:
|
||||
ratio = max(new_height / orig_height, new_width / orig_width)
|
||||
resized_width = int(orig_width * ratio)
|
||||
resized_height = int(orig_height * ratio)
|
||||
|
||||
# Reshape to (N*T, C, H, W) for interpolation
|
||||
n, c, t, h, w = samples.shape
|
||||
samples = samples.permute(0, 2, 1, 3, 4).reshape(n * t, c, h, w)
|
||||
|
||||
# Resize
|
||||
samples = F.interpolate(
|
||||
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
# Center Crop
|
||||
start_x = (resized_width - new_width) // 2
|
||||
end_x = start_x + new_width
|
||||
start_y = (resized_height - new_height) // 2
|
||||
end_y = start_y + new_height
|
||||
samples = samples[:, :, start_y:end_y, start_x:end_x]
|
||||
|
||||
# Reshape back to (N, C, T, H, W)
|
||||
samples = samples.reshape(n, t, c, new_height, new_width).permute(0, 2, 1, 3, 4)
|
||||
|
||||
return samples
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import SanaVideoTransformer3DModel
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SanaVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = SanaVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 16
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 2, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
"num_layers": 2,
|
||||
"num_cross_attention_heads": 2,
|
||||
"cross_attention_head_dim": 12,
|
||||
"cross_attention_dim": 24,
|
||||
"caption_channels": 16,
|
||||
"mlp_ratio": 2.5,
|
||||
"dropout": 0.0,
|
||||
"attention_bias": False,
|
||||
"sample_size": 8,
|
||||
"patch_size": (1, 2, 2),
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"SanaVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class SanaVideoTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = SanaVideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return SanaVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
@@ -1,172 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.modular_pipelines import (
|
||||
FluxAutoBlocks,
|
||||
FluxKontextAutoBlocks,
|
||||
FluxKontextModularPipeline,
|
||||
FluxModularPipeline,
|
||||
ModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-flux-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 48,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
|
||||
class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-flux-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = super().get_pipeline(components_manager, torch_dtype)
|
||||
|
||||
# Override `vae_scale_factor` here as currently, `image_processor` is initialized with
|
||||
# fixed constants instead of
|
||||
# https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10
|
||||
pipeline.image_processor = VaeImageProcessor(vae_scale_factor=2)
|
||||
return pipeline
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 48,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
|
||||
inputs["image"] = init_image
|
||||
inputs["strength"] = 0.5
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_from_pretrained(self):
|
||||
pipes = []
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(base_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
base_pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
|
||||
|
||||
pipes.append(pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
|
||||
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxKontextModularPipeline
|
||||
pipeline_blocks_class = FluxKontextAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 48,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = PIL.Image.new("RGB", (32, 32), 0)
|
||||
|
||||
inputs["image"] = image
|
||||
inputs["max_area"] = inputs["height"] * inputs["width"]
|
||||
inputs["_auto_resize"] = False
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_from_pretrained(self):
|
||||
pipes = []
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(base_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
base_pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
|
||||
|
||||
pipes.append(pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
@@ -14,43 +14,93 @@
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import ClassifierFreeGuidance, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
from diffusers import (
|
||||
ClassifierFreeGuidance,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
)
|
||||
from diffusers.loaders import ModularIPAdapterMixin
|
||||
|
||||
from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
from ...models.unets.test_models_unet_2d_condition import (
|
||||
create_ip_adapter_state_dict,
|
||||
)
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modular_pipelines_common import (
|
||||
ModularPipelineTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SDXLModularTesterMixin:
|
||||
class SDXLModularTests:
|
||||
"""
|
||||
This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
|
||||
"""
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"negative_prompt",
|
||||
"cross_attention_kwargs",
|
||||
"image",
|
||||
"mask_image",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline.load_components(torch_dtype=torch_dtype)
|
||||
return pipeline
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
sd_pipe = self.get_pipeline()
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs, output="images")
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == expected_image_shape
|
||||
max_diff = torch.abs(image_slice.flatten() - expected_slice).max()
|
||||
assert max_diff < expected_max_diff, f"Image slice does not match expected slice. Max Difference: {max_diff}"
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
|
||||
"Image Slice does not match expected slice"
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularIPAdapterTesterMixin:
|
||||
class SDXLModularIPAdapterTests:
|
||||
"""
|
||||
This mixin is designed to test IP Adapter.
|
||||
"""
|
||||
@@ -89,7 +139,7 @@ class SDXLModularIPAdapterTesterMixin:
|
||||
if "image" in parameters and "strength" in parameters:
|
||||
inputs["num_inference_steps"] = 4
|
||||
|
||||
inputs["output_type"] = "pt"
|
||||
inputs["output_type"] = "np"
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
@@ -114,7 +164,7 @@ class SDXLModularIPAdapterTesterMixin:
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
if expected_pipe_slice is None:
|
||||
output_without_adapter = pipe(**inputs, output="images")
|
||||
else:
|
||||
@@ -125,7 +175,7 @@ class SDXLModularIPAdapterTesterMixin:
|
||||
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
# forward pass with single ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(0.0)
|
||||
@@ -134,7 +184,7 @@ class SDXLModularIPAdapterTesterMixin:
|
||||
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
pipe.set_ip_adapter_scale(42.0)
|
||||
@@ -142,8 +192,8 @@ class SDXLModularIPAdapterTesterMixin:
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_adapter_scale = torch.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = torch.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
|
||||
assert max_diff_without_adapter_scale < expected_max_diff, (
|
||||
"Output without ip-adapter must be same as normal inference"
|
||||
@@ -156,7 +206,7 @@ class SDXLModularIPAdapterTesterMixin:
|
||||
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
|
||||
|
||||
# forward pass with multi ip adapter, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([0.0, 0.0])
|
||||
@@ -165,7 +215,7 @@ class SDXLModularIPAdapterTesterMixin:
|
||||
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with multi ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([42.0, 42.0])
|
||||
@@ -173,10 +223,10 @@ class SDXLModularIPAdapterTesterMixin:
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_multi_adapter_scale = torch.abs(
|
||||
max_diff_without_multi_adapter_scale = np.abs(
|
||||
output_without_multi_adapter_scale - output_without_adapter
|
||||
).max()
|
||||
max_diff_with_multi_adapter_scale = torch.abs(output_with_multi_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
|
||||
assert max_diff_without_multi_adapter_scale < expected_max_diff, (
|
||||
"Output without multi-ip-adapter must be same as normal inference"
|
||||
)
|
||||
@@ -185,7 +235,7 @@ class SDXLModularIPAdapterTesterMixin:
|
||||
)
|
||||
|
||||
|
||||
class SDXLModularControlNetTesterMixin:
|
||||
class SDXLModularControlNetTests:
|
||||
"""
|
||||
This mixin is designed to test ControlNet.
|
||||
"""
|
||||
@@ -224,26 +274,24 @@ class SDXLModularControlNetTesterMixin:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass without controlnet
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_controlnet = pipe(**inputs, output="images")
|
||||
output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs())
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 0.0
|
||||
output_without_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single controlnet, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs())
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["controlnet_conditioning_scale"] = 42.0
|
||||
output_with_controlnet_scale = pipe(**inputs, output="images")
|
||||
output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_controlnet_scale = torch.abs(
|
||||
output_without_controlnet_scale - output_without_controlnet
|
||||
).max()
|
||||
max_diff_with_controlnet_scale = torch.abs(output_with_controlnet_scale - output_without_controlnet).max()
|
||||
max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
|
||||
max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
|
||||
|
||||
assert max_diff_without_controlnet_scale < expected_max_diff, (
|
||||
"Output without controlnet must be same as normal inference"
|
||||
@@ -259,21 +307,21 @@ class SDXLModularControlNetTesterMixin:
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs())
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs())
|
||||
inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = torch.abs(out_cfg - out_no_cfg).max()
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularGuiderTesterMixin:
|
||||
class SDXLModularGuiderTests:
|
||||
def test_guider_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
@@ -283,13 +331,13 @@ class SDXLModularGuiderTesterMixin:
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
@@ -297,57 +345,30 @@ class SDXLModularGuiderTesterMixin:
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestSDXLModularPipelineFast(
|
||||
SDXLModularTesterMixin,
|
||||
SDXLModularIPAdapterTesterMixin,
|
||||
SDXLModularControlNetTesterMixin,
|
||||
SDXLModularGuiderTesterMixin,
|
||||
class SDXLModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"negative_prompt",
|
||||
"cross_attention_kwargs",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
expected_image_output_shape = (1, 3, 64, 64)
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=self.expected_image_output_shape,
|
||||
expected_slice=torch.tensor(
|
||||
[
|
||||
0.5966781,
|
||||
0.62939394,
|
||||
0.48465094,
|
||||
0.51573336,
|
||||
0.57593524,
|
||||
0.47035995,
|
||||
0.53410417,
|
||||
0.51436996,
|
||||
0.47313565,
|
||||
],
|
||||
device=torch_device,
|
||||
),
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.5966781,
|
||||
0.62939394,
|
||||
0.48465094,
|
||||
0.51573336,
|
||||
0.57593524,
|
||||
0.47035995,
|
||||
0.53410417,
|
||||
0.51436996,
|
||||
0.47313565,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
@@ -355,65 +376,39 @@ class TestSDXLModularPipelineFast(
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
class TestSDXLImg2ImgModularPipelineFast(
|
||||
SDXLModularTesterMixin,
|
||||
SDXLModularIPAdapterTesterMixin,
|
||||
SDXLModularControlNetTesterMixin,
|
||||
SDXLModularGuiderTesterMixin,
|
||||
class SDXLImg2ImgModularPipelineFastTests(
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"negative_prompt",
|
||||
"cross_attention_kwargs",
|
||||
"image",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image"])
|
||||
expected_image_output_shape = (1, 3, 64, 64)
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 4,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
|
||||
inputs["image"] = init_image
|
||||
inputs["strength"] = 0.5
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
inputs["image"] = image
|
||||
inputs["strength"] = 0.8
|
||||
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=self.expected_image_output_shape,
|
||||
expected_slice=torch.tensor(
|
||||
[
|
||||
0.56943184,
|
||||
0.4702148,
|
||||
0.48048905,
|
||||
0.6235963,
|
||||
0.551138,
|
||||
0.49629188,
|
||||
0.60031277,
|
||||
0.5688907,
|
||||
0.43996853,
|
||||
],
|
||||
device=torch_device,
|
||||
),
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.56943184,
|
||||
0.4702148,
|
||||
0.48048905,
|
||||
0.6235963,
|
||||
0.551138,
|
||||
0.49629188,
|
||||
0.60031277,
|
||||
0.5688907,
|
||||
0.43996853,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
@@ -422,43 +417,20 @@ class TestSDXLImg2ImgModularPipelineFast(
|
||||
|
||||
|
||||
class SDXLInpaintingModularPipelineFastTests(
|
||||
SDXLModularTesterMixin,
|
||||
SDXLModularIPAdapterTesterMixin,
|
||||
SDXLModularControlNetTesterMixin,
|
||||
SDXLModularGuiderTesterMixin,
|
||||
SDXLModularTests,
|
||||
SDXLModularIPAdapterTests,
|
||||
SDXLModularControlNetTests,
|
||||
SDXLModularGuiderTests,
|
||||
ModularPipelineTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"negative_prompt",
|
||||
"cross_attention_kwargs",
|
||||
"image",
|
||||
"mask_image",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
expected_image_output_shape = (1, 3, 64, 64)
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 4,
|
||||
"output_type": "pt",
|
||||
}
|
||||
inputs = super().get_dummy_inputs(device, seed)
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
|
||||
# create mask
|
||||
image[8:, 8:, :] = 255
|
||||
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
|
||||
@@ -471,21 +443,18 @@ class SDXLInpaintingModularPipelineFastTests(
|
||||
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=self.expected_image_output_shape,
|
||||
expected_slice=torch.tensor(
|
||||
[
|
||||
0.40872607,
|
||||
0.38842705,
|
||||
0.34893104,
|
||||
0.47837183,
|
||||
0.43792963,
|
||||
0.5332134,
|
||||
0.3716843,
|
||||
0.47274873,
|
||||
0.45000193,
|
||||
],
|
||||
device=torch_device,
|
||||
),
|
||||
expected_image_shape=(1, 64, 64, 3),
|
||||
expected_slice=[
|
||||
0.40872607,
|
||||
0.38842705,
|
||||
0.34893104,
|
||||
0.47837183,
|
||||
0.43792963,
|
||||
0.5332134,
|
||||
0.3716843,
|
||||
0.47274873,
|
||||
0.45000193,
|
||||
],
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
@@ -17,9 +19,17 @@ from ..testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.detach().cpu().numpy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with unittest.TestCase classes.
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
including:
|
||||
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
|
||||
@@ -47,8 +57,9 @@ class ModularPipelineTesterMixin:
|
||||
]
|
||||
)
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator = torch.Generator("cpu").manual_seed(seed)
|
||||
def get_generator(self, seed):
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
return generator
|
||||
|
||||
@property
|
||||
@@ -71,7 +82,13 @@ class ModularPipelineTesterMixin:
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
def get_pipeline(self):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_pipeline(self)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
raise NotImplementedError(
|
||||
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
|
||||
"See existing pipeline tests for reference."
|
||||
@@ -106,23 +123,20 @@ class ModularPipelineTesterMixin:
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def setup_method(self):
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
super().setUp()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline.load_components(torch_dtype=torch_dtype)
|
||||
return pipeline
|
||||
|
||||
def test_pipeline_call_signature(self):
|
||||
pipe = self.get_pipeline()
|
||||
input_parameters = pipe.blocks.input_names
|
||||
@@ -142,7 +156,7 @@ class ModularPipelineTesterMixin:
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
@@ -182,7 +196,7 @@ class ModularPipelineTesterMixin:
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
@@ -212,9 +226,10 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert output_batch.shape[0] == batch_size
|
||||
|
||||
max_diff = torch.abs(output_batch[0] - output[0]).max()
|
||||
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
|
||||
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_accelerator
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
pipe = self.get_pipeline()
|
||||
@@ -225,13 +240,13 @@ class ModularPipelineTesterMixin:
|
||||
pipe_fp16.to(torch_device, torch.float16)
|
||||
pipe_fp16.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in inputs:
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
output = pipe(**inputs, output="images")
|
||||
|
||||
fp16_inputs = self.get_dummy_inputs()
|
||||
fp16_inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
if "generator" in fp16_inputs:
|
||||
fp16_inputs["generator"] = self.get_generator(0)
|
||||
@@ -268,8 +283,8 @@ class ModularPipelineTesterMixin:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to("cpu")
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(), output="images")
|
||||
assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"
|
||||
output = pipe(**self.get_dummy_inputs("cpu"), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
|
||||
|
||||
@require_accelerator
|
||||
def test_inference_is_not_nan(self):
|
||||
@@ -277,8 +292,8 @@ class ModularPipelineTesterMixin:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(), output="images")
|
||||
assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN"
|
||||
output = pipe(**self.get_dummy_inputs(torch_device), output="images")
|
||||
assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
pipe = self.get_pipeline()
|
||||
@@ -294,7 +309,7 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
@@ -314,12 +329,12 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
image_slices = []
|
||||
for pipe in [base_pipe, offload_pipe]:
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_save_from_pretrained(self):
|
||||
pipes = []
|
||||
@@ -336,9 +351,9 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
@@ -1,225 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
|
||||
|
||||
from diffusers import AutoencoderKLWan, DPMSolverMultistepScheduler, SanaVideoPipeline, SanaVideoTransformer3DModel
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SanaVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = SanaVideoPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = DPMSolverMultistepScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = Gemma2Config(
|
||||
head_dim=16,
|
||||
hidden_size=8,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=64,
|
||||
max_position_embeddings=8192,
|
||||
model_type="gemma2",
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=1,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=8,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
text_encoder = Gemma2Model(text_encoder_config)
|
||||
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = SanaVideoTransformer3DModel(
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
num_layers=2,
|
||||
num_cross_attention_heads=2,
|
||||
cross_attention_head_dim=12,
|
||||
cross_attention_dim=24,
|
||||
caption_channels=8,
|
||||
mlp_ratio=2.5,
|
||||
dropout=0.0,
|
||||
attention_bias=False,
|
||||
sample_size=8,
|
||||
patch_size=(1, 2, 2),
|
||||
norm_elementwise_affine=False,
|
||||
norm_eps=1e-6,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
"complex_human_instruction": [],
|
||||
"use_resolution_binning": False,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_save_load_local(self, expected_max_difference=5e-4):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
# TODO(aryan): Create a dummy gemma model with smol vocab size
|
||||
@unittest.skip(
|
||||
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
|
||||
)
|
||||
def test_inference_batch_consistent(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
|
||||
)
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
def test_float16_inference(self):
|
||||
# Requires higher tolerance as model seems very sensitive to dtype
|
||||
super().test_float16_inference(expected_max_diff=0.08)
|
||||
|
||||
def test_save_load_float16(self):
|
||||
# Requires higher tolerance as model seems very sensitive to dtype
|
||||
super().test_save_load_float16(expected_max_diff=0.2)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
class SanaVideoPipelineIntegrationTests(unittest.TestCase):
|
||||
prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest."
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@unittest.skip("TODO: test needs to be implemented")
|
||||
def test_sana_video_480p(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user