mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-24 21:34:55 +08:00
Compare commits
17 Commits
add-uv-scr
...
deal-audio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d095f3b85 | ||
|
|
5937e11d85 | ||
|
|
9c1d4e3be1 | ||
|
|
7ea065c507 | ||
|
|
7a7a487396 | ||
|
|
4efb4db9d0 | ||
|
|
639fd12a20 | ||
|
|
69a9828f4d | ||
|
|
11d22e0e80 | ||
|
|
9a38fab5ae | ||
|
|
cb8e61ed2f | ||
|
|
8e53cd959e | ||
|
|
359b605f4b | ||
|
|
6febc08bfc | ||
|
|
9a2eaed002 | ||
|
|
0c71189abe | ||
|
|
58d2b10a2e |
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
group: aws-g6e-4xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
|
||||
14
.github/workflows/nightly_tests.yml
vendored
14
.github/workflows/nightly_tests.yml
vendored
@@ -61,7 +61,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -107,7 +107,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -178,7 +178,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -222,7 +222,7 @@ jobs:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -270,7 +270,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-minimum-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -344,7 +344,7 @@ jobs:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "20gb" --ipc host --gpus 0
|
||||
options: --shm-size "20gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -396,7 +396,7 @@ jobs:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "20gb" --ipc host --gpus 0
|
||||
options: --shm-size "20gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
|
||||
6
.github/workflows/pr_tests_gpu.yml
vendored
6
.github/workflows/pr_tests_gpu.yml
vendored
@@ -118,7 +118,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -183,7 +183,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -253,7 +253,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
|
||||
10
.github/workflows/push_tests.yml
vendored
10
.github/workflows/push_tests.yml
vendored
@@ -64,7 +64,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -109,7 +109,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -167,7 +167,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -210,7 +210,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-xformers-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -252,7 +252,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
|
||||
12
.github/workflows/release_tests_fast.yml
vendored
12
.github/workflows/release_tests_fast.yml
vendored
@@ -62,7 +62,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -107,7 +107,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -163,7 +163,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-minimum-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -222,7 +222,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -265,7 +265,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-xformers-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -307,7 +307,7 @@ jobs:
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
|
||||
2
.github/workflows/run_tests_from_a_pr.yml
vendored
2
.github/workflows/run_tests_from_a_pr.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: ${{ github.event.inputs.docker_image }}
|
||||
options: --gpus 0 --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
|
||||
steps:
|
||||
- name: Validate test files input
|
||||
|
||||
2
.github/workflows/ssh-runner.yml
vendored
2
.github/workflows/ssh-runner.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
group: "${{ github.event.inputs.runner_type }}"
|
||||
container:
|
||||
image: ${{ github.event.inputs.docker_image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus 0 --privileged
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus all --privileged
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
|
||||
@@ -366,6 +366,8 @@
|
||||
title: PixArtTransformer2DModel
|
||||
- local: api/models/prior_transformer
|
||||
title: PriorTransformer
|
||||
- local: api/models/qwenimage_transformer2d
|
||||
title: QwenImageTransformer2DModel
|
||||
- local: api/models/sana_transformer2d
|
||||
title: SanaTransformer2DModel
|
||||
- local: api/models/sd3_transformer2d
|
||||
@@ -418,6 +420,8 @@
|
||||
title: AutoencoderKLMagvit
|
||||
- local: api/models/autoencoderkl_mochi
|
||||
title: AutoencoderKLMochi
|
||||
- local: api/models/autoencoderkl_qwenimage
|
||||
title: AutoencoderKLQwenImage
|
||||
- local: api/models/autoencoder_kl_wan
|
||||
title: AutoencoderKLWan
|
||||
- local: api/models/consistency_decoder_vae
|
||||
@@ -554,6 +558,8 @@
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/pixart_sigma
|
||||
title: PixArt-Σ
|
||||
- local: api/pipelines/qwenimage
|
||||
title: QwenImage
|
||||
- local: api/pipelines/sana
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
|
||||
@@ -30,6 +30,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
|
||||
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
|
||||
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
|
||||
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen)
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
<Tip>
|
||||
@@ -105,6 +106,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
|
||||
|
||||
## QwenImageLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin
|
||||
|
||||
## LoraBaseMixin
|
||||
|
||||
[[autodoc]] loaders.lora_base.LoraBaseMixin
|
||||
35
docs/source/en/api/models/autoencoderkl_qwenimage.md
Normal file
35
docs/source/en/api/models/autoencoderkl_qwenimage.md
Normal file
@@ -0,0 +1,35 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLQwenImage
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLQwenImage
|
||||
|
||||
vae = AutoencoderKLQwenImage.from_pretrained("Qwen/QwenImage-20B", subfolder="vae")
|
||||
```
|
||||
|
||||
## AutoencoderKLQwenImage
|
||||
|
||||
[[autodoc]] AutoencoderKLQwenImage
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
28
docs/source/en/api/models/qwenimage_transformer2d.md
Normal file
28
docs/source/en/api/models/qwenimage_transformer2d.md
Normal file
@@ -0,0 +1,28 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# QwenImageTransformer2DModel
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
|
||||
transformer = QwenImageTransformer2DModel.from_pretrained("Qwen/QwenImage-20B", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## QwenImageTransformer2DModel
|
||||
|
||||
[[autodoc]] QwenImageTransformer2DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
35
docs/source/en/api/pipelines/qwenimage.md
Normal file
35
docs/source/en/api/pipelines/qwenimage.md
Normal file
@@ -0,0 +1,35 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# QwenImage
|
||||
|
||||
Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.
|
||||
|
||||
Check out the model card [here](https://huggingface.co/Qwen/Qwen-Image) to learn more.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## QwenImagePipeline
|
||||
|
||||
[[autodoc]] QwenImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## QwenImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
|
||||
@@ -29,6 +29,7 @@
|
||||
You can find all the original Wan2.1 checkpoints under the [Wan-AI](https://huggingface.co/Wan-AI) organization.
|
||||
|
||||
The following Wan models are supported in Diffusers:
|
||||
|
||||
- [Wan 2.1 T2V 1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)
|
||||
- [Wan 2.1 T2V 14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers)
|
||||
- [Wan 2.1 I2V 14B - 480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)
|
||||
@@ -36,6 +37,9 @@ The following Wan models are supported in Diffusers:
|
||||
- [Wan 2.1 FLF2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers)
|
||||
- [Wan 2.1 VACE 1.3B](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers)
|
||||
- [Wan 2.1 VACE 14B](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B-diffusers)
|
||||
- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
|
||||
- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
|
||||
- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
|
||||
|
||||
> [!TIP]
|
||||
> Click on the Wan2.1 models in the right sidebar for more examples of video generation.
|
||||
@@ -327,6 +331,8 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
|
||||
|
||||
- Try lower `shift` values (`2.0` to `5.0`) for lower resolution videos and higher `shift` values (`7.0` to `12.0`) for higher resolution images.
|
||||
|
||||
- Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more.
|
||||
|
||||
## WanPipeline
|
||||
|
||||
[[autodoc]] WanPipeline
|
||||
|
||||
@@ -12,183 +12,156 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Installation
|
||||
|
||||
🤗 Diffusers is tested on Python 3.8+, PyTorch 1.7.0+, and Flax. Follow the installation instructions below for the deep learning library you are using:
|
||||
Diffusers is tested on Python 3.8+, PyTorch 1.4+, and Flax 0.4.1+. Follow the installation instructions for the deep learning library you're using, [PyTorch](https://pytorch.org/get-started/locally/) or [Flax](https://flax.readthedocs.io/en/latest/).
|
||||
|
||||
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions
|
||||
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions
|
||||
|
||||
## Install with pip
|
||||
|
||||
You should install 🤗 Diffusers in a [virtual environment](https://docs.python.org/3/library/venv.html).
|
||||
If you're unfamiliar with Python virtual environments, take a look at this [guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
A virtual environment makes it easier to manage different projects and avoid compatibility issues between dependencies.
|
||||
|
||||
Create a virtual environment with Python or [uv](https://docs.astral.sh/uv/) (refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), a fast Rust-based Python package and project manager.
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="uv">
|
||||
Create a [virtual environment](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) for easier management of separate projects and to avoid compatibility issues between dependencies. Use [uv](https://docs.astral.sh/uv/), a Rust-based Python package and project manager, to create a virtual environment and install Diffusers.
|
||||
|
||||
```bash
|
||||
uv venv my-env
|
||||
source my-env/bin/activate
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Python">
|
||||
Install Diffusers with one of the following methods.
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="pip">
|
||||
|
||||
PyTorch only supports Python 3.8 - 3.11 on Windows.
|
||||
|
||||
```bash
|
||||
python -m venv my-env
|
||||
source my-env/bin/activate
|
||||
uv pip install diffusers["torch"] transformers
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models.
|
||||
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
PyTorch only supports Python 3.8 - 3.11 on Windows. Install Diffusers with uv.
|
||||
|
||||
```bash
|
||||
uv install diffusers["torch"] transformers
|
||||
```
|
||||
|
||||
You can also install Diffusers with pip.
|
||||
|
||||
```bash
|
||||
pip install diffusers["torch"] transformers
|
||||
```
|
||||
|
||||
</pt>
|
||||
<jax>
|
||||
|
||||
Install Diffusers with uv.
|
||||
Use the command below for Flax.
|
||||
|
||||
```bash
|
||||
uv pip install diffusers["flax"] transformers
|
||||
```
|
||||
|
||||
You can also install Diffusers with pip.
|
||||
|
||||
```bash
|
||||
pip install diffusers["flax"] transformers
|
||||
```
|
||||
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
## Install with conda
|
||||
|
||||
After activating your virtual environment, with `conda` (maintained by the community):
|
||||
</hfoption>
|
||||
<hfoption id="conda">
|
||||
|
||||
```bash
|
||||
conda install -c conda-forge diffusers
|
||||
```
|
||||
|
||||
## Install from source
|
||||
</hfoption>
|
||||
<hfoption id="source">
|
||||
|
||||
Before installing 🤗 Diffusers from source, make sure you have PyTorch and 🤗 Accelerate installed.
|
||||
A source install installs the `main` version instead of the latest `stable` version. The `main` version is useful for staying updated with the latest changes but it may not always be stable. If you run into a problem, open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and we will try to resolve it as soon as possible.
|
||||
|
||||
To install 🤗 Accelerate:
|
||||
Make sure [Accelerate](https://huggingface.co/docs/accelerate/index) is installed.
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
uv pip install accelerate
|
||||
```
|
||||
|
||||
Then install 🤗 Diffusers from source:
|
||||
Install Diffusers from source with the command below.
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/diffusers
|
||||
uv pip install git+https://github.com/huggingface/diffusers
|
||||
```
|
||||
|
||||
This command installs the bleeding edge `main` version rather than the latest `stable` version.
|
||||
The `main` version is useful for staying up-to-date with the latest developments.
|
||||
For instance, if a bug has been fixed since the last official release but a new release hasn't been rolled out yet.
|
||||
However, this means the `main` version may not always be stable.
|
||||
We strive to keep the `main` version operational, and most issues are usually resolved within a few hours or a day.
|
||||
If you run into a problem, please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) so we can fix it even sooner!
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Editable install
|
||||
|
||||
You will need an editable install if you'd like to:
|
||||
An editable install is recommended for development workflows or if you're using the `main` version of the source code. A special link is created between the cloned repository and the Python library paths. This avoids reinstalling a package after every change.
|
||||
|
||||
* Use the `main` version of the source code.
|
||||
* Contribute to 🤗 Diffusers and need to test changes in the code.
|
||||
Clone the repository and install Diffusers with the following commands.
|
||||
|
||||
Clone the repository and install 🤗 Diffusers with the following commands:
|
||||
<hfoptions id="editable">
|
||||
<hfoption id="PyTorch">
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
uv pip install -e ".[torch]"
|
||||
```
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
</hfoption>
|
||||
<hfoption id="Flax">
|
||||
|
||||
```bash
|
||||
pip install -e ".[torch]"
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
uv pip install -e ".[flax]"
|
||||
```
|
||||
</pt>
|
||||
<jax>
|
||||
```bash
|
||||
pip install -e ".[flax]"
|
||||
```
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
These commands will link the folder you cloned the repository to and your Python library paths.
|
||||
Python will now look inside the folder you cloned to in addition to the normal library paths.
|
||||
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.10/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<Tip warning={true}>
|
||||
> [!WARNING]
|
||||
> You must keep the `diffusers` folder if you want to keep using the library with the editable install.
|
||||
|
||||
You must keep the `diffusers` folder if you want to keep using the library.
|
||||
|
||||
</Tip>
|
||||
|
||||
Now you can easily update your clone to the latest version of 🤗 Diffusers with the following command:
|
||||
Update your cloned repository to the latest version of Diffusers with the command below.
|
||||
|
||||
```bash
|
||||
cd ~/diffusers/
|
||||
git pull
|
||||
```
|
||||
|
||||
Your Python environment will find the `main` version of 🤗 Diffusers on the next run.
|
||||
|
||||
## Cache
|
||||
|
||||
Model weights and files are downloaded from the Hub to a cache which is usually your home directory. You can change the cache location by specifying the `HF_HOME` or `HUGGINFACE_HUB_CACHE` environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
|
||||
Model weights and files are downloaded from the Hub to a cache, which is usually your home directory. Change the cache location with the [HF_HOME](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhome) or [HF_HUB_CACHE](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhubcache) environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `1` and 🤗 Diffusers will only load previously downloaded files in the cache.
|
||||
<hfoptions id="cache">
|
||||
<hfoption id="env variable">
|
||||
|
||||
```bash
|
||||
export HF_HOME="/path/to/your/cache"
|
||||
export HF_HUB_CACHE="/path/to/your/hub/cache"
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="from_pretrained">
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
cache_dir="/path/to/your/cache"
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Cached files allow you to use Diffusers offline. Set the [HF_HUB_OFFLINE](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhuboffline) environment variable to `1` to prevent Diffusers from connecting to the internet.
|
||||
|
||||
```shell
|
||||
export HF_HUB_OFFLINE=1
|
||||
```
|
||||
|
||||
For more details about managing and cleaning the cache, take a look at the [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
|
||||
For more details about managing and cleaning the cache, take a look at the [Understand caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
|
||||
|
||||
## Telemetry logging
|
||||
|
||||
Our library gathers telemetry information during [`~DiffusionPipeline.from_pretrained`] requests.
|
||||
The data gathered includes the version of 🤗 Diffusers and PyTorch/Flax, the requested model or pipeline class,
|
||||
and the path to a pretrained checkpoint if it is hosted on the Hugging Face Hub.
|
||||
Diffusers gathers telemetry information during [`~DiffusionPipeline.from_pretrained`] requests.
|
||||
The data gathered includes the Diffusers and PyTorch/Flax version, the requested model or pipeline class,
|
||||
and the path to a pretrained checkpoint if it is hosted on the Hub.
|
||||
|
||||
This usage data helps us debug issues and prioritize new features.
|
||||
Telemetry is only sent when loading models and pipelines from the Hub,
|
||||
and it is not collected if you're loading local files.
|
||||
|
||||
We understand that not everyone wants to share additional information,and we respect your privacy.
|
||||
You can disable telemetry collection by setting the `HF_HUB_DISABLE_TELEMETRY` environment variable from your terminal:
|
||||
Opt-out and disable telemetry collection with the [HF_HUB_DISABLE_TELEMETRY](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhubdisabletelemetry) environment variable.
|
||||
|
||||
On Linux/MacOS:
|
||||
<hfoptions id="telemetry">
|
||||
<hfoption id="Linux/macOS">
|
||||
|
||||
```bash
|
||||
export HF_HUB_DISABLE_TELEMETRY=1
|
||||
```
|
||||
|
||||
On Windows:
|
||||
</hfoption>
|
||||
<hfoption id="Windows">
|
||||
|
||||
```bash
|
||||
set HF_HUB_DISABLE_TELEMETRY=1
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
136
examples/dreambooth/README_qwen.md
Normal file
136
examples/dreambooth/README_qwen.md
Normal file
@@ -0,0 +1,136 @@
|
||||
# DreamBooth training example for Qwen Image
|
||||
|
||||
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
|
||||
|
||||
The `train_dreambooth_lora_qwen_image.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [Qwen Image](https://huggingface.co/Qwen/Qwen-Image).
|
||||
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/dreambooth` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sana.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
|
||||
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
Now, we can launch training using:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="Qwen/Qwen-Image"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-sana-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_sana.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="bf16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--use_8bit_adam \
|
||||
--learning_rate=2e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
For using `push_to_hub`, make you're logged into your Hugging Face account:
|
||||
|
||||
```bash
|
||||
hf auth login
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
## Notes
|
||||
|
||||
Additionally, we welcome you to explore the following CLI arguments:
|
||||
|
||||
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
|
||||
* `--max_sequence_length`: Maximum sequence length to use for text embeddings.
|
||||
|
||||
We provide several options for optimizing memory optimization:
|
||||
|
||||
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
|
||||
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
|
||||
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
|
||||
|
||||
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwenimage) of the `QwenImagePipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
|
||||
|
||||
## Using quantization
|
||||
|
||||
You can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:
|
||||
|
||||
```json
|
||||
{
|
||||
"load_in_4bit": true,
|
||||
"bnb_4bit_quant_type": "nf4"
|
||||
}
|
||||
```
|
||||
248
examples/dreambooth/test_dreambooth_lora_qwenimage.py
Normal file
248
examples/dreambooth/test_dreambooth_lora_qwenimage.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRAQwenImage(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_qwen_image.py"
|
||||
transformer_layer_type = "transformer_blocks.0.attn.to_k"
|
||||
|
||||
def test_dreambooth_lora_qwen(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# transformer.transformer_blocks.0.attn.to_k should be in the state dict
|
||||
starts_with_transformer = all(
|
||||
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_with_metadata(self):
|
||||
# Use a `lora_alpha` that is different from `rank`.
|
||||
lora_alpha = 8
|
||||
rank = 4
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_alpha={lora_alpha}
|
||||
--rank={rank}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(state_dict_file))
|
||||
|
||||
# Check if the metadata was properly serialized.
|
||||
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
if raw:
|
||||
raw = json.loads(raw)
|
||||
|
||||
loaded_lora_alpha = raw["transformer.lora_alpha"]
|
||||
self.assertTrue(loaded_lora_alpha == lora_alpha)
|
||||
loaded_lora_rank = raw["transformer.r"]
|
||||
self.assertTrue(loaded_lora_rank == rank)
|
||||
1687
examples/dreambooth/train_dreambooth_lora_qwen_image.py
Normal file
1687
examples/dreambooth/train_dreambooth_lora_qwen_image.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -174,6 +174,7 @@ else:
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
"AutoencoderKLQwenImage",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderKLWan",
|
||||
"AutoencoderOobleck",
|
||||
@@ -215,6 +216,7 @@ else:
|
||||
"OmniGenTransformer2DModel",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"QwenImageTransformer2DModel",
|
||||
"SanaControlNetModel",
|
||||
"SanaTransformer2DModel",
|
||||
"SD3ControlNetModel",
|
||||
@@ -486,6 +488,7 @@ else:
|
||||
"PixArtAlphaPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
"PixArtSigmaPipeline",
|
||||
"QwenImagePipeline",
|
||||
"ReduxImageEncoder",
|
||||
"SanaControlNetPipeline",
|
||||
"SanaPAGPipeline",
|
||||
@@ -832,6 +835,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLQwenImage,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
@@ -873,6 +877,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
OmniGenTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaControlNetModel,
|
||||
SanaTransformer2DModel,
|
||||
SD3ControlNetModel,
|
||||
@@ -1119,6 +1124,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PixArtAlphaPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
PixArtSigmaPipeline,
|
||||
QwenImagePipeline,
|
||||
ReduxImageEncoder,
|
||||
SanaControlNetPipeline,
|
||||
SanaPAGPipeline,
|
||||
|
||||
@@ -153,6 +153,7 @@ def _register_transformer_blocks_metadata():
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
|
||||
# BasicTransformerBlock
|
||||
@@ -255,6 +256,15 @@ def _register_transformer_blocks_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# QwenImage
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=QwenImageTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
|
||||
@@ -79,6 +79,7 @@ if is_torch_available():
|
||||
"WanLoraLoaderMixin",
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
"SkyReelsV2LoraLoaderMixin",
|
||||
"QwenImageLoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
@@ -118,6 +119,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Lumina2LoraLoaderMixin,
|
||||
Mochi1LoraLoaderMixin,
|
||||
QwenImageLoraLoaderMixin,
|
||||
SanaLoraLoaderMixin,
|
||||
SD3LoraLoaderMixin,
|
||||
SkyReelsV2LoraLoaderMixin,
|
||||
|
||||
@@ -1974,6 +1974,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
bias_key_theirs = original_key.removesuffix(f".{lora_up_key}.weight") + ".diff_b"
|
||||
if bias_key_theirs in original_state_dict:
|
||||
bias_key = converted_key.removesuffix(".weight") + ".bias"
|
||||
converted_state_dict[bias_key] = original_state_dict.pop(bias_key_theirs)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
diff = all(".diff" in k for k in original_state_dict)
|
||||
|
||||
@@ -6538,6 +6538,348 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`QwenImageTransformer2DModel`]. Specific to [`QwenImagePipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
return_lora_metadata (`bool`, *optional*, defaults to False):
|
||||
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
||||
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
||||
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
||||
dict is loaded into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->QwenImageTransformer2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`QwenImageTransformer2DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
metadata (`dict`):
|
||||
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
||||
from the state dict.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the transformer.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
|
||||
@@ -61,6 +61,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
|
||||
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
||||
@@ -88,6 +89,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
@@ -126,6 +128,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLQwenImage,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
@@ -177,6 +180,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
OmniGenTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
SkyReelsV2Transformer3DModel,
|
||||
|
||||
@@ -8,6 +8,7 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_kl_wan import AutoencoderKLWan
|
||||
from .autoencoder_oobleck import AutoencoderOobleck
|
||||
|
||||
@@ -90,7 +90,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
|
||||
shift_factor: Optional[float] = None,
|
||||
latents_mean: Optional[Tuple[float]] = None,
|
||||
latents_std: Optional[Tuple[float]] = None,
|
||||
force_upcast: float = True,
|
||||
force_upcast: bool = True,
|
||||
use_quant_conv: bool = True,
|
||||
use_post_quant_conv: bool = True,
|
||||
mid_block_add_attention: bool = True,
|
||||
|
||||
@@ -168,7 +168,9 @@ class CosmosPatchEmbed3d(nn.Module):
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p = self.patch_size
|
||||
|
||||
hidden_states = torch.reshape(batch_size, num_channels, num_frames // p, p, height // p, p, width // p, p)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, num_channels, num_frames // p, p, height // p, p, width // p, p
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4).contiguous()
|
||||
return hidden_states
|
||||
|
||||
|
||||
1070
src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
Normal file
1070
src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -913,38 +913,21 @@ def patchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
# x shape: [batch_size, channels, height, width]
|
||||
batch_size, channels, height, width = x.shape
|
||||
|
||||
# Ensure height and width are divisible by patch_size
|
||||
if height % patch_size != 0 or width % patch_size != 0:
|
||||
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
|
||||
|
||||
# Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size]
|
||||
x = x.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
|
||||
|
||||
# Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size]
|
||||
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
|
||||
x = x.view(batch_size, channels * patch_size * patch_size, height // patch_size, width // patch_size)
|
||||
|
||||
elif x.dim() == 5:
|
||||
# x shape: [batch_size, channels, frames, height, width]
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
|
||||
# Ensure height and width are divisible by patch_size
|
||||
if height % patch_size != 0 or width % patch_size != 0:
|
||||
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
|
||||
|
||||
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
|
||||
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
|
||||
|
||||
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
|
||||
x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous()
|
||||
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
|
||||
|
||||
else:
|
||||
if x.dim() != 5:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
# x shape: [batch_size, channels, frames, height, width]
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
|
||||
# Ensure height and width are divisible by patch_size
|
||||
if height % patch_size != 0 or width % patch_size != 0:
|
||||
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
|
||||
|
||||
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
|
||||
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
|
||||
|
||||
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
|
||||
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
|
||||
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
|
||||
|
||||
return x
|
||||
|
||||
@@ -953,29 +936,18 @@ def unpatchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
# x shape: [b, (c * patch_size * patch_size), h, w]
|
||||
batch_size, c_patches, height, width = x.shape
|
||||
channels = c_patches // (patch_size * patch_size)
|
||||
if x.dim() != 5:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
|
||||
batch_size, c_patches, frames, height, width = x.shape
|
||||
channels = c_patches // (patch_size * patch_size)
|
||||
|
||||
# Reshape to [b, c, patch_size, patch_size, h, w]
|
||||
x = x.view(batch_size, channels, patch_size, patch_size, height, width)
|
||||
# Reshape to [b, c, patch_size, patch_size, f, h, w]
|
||||
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
|
||||
|
||||
# Rearrange to [b, c, h * patch_size, w * patch_size]
|
||||
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
|
||||
x = x.view(batch_size, channels, height * patch_size, width * patch_size)
|
||||
|
||||
elif x.dim() == 5:
|
||||
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
|
||||
batch_size, c_patches, frames, height, width = x.shape
|
||||
channels = c_patches // (patch_size * patch_size)
|
||||
|
||||
# Reshape to [b, c, patch_size, patch_size, f, h, w]
|
||||
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
|
||||
|
||||
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
|
||||
x = x.permute(0, 1, 4, 5, 2, 6, 3).contiguous()
|
||||
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
|
||||
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
|
||||
x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
|
||||
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
|
||||
|
||||
return x
|
||||
|
||||
@@ -1044,7 +1016,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
patch_size: Optional[int] = None,
|
||||
scale_factor_temporal: Optional[int] = 4,
|
||||
scale_factor_spatial: Optional[int] = 8,
|
||||
clip_output: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -1244,10 +1215,11 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
if self.config.clip_output:
|
||||
out = torch.clamp(out, min=-1.0, max=1.0)
|
||||
if self.config.patch_size is not None:
|
||||
out = unpatchify(out, patch_size=self.config.patch_size)
|
||||
|
||||
out = torch.clamp(out, min=-1.0, max=1.0)
|
||||
|
||||
self.clear_cache()
|
||||
if not return_dict:
|
||||
return (out,)
|
||||
|
||||
@@ -30,6 +30,7 @@ if is_torch_available():
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_qwenimage import QwenImageTransformer2DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
|
||||
628
src/diffusers/models/transformers/transformer_qwenimage.py
Normal file
628
src/diffusers/models/transformers/transformer_qwenimage.py
Normal file
@@ -0,0 +1,628 @@
|
||||
# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
Args
|
||||
timesteps (torch.Tensor):
|
||||
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||
embedding_dim (int):
|
||||
the dimension of the output.
|
||||
flip_sin_to_cos (bool):
|
||||
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
||||
downscale_freq_shift (float):
|
||||
Controls the delta between frequencies between dimensions
|
||||
scale (float):
|
||||
Scaling factor applied to the embeddings.
|
||||
max_period (int):
|
||||
Controls the maximum frequency of the embeddings
|
||||
Returns
|
||||
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent).to(timesteps.dtype)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def apply_rotary_emb_qwen(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Used for flux, cogvideox, hunyuan-dit
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
else:
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(1)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timestep, hidden_states):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
|
||||
|
||||
conditioning = timesteps_emb
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class QwenEmbedRope(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(1024)
|
||||
neg_index = torch.arange(1024).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat(
|
||||
[
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.neg_freqs = torch.cat(
|
||||
[
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.rope_cache = {}
|
||||
|
||||
# 是否使用 scale rope
|
||||
self.scale_rope = scale_rope
|
||||
|
||||
def rope_params(self, index, dim, theta=10000):
|
||||
"""
|
||||
Args:
|
||||
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
"""
|
||||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
||||
txt_length: [bs] a list of 1 integers representing the length of the text
|
||||
"""
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
frame, height, width = video_fhw
|
||||
rope_key = f"{frame}_{height}_{width}"
|
||||
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||
vid_freqs = self.rope_cache[rope_key]
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2)
|
||||
else:
|
||||
max_vid_index = max(height, width)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
class QwenDoubleStreamAttnProcessor2_0:
|
||||
"""
|
||||
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
|
||||
implements joint attention computation where text and image streams are processed together.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor, # Image stream
|
||||
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
||||
encoder_hidden_states_mask: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if encoder_hidden_states is None:
|
||||
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
|
||||
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
# Compute QKV for image stream (sample projections)
|
||||
img_query = attn.to_q(hidden_states)
|
||||
img_key = attn.to_k(hidden_states)
|
||||
img_value = attn.to_v(hidden_states)
|
||||
|
||||
# Compute QKV for text stream (context projections)
|
||||
txt_query = attn.add_q_proj(encoder_hidden_states)
|
||||
txt_key = attn.add_k_proj(encoder_hidden_states)
|
||||
txt_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
img_query = img_query.unflatten(-1, (attn.heads, -1))
|
||||
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
||||
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
||||
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
||||
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
# Apply QK normalization
|
||||
if attn.norm_q is not None:
|
||||
img_query = attn.norm_q(img_query)
|
||||
if attn.norm_k is not None:
|
||||
img_key = attn.norm_k(img_key)
|
||||
if attn.norm_added_q is not None:
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
if attn.norm_added_k is not None:
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
# Apply RoPE
|
||||
if image_rotary_emb is not None:
|
||||
img_freqs, txt_freqs = image_rotary_emb
|
||||
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
|
||||
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
|
||||
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
|
||||
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
|
||||
|
||||
# Concatenate for joint attention
|
||||
# Order: [text, image]
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
|
||||
# Compute joint attention
|
||||
joint_hidden_states = dispatch_attention_fn(
|
||||
joint_query,
|
||||
joint_key,
|
||||
joint_value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
joint_hidden_states = joint_hidden_states.flatten(2, 3)
|
||||
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
|
||||
|
||||
# Split attention outputs back
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
||||
|
||||
# Apply output projections
|
||||
img_attn_output = attn.to_out[0](img_attn_output)
|
||||
if len(attn.to_out) > 1:
|
||||
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
||||
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
|
||||
return img_attn_output, txt_attn_output
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class QwenImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
# Image processing modules
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
|
||||
)
|
||||
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None, # Enable cross attention for joint computation
|
||||
added_kv_proj_dim=dim, # Enable added KV projections for text stream
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=QwenDoubleStreamAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
)
|
||||
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
# Text processing modules
|
||||
self.txt_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
|
||||
)
|
||||
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
# Text doesn't need separate attention - it's handled by img_attn joint computation
|
||||
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
||||
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def _modulate(self, x, mod_params):
|
||||
"""Apply modulation to input tensor"""
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Get modulation parameters for both streams
|
||||
img_mod_params = self.img_mod(temb) # [B, 6*dim]
|
||||
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
|
||||
|
||||
# Split modulation parameters for norm1 and norm2
|
||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
||||
|
||||
# Process image stream - norm1 + modulation
|
||||
img_normed = self.img_norm1(hidden_states)
|
||||
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
||||
|
||||
# Process text stream - norm1 + modulation
|
||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
||||
|
||||
# Use QwenAttnProcessor2_0 for joint attention computation
|
||||
# This directly implements the DoubleStreamLayerMegatron logic:
|
||||
# 1. Computes QKV for both streams
|
||||
# 2. Applies QK normalization and RoPE
|
||||
# 3. Concatenates and runs joint attention
|
||||
# 4. Splits results back to separate streams
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
attn_output = self.attn(
|
||||
hidden_states=img_modulated, # Image stream (will be processed as "sample")
|
||||
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
|
||||
img_attn_output, txt_attn_output = attn_output
|
||||
|
||||
# Apply attention gates and add residual (like in Megatron)
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
|
||||
# Process image stream - norm2 + MLP
|
||||
img_normed2 = self.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||
img_mlp_output = self.img_mlp(img_modulated2)
|
||||
hidden_states = hidden_states + img_gate2 * img_mlp_output
|
||||
|
||||
# Process text stream - norm2 + MLP
|
||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
||||
txt_mlp_output = self.txt_mlp(txt_modulated2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
||||
|
||||
# Clip to prevent overflow for fp16
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
"""
|
||||
The Transformer model introduced in Qwen.
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, defaults to `64`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `60`):
|
||||
The number of layers of dual stream DiT blocks to use.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of dimensions to use for each attention head.
|
||||
num_attention_heads (`int`, defaults to `24`):
|
||||
The number of attention heads to use.
|
||||
joint_attention_dim (`int`, defaults to `3584`):
|
||||
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
||||
`encoder_hidden_states`).
|
||||
guidance_embeds (`bool`, defaults to `False`):
|
||||
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
||||
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions to use for the rotary positional embeddings.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["QwenImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = 16,
|
||||
num_layers: int = 60,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 3584,
|
||||
guidance_embeds: bool = False, # TODO: this should probably be removed
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
||||
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
||||
|
||||
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
|
||||
|
||||
self.img_in = nn.Linear(in_channels, self.inner_dim)
|
||||
self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
QwenImageTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
encoder_hidden_states_mask: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
||||
txt_seq_lens: Optional[List[int]] = None,
|
||||
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`QwenTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
|
||||
Mask of the input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
temb = (
|
||||
self.time_text_embed(timestep, hidden_states)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||
)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
encoder_hidden_states_mask,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=attention_kwargs,
|
||||
)
|
||||
|
||||
# Use only the image part (hidden_states) from the dual-stream blocks
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -180,6 +180,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
cross_attention_dim_head: Optional[int] = None,
|
||||
processor=None,
|
||||
is_cross_attention=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -207,6 +208,8 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def fuse_projections(self):
|
||||
@@ -324,7 +327,7 @@ class WanTimeTextImageEmbedding(nn.Module):
|
||||
):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
if timestep_seq_len is not None:
|
||||
timestep = timestep.unflatten(0, (1, timestep_seq_len))
|
||||
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
|
||||
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
|
||||
@@ -387,6 +387,7 @@ else:
|
||||
"SkyReelsV2ImageToVideoPipeline",
|
||||
"SkyReelsV2Pipeline",
|
||||
]
|
||||
_import_structure["qwenimage"] = ["QwenImagePipeline"]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -703,6 +704,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .qwenimage import QwenImagePipeline
|
||||
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
|
||||
@@ -312,15 +312,14 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.0.dev0"):
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
|
||||
|
||||
49
src/diffusers/pipelines/qwenimage/__init__.py
Normal file
49
src/diffusers/pipelines/qwenimage/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["QwenImagePipelineOutput", "QwenImagePriorReduxPipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"]
|
||||
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_qwenimage import QwenImagePipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
21
src/diffusers/pipelines/qwenimage/pipeline_output.py
Normal file
21
src/diffusers/pipelines/qwenimage/pipeline_output.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class QwenImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Stable Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
730
src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Normal file
730
src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Normal file
@@ -0,0 +1,730 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import QwenImageLoraLoaderMixin
|
||||
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import QwenImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import QwenImagePipeline
|
||||
|
||||
>>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||
>>> # Refer to the pipeline documentation for more details.
|
||||
>>> image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
>>> image.save("qwenimage.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
r"""
|
||||
The QwenImage pipeline for text-to-image generation.
|
||||
|
||||
Args:
|
||||
transformer ([`QwenImageTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
||||
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
||||
tokenizer (`QwenTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLQwenImage,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: Qwen2Tokenizer,
|
||||
transformer: QwenImageTransformer2DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.tokenizer_max_length = 1024
|
||||
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.prompt_template_encode_start_idx = 34
|
||||
self.default_sample_size = 128
|
||||
|
||||
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
|
||||
return split_result
|
||||
|
||||
def _get_qwen_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
template = self.prompt_template_encode
|
||||
drop_idx = self.prompt_template_encode_start_idx
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = self.tokenizer(
|
||||
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(self.device)
|
||||
encoder_hidden_states = self.text_encoder(
|
||||
input_ids=txt_tokens.input_ids,
|
||||
attention_mask=txt_tokens.attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
hidden_states = encoder_hidden_states.hidden_states[-1]
|
||||
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
|
||||
)
|
||||
encoder_attention_mask = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 1024,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 1024:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height, width, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
|
||||
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
|
||||
|
||||
return latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, 1, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
true_cfg_scale: float = 4.0,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 1.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
||||
not greater than `1`).
|
||||
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
||||
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
|
||||
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
has_neg_prompt = negative_prompt is not None or (
|
||||
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
||||
)
|
||||
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
||||
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_true_cfg:
|
||||
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# handle guidance
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
||||
noise_pred = comb_pred * (cond_norm / noise_norm)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return QwenImagePipelineOutput(images=image)
|
||||
@@ -1034,7 +1034,8 @@ class StableDiffusionPipeline(
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
if hasattr(self.scheduler, "scale_model_input"):
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
|
||||
@@ -125,15 +125,15 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2"]
|
||||
_optional_components = ["transformer", "transformer_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer: Optional[WanTransformer3DModel] = None,
|
||||
transformer_2: Optional[WanTransformer3DModel] = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
expand_timesteps: bool = False, # Wan2.2 ti2v
|
||||
@@ -526,7 +526,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
device=device,
|
||||
)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
@@ -536,7 +536,11 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
num_channels_latents = (
|
||||
self.transformer.config.in_channels
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.in_channels
|
||||
)
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -162,17 +162,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2", "image_encoder", "image_processor"]
|
||||
_optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
image_processor: CLIPImageProcessor = None,
|
||||
image_encoder: CLIPVisionModel = None,
|
||||
transformer: WanTransformer3DModel = None,
|
||||
transformer_2: WanTransformer3DModel = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
expand_timesteps: bool = False,
|
||||
@@ -669,12 +669,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)
|
||||
|
||||
# Encode image embedding
|
||||
transformer_dtype = self.transformer.dtype
|
||||
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
if self.config.boundary_ratio is None and not self.config.expand_timesteps:
|
||||
# only wan 2.1 i2v transformer accepts image_embeds
|
||||
if self.transformer is not None and self.transformer.config.image_dim is not None:
|
||||
if image_embeds is None:
|
||||
if last_image is None:
|
||||
image_embeds = self.encode_image(image, device)
|
||||
@@ -709,6 +710,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
last_image,
|
||||
)
|
||||
if self.config.expand_timesteps:
|
||||
# wan 2.2 5b i2v use firt_frame_mask to mask timesteps
|
||||
latents, condition, first_frame_mask = latents_outputs
|
||||
else:
|
||||
latents, condition = latents_outputs
|
||||
|
||||
@@ -423,6 +423,21 @@ class AutoencoderKLMochi(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLQwenImage(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1038,6 +1053,21 @@ class PriorTransformer(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class QwenImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SanaControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1742,6 +1742,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ReduxImageEncoder(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
129
tests/lora/test_lora_layers_qwenimage.py
Normal file
129
tests/lora/test_lora_layers_qwenimage.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLQwenImage,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
QwenImagePipeline,
|
||||
QwenImageTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = QwenImagePipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
transformer_cls = QwenImageTransformer2DModel
|
||||
z_dim = 4
|
||||
vae_kwargs = {
|
||||
"base_dim": z_dim * 6,
|
||||
"z_dim": z_dim,
|
||||
"dim_mult": [1, 2, 4],
|
||||
"num_res_blocks": 1,
|
||||
"temperal_downsample": [False, True],
|
||||
"latents_mean": [0.0] * 4,
|
||||
"latents_std": [1.0] * 4,
|
||||
}
|
||||
vae_cls = AutoencoderKLQwenImage
|
||||
tokenizer_cls, tokenizer_id = Qwen2Tokenizer, "hf-internal-testing/tiny-random-Qwen25VLForCondGen"
|
||||
text_encoder_cls, text_encoder_id = (
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
"hf-internal-testing/tiny-random-Qwen25VLForCondGen",
|
||||
)
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 8, 8, 3)
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 10
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes)
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 0.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
@@ -45,6 +45,7 @@ from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import is_transformers_version
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
@@ -220,6 +221,11 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
}
|
||||
return inputs
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.54.1"),
|
||||
reason="Test currently fails on Transformers version 4.54.1.",
|
||||
strict=False,
|
||||
)
|
||||
def test_audioldm2_ddim(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
@@ -312,7 +318,6 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
@@ -371,6 +376,11 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
assert np.abs(audio_1 - audio_2).max() < 1e-2
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.54.1"),
|
||||
reason="Test currently fails on Transformers version 4.54.1.",
|
||||
strict=False,
|
||||
)
|
||||
def test_audioldm2_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
0
tests/pipelines/qwenimage/__init__.py
Normal file
0
tests/pipelines/qwenimage/__init__.py
Normal file
236
tests/pipelines/qwenimage/test_qwenimage.py
Normal file
236
tests/pipelines/qwenimage/test_qwenimage.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLQwenImage,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
QwenImagePipeline,
|
||||
QwenImageTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = QwenImagePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = QwenImageTransformer2DModel(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
out_channels=4,
|
||||
num_layers=2,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=3,
|
||||
joint_attention_dim=16,
|
||||
guidance_embeds=False,
|
||||
axes_dims_rope=(8, 4, 4),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
z_dim = 4
|
||||
vae = AutoencoderKLQwenImage(
|
||||
base_dim=z_dim * 6,
|
||||
z_dim=z_dim,
|
||||
dim_mult=[1, 2, 4],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True],
|
||||
# fmt: off
|
||||
latents_mean=[0.0] * 4,
|
||||
latents_std=[1.0] * 4,
|
||||
# fmt: on
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = Qwen2_5_VLConfig(
|
||||
text_config={
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [1, 1, 2],
|
||||
"rope_type": "default",
|
||||
"type": "default",
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
},
|
||||
vision_config={
|
||||
"depth": 2,
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_heads": 2,
|
||||
"out_hidden_size": 16,
|
||||
},
|
||||
hidden_size=16,
|
||||
vocab_size=152064,
|
||||
vision_end_token_id=151653,
|
||||
vision_start_token_id=151652,
|
||||
vision_token_id=151654,
|
||||
)
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"true_cfg_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
generated_image = image[0]
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.563, 0.6358, 0.6028, 0.5656, 0.5806, 0.5512, 0.5712, 0.6331, 0.4147, 0.3558, 0.5625, 0.4831, 0.4957, 0.5258, 0.4075, 0.5018])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_image.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling
|
||||
pipe.vae.enable_tiling(
|
||||
tile_sample_min_height=96,
|
||||
tile_sample_min_width=96,
|
||||
tile_sample_stride_height=64,
|
||||
tile_sample_stride_width=64,
|
||||
)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
@@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -85,29 +87,13 @@ class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
rope_max_seq_len=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer_2 = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer_2": transformer_2,
|
||||
"transformer_2": None,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -155,6 +141,45 @@ class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
# _optional_components include transformer, transformer_2, but only transformer_2 is optional for this wan2.1 t2v pipeline
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
optional_component = "transformer_2"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
components[optional_component] = None
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
self.assertTrue(
|
||||
getattr(pipe_loaded, optional_component) is None,
|
||||
f"`{optional_component}` did not stay set to None after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
||||
367
tests/pipelines/wan/test_wan_22.py
Normal file
367
tests/pipelines/wan/test_wan_22.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanPipeline, WanTransformer3DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Wan22PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = WanPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer_2 = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer_2": transformer_2,
|
||||
"boundary_ratio": 0.875,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "negative",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"height": 16,
|
||||
"width": 16,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(
|
||||
**components,
|
||||
)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_video.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
optional_component = "transformer"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
components[optional_component] = None
|
||||
components["boundary_ratio"] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
self.assertTrue(
|
||||
getattr(pipe_loaded, "transformer") is None,
|
||||
"`transformer` did not stay set to None after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
|
||||
class Wan225BPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = WanPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=48,
|
||||
in_channels=12,
|
||||
out_channels=12,
|
||||
is_residual=True,
|
||||
patch_size=2,
|
||||
latents_mean=[0.0] * 48,
|
||||
latents_std=[1.0] * 48,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
scale_factor_spatial=16,
|
||||
scale_factor_temporal=4,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=48,
|
||||
out_channels=48,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer_2": None,
|
||||
"boundary_ratio": None,
|
||||
"expand_timesteps": True,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "negative", # TODO
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(
|
||||
**components,
|
||||
)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([[[0.4814, 0.4298, 0.5094, 0.4289, 0.5061, 0.4301, 0.5043, 0.4284, 0.5375,
|
||||
0.5965, 0.5527, 0.6014, 0.5228, 0.6076, 0.6644, 0.5651]]])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_video.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(
|
||||
torch.allclose(generated_slice, expected_slice, atol=1e-3),
|
||||
f"generated_slice: {generated_slice}, expected_slice: {expected_slice}",
|
||||
)
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components.pop("boundary_ratio")
|
||||
init_components.pop("expand_timesteps")
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
optional_component = "transformer_2"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
components[optional_component] = None
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
self.assertTrue(
|
||||
getattr(pipe_loaded, optional_component) is None,
|
||||
f"`{optional_component}` did not stay set to None after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
392
tests/pipelines/wan/test_wan_22_image_to_video.py
Normal file
392
tests/pipelines/wan/test_wan_22_image_to_video.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanImageToVideoPipeline, WanTransformer3DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Wan22ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = WanImageToVideoPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=36,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer_2 = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=36,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer_2": transformer_2,
|
||||
"image_encoder": None,
|
||||
"image_processor": None,
|
||||
"boundary_ratio": 0.875,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
image_height = 16
|
||||
image_width = 16
|
||||
image = Image.new("RGB", (image_width, image_height))
|
||||
inputs = {
|
||||
"image": image,
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "negative", # TODO
|
||||
"height": image_height,
|
||||
"width": image_width,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(
|
||||
**components,
|
||||
)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.4527, 0.4526, 0.4498, 0.4539, 0.4521, 0.4524, 0.4533, 0.4535, 0.5154,
|
||||
0.5353, 0.5200, 0.5174, 0.5434, 0.5301, 0.5199, 0.5216])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_video.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(
|
||||
torch.allclose(generated_slice, expected_slice, atol=1e-3),
|
||||
f"generated_slice: {generated_slice}, expected_slice: {expected_slice}",
|
||||
)
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
optional_component = ["transformer", "image_encoder", "image_processor"]
|
||||
|
||||
components = self.get_dummy_components()
|
||||
for component in optional_component:
|
||||
components[component] = None
|
||||
components["boundary_ratio"] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
for component in optional_component:
|
||||
self.assertTrue(
|
||||
getattr(pipe_loaded, component) is None,
|
||||
f"`{component}` did not stay set to None after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
|
||||
class Wan225BImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = WanImageToVideoPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=48,
|
||||
in_channels=12,
|
||||
out_channels=12,
|
||||
is_residual=True,
|
||||
patch_size=2,
|
||||
latents_mean=[0.0] * 48,
|
||||
latents_std=[1.0] * 48,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
scale_factor_spatial=16,
|
||||
scale_factor_temporal=4,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=48,
|
||||
out_channels=48,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer_2": None,
|
||||
"image_encoder": None,
|
||||
"image_processor": None,
|
||||
"boundary_ratio": None,
|
||||
"expand_timesteps": True,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
image_height = 32
|
||||
image_width = 32
|
||||
image = Image.new("RGB", (image_width, image_height))
|
||||
inputs = {
|
||||
"image": image,
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "negative", # TODO
|
||||
"height": image_height,
|
||||
"width": image_width,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(
|
||||
**components,
|
||||
)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([[0.4833, 0.4305, 0.5100, 0.4299, 0.5056, 0.4298, 0.5052, 0.4332, 0.5550,
|
||||
0.6092, 0.5536, 0.5928, 0.5199, 0.5864, 0.6705, 0.5493]])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_video.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(
|
||||
torch.allclose(generated_slice, expected_slice, atol=1e-3),
|
||||
f"generated_slice: {generated_slice}, expected_slice: {expected_slice}",
|
||||
)
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components.pop("boundary_ratio")
|
||||
init_components.pop("expand_timesteps")
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
optional_component = ["transformer_2", "image_encoder", "image_processor"]
|
||||
|
||||
components = self.get_dummy_components()
|
||||
for component in optional_component:
|
||||
components[component] = None
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
for component in optional_component:
|
||||
self.assertTrue(
|
||||
getattr(pipe_loaded, component) is None,
|
||||
f"`{component}` did not stay set to None after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_callback_inputs(self):
|
||||
pass
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
@@ -25,7 +27,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanImageToVideoPipeline, WanTransformer3DModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
@@ -86,23 +88,6 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_dim=4,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer_2 = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=36,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
image_dim=4,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_encoder_config = CLIPVisionConfig(
|
||||
hidden_size=4,
|
||||
@@ -126,7 +111,7 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"tokenizer": tokenizer,
|
||||
"image_encoder": image_encoder,
|
||||
"image_processor": image_processor,
|
||||
"transformer_2": transformer_2,
|
||||
"transformer_2": None,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -182,11 +167,44 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"TODO: refactor this test: one component can be optional for certain checkpoints but not for others"
|
||||
)
|
||||
def test_save_load_optional_components(self):
|
||||
pass
|
||||
# _optional_components include transformer, transformer_2 and image_encoder, image_processor, but only transformer_2 is optional for wan2.1 i2v pipeline
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
optional_component = "transformer_2"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
components[optional_component] = None
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
self.assertTrue(
|
||||
getattr(pipe_loaded, optional_component) is None,
|
||||
f"`{optional_component}` did not stay set to None after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
|
||||
class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@@ -242,24 +260,6 @@ class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pos_embed_seq_len=2 * (4 * 4 + 1),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer_2 = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=36,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
image_dim=4,
|
||||
pos_embed_seq_len=2 * (4 * 4 + 1),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_encoder_config = CLIPVisionConfig(
|
||||
hidden_size=4,
|
||||
@@ -283,7 +283,7 @@ class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"tokenizer": tokenizer,
|
||||
"image_encoder": image_encoder,
|
||||
"image_processor": image_processor,
|
||||
"transformer_2": transformer_2,
|
||||
"transformer_2": None,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -341,8 +341,41 @@ class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"TODO: refactor this test: one component can be optional for certain checkpoints but not for others"
|
||||
)
|
||||
def test_save_load_optional_components(self):
|
||||
pass
|
||||
# _optional_components include transformer, transformer_2 and image_encoder, image_processor, but only transformer_2 is optional for wan2.1 FLFT2V pipeline
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
optional_component = "transformer_2"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
components[optional_component] = None
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
self.assertTrue(
|
||||
getattr(pipe_loaded, optional_component) is None,
|
||||
f"`{optional_component}` did not stay set to None after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
|
||||
self.assertLess(max_diff, expected_max_difference)
|
||||
|
||||
Reference in New Issue
Block a user