Compare commits

..

30 Commits

Author SHA1 Message Date
sayakpaul
0e602e4be3 include auto-docstring check in the modular ci. 2026-01-20 15:42:51 +05:30
github-actions[bot]
23d06423ab Apply style fixes 2026-01-19 09:23:31 +00:00
YiYi Xu
aba551c868 Merge branch 'main' into modular-doc-improv 2026-01-18 23:20:36 -10:00
yiyixuxu
1f9576a2ca fix 2026-01-19 09:56:14 +01:00
yiyixuxu
d75fbc43c7 Merge branch 'modular-doc-improv' of github.com:huggingface/diffusers into modular-doc-improv 2026-01-19 09:54:46 +01:00
yiyixuxu
b7127ce7a7 revert change in z 2026-01-19 09:54:40 +01:00
YiYi Xu
7e9d2b954e Apply suggestions from code review 2026-01-18 22:44:44 -10:00
yiyixuxu
94525200fd rmove space in make docstring 2026-01-19 09:35:39 +01:00
yiyixuxu
f056af1fbb make style 2026-01-19 09:27:40 +01:00
yiyixuxu
8d45ff5bf6 apply auto docstring 2026-01-19 09:22:04 +01:00
yiyixuxu
fb15752d55 up up up 2026-01-19 08:10:31 +01:00
yiyixuxu
1f2dbc9dd2 up 2026-01-19 04:10:17 +01:00
yiyixuxu
002c3e8239 add template method 2026-01-19 03:24:34 +01:00
yiyixuxu
de03d7f100 refactor based on dhruv's feedback: remove the class method 2026-01-18 00:35:01 +01:00
yiyixuxu
25c968a38f add TODO in the description for empty docstring 2026-01-17 09:57:56 +01:00
yiyixuxu
aea0d046f6 address feedbacks 2026-01-17 09:36:58 +01:00
David El Malih
3996788b60 [Docs] Replace root CONTRIBUTING.md with symlink to source docs (#12986)
Chore: Replace CONTRIBUTING.md with a symlink to documentation
2026-01-16 12:36:50 -08:00
David El Malih
9fedfe58b7 Improve docstrings and type hints in scheduling_cosine_dpmsolver_multistep.py (#12936)
* docs: improve docstring scheduling_cosine_dpmsolver_multistep.py

* Update src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* fix

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-01-16 12:23:49 -08:00
Sayak Paul
ebf891a254 [core] gracefully error out when attn-backend x cp combo isn't supported. (#12832)
* gracefully error out when attn-backend x cp combo isn't supported.

* Revert "gracefully error out when attn-backend x cp combo isn't supported."

This reverts commit c8abb5d7c0.

* gracefully error out when attn-backend x cp combo isn't supported.

* up

* address PR feedback.

* up

* Update src/diffusers/models/modeling_utils.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* dot.

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2026-01-16 21:29:42 +05:30
yiyixuxu
1c90ce33f2 up 2026-01-10 12:21:26 +01:00
yiyixuxu
507953f415 more more 2026-01-10 12:19:14 +01:00
yiyixuxu
f0555af1c6 up up up 2026-01-10 12:15:53 +01:00
yiyixuxu
2a81f2ec54 style 2026-01-10 12:15:36 +01:00
yiyixuxu
d20f413f78 more auto docstring 2026-01-10 12:11:28 +01:00
yiyixuxu
ff09bf1a63 add modular_auto_docstring! 2026-01-10 11:55:03 +01:00
yiyixuxu
34a743e2dc style 2026-01-10 10:57:27 +01:00
yiyixuxu
43ab14845d update outputs 2026-01-10 10:56:54 +01:00
YiYi Xu
fbfe5c8d6b Merge branch 'main' into modular-doc-improv 2026-01-09 23:54:23 -10:00
yiyixuxu
b29873dee7 up up 2026-01-10 10:52:53 +01:00
yiyixuxu
7b499de6d0 up 2026-01-10 03:35:15 +01:00
38 changed files with 4283 additions and 7088 deletions

View File

@@ -75,9 +75,27 @@ jobs:
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
check_auto_docs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[quality]
- name: Check auto docs
run: make modular-autodoctrings
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Auto docstring checks failed. Please run `python utils/modular_auto_docstring.py --fix_and_overwrite`." >> $GITHUB_STEP_SUMMARY
run_fast_tests:
needs: [check_code_quality, check_repository_consistency]
needs: [check_code_quality, check_repository_consistency, check_auto_docs]
name: Fast PyTorch Modular Pipeline CPU tests
runs-on:

View File

@@ -1,506 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# How to contribute to Diffusers 🧨
We ❤️ contributions from the open-source community! Everyone is welcome, and all types of participation not just code are valued and appreciated. Answering questions, helping others, reaching out, and improving the documentation are all immensely valuable to the community, so don't be afraid and get involved if you're up for it!
Everyone is encouraged to start by saying 👋 in our public Discord channel. We discuss the latest trends in diffusion models, ask questions, show off personal projects, help each other with contributions, or just hang out ☕. <a href="https://discord.gg/G7tWnz98XR"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=Discord&logoColor=white"></a>
Whichever way you choose to contribute, we strive to be part of an open, welcoming, and kind community. Please, read our [code of conduct](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md) and be mindful to respect it during your interactions. We also recommend you become familiar with the [ethical guidelines](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines) that guide our project and ask you to adhere to the same principles of transparency and responsibility.
We enormously value feedback from the community, so please do not be afraid to speak up if you believe you have valuable feedback that can help improve the library - every message, comment, issue, and pull request (PR) is read and considered.
## Overview
You can contribute in many ways ranging from answering questions on issues to adding new diffusion models to
the core library.
In the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community.
* 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR).
* 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose).
* 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues).
* 4. Fix a simple issue, marked by the "Good first issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
* 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source).
* 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples).
* 7. Contribute to the [examples](https://github.com/huggingface/diffusers/tree/main/examples).
* 8. Fix a more difficult issue, marked by the "Good second issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22).
* 9. Add a new pipeline, model, or scheduler, see ["New Pipeline/Model"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) and ["New scheduler"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) issues. For this contribution, please have a look at [Design Philosophy](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md).
As said before, **all contributions are valuable to the community**.
In the following, we will explain each contribution a bit more in detail.
For all contributions 4-9, you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr).
### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord
Any question or comment related to the Diffusers library can be asked on the [discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/) or on [Discord](https://discord.gg/G7tWnz98XR). Such questions and comments include (but are not limited to):
- Reports of training or inference experiments in an attempt to share knowledge
- Presentation of personal projects
- Questions to non-official training examples
- Project proposals
- General feedback
- Paper summaries
- Asking for help on personal projects that build on top of the Diffusers library
- General questions
- Ethical questions regarding diffusion models
- ...
Every question that is asked on the forum or on Discord actively encourages the community to publicly
share knowledge and might very well help a beginner in the future who has the same question you're
having. Please do pose any questions you might have.
In the same spirit, you are of immense help to the community by answering such questions because this way you are publicly documenting knowledge for everybody to learn from.
**Please** keep in mind that the more effort you put into asking or answering a question, the higher
the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.
In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formatted/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
**NOTE about channels**:
[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.
In addition, questions and answers posted in the forum can easily be linked to.
In contrast, *Discord* has a chat-like format that invites fast back-and-forth communication.
While it will most likely take less time for you to get an answer to your question on Discord, your
question won't be visible anymore over time. Also, it's much harder to find information that was posted a while back on Discord. We therefore strongly recommend using the forum for high-quality questions and answers in an attempt to create long-lasting knowledge for the community. If discussions on Discord lead to very interesting answers and conclusions, we recommend posting the results on the forum to make the information more available for future readers.
### 2. Opening new issues on the GitHub issues tab
The 🧨 Diffusers library is robust and reliable thanks to the users who notify us of
the problems they encounter. So thank you for reporting an issue.
Remember, GitHub issues are reserved for technical questions directly related to the Diffusers library, bug reports, feature requests, or feedback on the library design.
In a nutshell, this means that everything that is **not** related to the **code of the Diffusers library** (including the documentation) should **not** be asked on GitHub, but rather on either the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).
**Please consider the following guidelines when opening a new issue**:
- Make sure you have searched whether your issue has already been asked before (use the search bar on GitHub under Issues).
- Please never report a new issue on another (related) issue. If another issue is highly related, please
open a new issue nevertheless and link to the related issue.
- Make sure your issue is written in English. Please use one of the great, free online translation services, such as [DeepL](https://www.deepl.com/translator) to translate from your native language to English if you are not comfortable in English.
- Check whether your issue might be solved by updating to the newest Diffusers version. Before posting your issue, please make sure that `python -c "import diffusers; print(diffusers.__version__)"` is higher or matches the latest Diffusers version.
- Remember that the more effort you put into opening a new issue, the higher the quality of your answer will be and the better the overall quality of the Diffusers issues.
New issues usually include the following.
#### 2.1. Reproducible, minimal bug reports
A bug report should always have a reproducible code snippet and be as minimal and concise as possible.
This means in more detail:
- Narrow the bug down as much as you can, **do not just dump your whole code file**.
- Format your code.
- Do not include any external libraries except for Diffusers depending on them.
- **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue.
- Explain the issue. If the reader doesn't know what the issue is and why it is an issue, she cannot solve it.
- **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell.
- If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible.
For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
You can open a bug report [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml).
#### 2.2. Feature requests
A world-class feature request addresses the following points:
1. Motivation first:
* Is it related to a problem/frustration with the library? If so, please explain
why. Providing a code snippet that demonstrates the problem is best.
* Is it related to something you would need for a project? We'd love to hear
about it!
* Is it something you worked on and think could benefit the community?
Awesome! Tell us what problem it solved for you.
2. Write a *full paragraph* describing the feature;
3. Provide a **code snippet** that demonstrates its future use;
4. In case this is related to a paper, please attach a link;
5. Attach any additional information (drawings, screenshots, etc.) you think may help.
You can open a feature request [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=).
#### 2.3 Feedback
Feedback about the library design and why it is good or not good helps the core maintainers immensely to build a user-friendly library. To understand the philosophy behind the current design philosophy, please have a look [here](https://huggingface.co/docs/diffusers/conceptual/philosophy). If you feel like a certain design choice does not fit with the current design philosophy, please explain why and how it should be changed. If a certain design choice follows the design philosophy too much, hence restricting use cases, explain why and how it should be changed.
If a certain design choice is very useful for you, please also leave a note as this is great feedback for future design decisions.
You can open an issue about feedback [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).
#### 2.4 Technical questions
Technical questions are mainly about why certain code of the library was written in a certain way, or what a certain part of the code does. Please make sure to link to the code in question and please provide detail on
why this part of the code is difficult to understand.
You can open an issue about a technical question [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml).
#### 2.5 Proposal to add a new model, scheduler, or pipeline
If the diffusion model community released a new model, pipeline, or scheduler that you would like to see in the Diffusers library, please provide the following information:
* Short description of the diffusion pipeline, model, or scheduler and link to the paper or public release.
* Link to any of its open-source implementation.
* Link to the model weights if they are available.
If you are willing to contribute to the model yourself, let us know so we can best guide you. Also, don't forget
to tag the original author of the component (model, scheduler, pipeline, etc.) by GitHub handle if you can find it.
You can open a request for a model/pipeline/scheduler [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml).
### 3. Answering issues on the GitHub issues tab
Answering issues on GitHub might require some technical knowledge of Diffusers, but we encourage everybody to give it a try even if you are not 100% certain that your answer is correct.
Some tips to give a high-quality answer to an issue:
- Be as concise and minimal as possible.
- Stay on topic. An answer to the issue should concern the issue and only the issue.
- Provide links to code, papers, or other sources that prove or encourage your point.
- Answer in code. If a simple code snippet is the answer to the issue or shows how the issue can be solved, please provide a fully reproducible code snippet.
Also, many issues tend to be simply off-topic, duplicates of other issues, or irrelevant. It is of great
help to the maintainers if you can answer such issues, encouraging the author of the issue to be
more precise, provide the link to a duplicated issue or redirect them to [the forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).
If you have verified that the issued bug report is correct and requires a correction in the source code,
please have a look at the next sections.
For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section.
### 4. Fixing a "Good first issue"
*Good first issues* are marked by the [Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) label. Usually, the issue already
explains how a potential solution should look so that it is easier to fix.
If the issue hasn't been closed and you would like to try to fix this issue, you can just leave a message "I would like to try this issue.". There are usually three scenarios:
- a.) The issue description already proposes a fix. In this case and if the solution makes sense to you, you can open a PR or draft PR to fix it.
- b.) The issue description does not propose a fix. In this case, you can ask what a proposed fix could look like and someone from the Diffusers team should answer shortly. If you have a good idea of how to fix it, feel free to directly open a PR.
- c.) There is already an open PR to fix the issue, but the issue hasn't been closed yet. If the PR has gone stale, you can simply open a new PR and link to the stale PR. PRs often go stale if the original contributor who wanted to fix the issue suddenly cannot find the time anymore to proceed. This often happens in open-source and is very normal. In this case, the community will be very happy if you give it a new try and leverage the knowledge of the existing PR. If there is already a PR and it is active, you can help the author by giving suggestions, reviewing the PR or even asking whether you can contribute to the PR.
### 5. Contribute to the documentation
A good library **always** has good documentation! The official documentation is often one of the first points of contact for new users of the library, and therefore contributing to the documentation is a **highly
valuable contribution**.
Contributing to the library can have many forms:
- Correcting spelling or grammatical errors.
- Correct incorrect formatting of the docstring. If you see that the official documentation is weirdly displayed or a link is broken, we are very happy if you take some time to correct it.
- Correct the shape or dimensions of a docstring input or output tensor.
- Clarify documentation that is hard to understand or incorrect.
- Update outdated code examples.
- Translating the documentation to another language.
Anything displayed on [the official Diffusers doc page](https://huggingface.co/docs/diffusers/index) is part of the official documentation and can be corrected, adjusted in the respective [documentation source](https://github.com/huggingface/diffusers/tree/main/docs/source).
Please have a look at [this page](https://github.com/huggingface/diffusers/tree/main/docs) on how to verify changes made to the documentation locally.
### 6. Contribute a community pipeline
[Pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) are usually the first point of contact between the Diffusers library and the user.
Pipelines are examples of how to use Diffusers [models](https://huggingface.co/docs/diffusers/api/models/overview) and [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview).
We support two types of pipelines:
- Official Pipelines
- Community Pipelines
Both official and community pipelines follow the same design and consist of the same type of components.
Official pipelines are tested and maintained by the core maintainers of Diffusers. Their code
resides in [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
In contrast, community pipelines are contributed and maintained purely by the **community** and are **not** tested.
They reside in [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and while they can be accessed via the [PyPI diffusers package](https://pypi.org/project/diffusers/), their code is not part of the PyPI distribution.
The reason for the distinction is that the core maintainers of the Diffusers library cannot maintain and test all
possible ways diffusion models can be used for inference, but some of them may be of interest to the community.
Officially released diffusion pipelines,
such as Stable Diffusion are added to the core src/diffusers/pipelines package which ensures
high quality of maintenance, no backward-breaking code changes, and testing.
More bleeding edge pipelines should be added as community pipelines. If usage for a community pipeline is high, the pipeline can be moved to the official pipelines upon request from the community. This is one of the ways we strive to be a community-driven library.
To add a community pipeline, one should add a <name-of-the-community>.py file to [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and adapt the [examples/community/README.md](https://github.com/huggingface/diffusers/tree/main/examples/community/README.md) to include an example of the new pipeline.
An example can be seen [here](https://github.com/huggingface/diffusers/pull/2400).
Community pipeline PRs are only checked at a superficial level and ideally they should be maintained by their original authors.
Contributing a community pipeline is a great way to understand how Diffusers models and schedulers work. Having contributed a community pipeline is usually the first stepping stone to contributing an official pipeline to the
core package.
### 7. Contribute to training examples
Diffusers examples are a collection of training scripts that reside in [examples](https://github.com/huggingface/diffusers/tree/main/examples).
We support two types of training examples:
- Official training examples
- Research training examples
Research training examples are located in [examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) whereas official training examples include all folders under [examples](https://github.com/huggingface/diffusers/tree/main/examples) except the `research_projects` and `community` folders.
The official training examples are maintained by the Diffusers' core maintainers whereas the research training examples are maintained by the community.
This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.
If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.
Both official training and research examples consist of a directory that contains one or more training scripts, a `requirements.txt` file, and a `README.md` file. In order for the user to make use of the
training examples, it is required to clone the repository:
```bash
git clone https://github.com/huggingface/diffusers
```
as well as to install all additional dependencies required for training:
```bash
cd diffusers
pip install -r examples/<your-example-folder>/requirements.txt
```
Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).
Training examples of the Diffusers library should adhere to the following philosophy:
- All the code necessary to run the examples should be found in a single Python file.
- One should be able to run the example from the command line with `python <your-example>.py --args`.
- Examples should be kept simple and serve as **an example** on how to use Diffusers for training. The purpose of example scripts is **not** to create state-of-the-art diffusion models, but rather to reproduce known training schemes without adding too much custom logic. As a byproduct of this point, our examples also strive to serve as good educational materials.
To contribute an example, it is highly recommended to look at already existing examples such as [dreambooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) to get an idea of how they should look like.
We strongly advise contributors to make use of the [Accelerate library](https://github.com/huggingface/accelerate) as it's tightly integrated
with Diffusers.
Once an example script works, please make sure to add a comprehensive `README.md` that states how to use the example exactly. This README should include:
- An example command on how to run the example script as shown [here e.g.](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-locally-with-pytorch).
- A link to some training results (logs, models, ...) that show what the user can expect as shown [here e.g.](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
- If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations).
If you are contributing to the official training examples, please also make sure to add a test to [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py). This is not necessary for non-official training examples.
### 8. Fixing a "Good second issue"
*Good second issues* are marked by the [Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22) label. Good second issues are
usually more complicated to solve than [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
The issue description usually gives less guidance on how to fix the issue and requires
a decent understanding of the library by the interested contributor.
If you are interested in tackling a good second issue, feel free to open a PR to fix it and link the PR to the issue. If you see that a PR has already been opened for this issue but did not get merged, have a look to understand why it wasn't merged and try to open an improved PR.
Good second issues are usually more difficult to get merged compared to good first issues, so don't hesitate to ask for help from the core maintainers. If your PR is almost finished the core maintainers can also jump into your PR and commit to it in order to get it merged.
### 9. Adding pipelines, models, schedulers
Pipelines, models, and schedulers are the most important pieces of the Diffusers library.
They provide easy access to state-of-the-art diffusion technologies and thus allow the community to
build powerful generative AI applications.
By adding a new model, pipeline, or scheduler you might enable a new powerful use case for any of the user interfaces relying on Diffusers which can be of immense value for the whole generative AI ecosystem.
Diffusers has a couple of open feature requests for all three components - feel free to gloss over them
if you don't know yet what specific component you would like to add:
- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)
- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md) a read to better understand the design of any of the three components. Please be aware that
we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please
open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design
pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
Please make sure to add links to the original codebase/paper to the PR and ideally also ping the
original author directly on the PR so that they can follow the progress and potentially help with questions.
If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help.
## How to write a good issue
**The better your issue is written, the higher the chances that it will be quickly resolved.**
1. Make sure that you've used the correct template for your issue. You can pick between *Bug Report*, *Feature Request*, *Feedback about API Design*, *New model/pipeline/scheduler addition*, *Forum*, or a blank issue. Make sure to pick the correct one when opening [a new issue](https://github.com/huggingface/diffusers/issues/new/choose).
2. **Be precise**: Give your issue a fitting title. Try to formulate your issue description as simple as possible. The more precise you are when submitting an issue, the less time it takes to understand the issue and potentially solve it. Make sure to open an issue for one issue only and not for multiple issues. If you found multiple issues, simply open multiple issues. If your issue is a bug, try to be as precise as possible about what bug it is - you should not just write "Error in diffusers".
3. **Reproducibility**: No reproducible code snippet == no solution. If you encounter a bug, maintainers **have to be able to reproduce** it. Make sure that you include a code snippet that can be copy-pasted into a Python interpreter to reproduce the issue. Make sure that your code snippet works, *i.e.* that there are no missing imports or missing links to images, ... Your issue should contain an error message **and** a code snippet that can be copy-pasted without any changes to reproduce the exact same error message. If your issue is using local model weights or local data that cannot be accessed by the reader, the issue cannot be solved. If you cannot share your data or model, try to make a dummy model or dummy data.
4. **Minimalistic**: Try to help the reader as much as you can to understand the issue as quickly as possible by staying as concise as possible. Remove all code / all information that is irrelevant to the issue. If you have found a bug, try to create the easiest code example you can to demonstrate your issue, do not just dump your whole workflow into the issue as soon as you have found a bug. E.g., if you train a model and get an error at some point during the training, you should first try to understand what part of the training code is responsible for the error and try to reproduce it with a couple of lines. Try to use dummy data instead of full datasets.
5. Add links. If you are referring to a certain naming, method, or model make sure to provide a link so that the reader can better understand what you mean. If you are referring to a specific PR or issue, make sure to link it to your issue. Do not assume that the reader knows what you are talking about. The more links you add to your issue the better.
6. Formatting. Make sure to nicely format your issue by formatting code into Python code syntax, and error messages into normal code syntax. See the [official GitHub formatting docs](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax) for more information.
7. Think of your issue not as a ticket to be solved, but rather as a beautiful entry to a well-written encyclopedia. Every added issue is a contribution to publicly available knowledge. By adding a nicely written issue you not only make it easier for maintainers to solve your issue, but you are helping the whole community to better understand a certain aspect of the library.
## How to write a good PR
1. Be a chameleon. Understand existing design patterns and syntax and make sure your code additions flow seamlessly into the existing code base. Pull requests that significantly diverge from existing design patterns or user interfaces will not be merged.
2. Be laser focused. A pull request should solve one problem and one problem only. Make sure to not fall into the trap of "also fixing another problem while we're adding it". It is much more difficult to review pull requests that solve multiple, unrelated problems at once.
3. If helpful, try to add a code snippet that displays an example of how your addition can be used.
4. The title of your pull request should be a summary of its contribution.
5. If your pull request addresses an issue, please mention the issue number in
the pull request description to make sure they are linked (and people
consulting the issue know you are working on it);
6. To indicate a work in progress please prefix the title with `[WIP]`. These
are useful to avoid duplicated work, and to differentiate it from PRs ready
to be merged;
7. Try to formulate and format your text as explained in [How to write a good issue](#how-to-write-a-good-issue).
8. Make sure existing tests pass;
9. Add high-coverage tests. No quality testing = no merge.
- If you are adding new `@slow` tests, make sure they pass using
`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.
CircleCI does not run the slow tests, but GitHub Actions does every night!
10. All public methods must have informative docstrings that work nicely with markdown. See [`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) for an example.
11. Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like
[`hf-internal-testing`](https://huggingface.co/hf-internal-testing) or [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images) to place these files.
If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images
to this dataset.
## How to open a PR
Before writing code, we strongly advise you to search through the existing PRs or
issues to make sure that nobody is already working on the same thing. If you are
unsure, it is always a good idea to open an issue to get some feedback.
You will need basic `git` proficiency to be able to contribute to
🧨 Diffusers. `git` is not the easiest tool to use but it has the greatest
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/42f25d601a910dceadaee6c44345896b4cfa9928/setup.py#L270)):
1. Fork the [repository](https://github.com/huggingface/diffusers) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote:
```bash
$ git clone git@github.com:<your GitHub handle>/diffusers.git
$ cd diffusers
$ git remote add upstream https://github.com/huggingface/diffusers.git
```
3. Create a new branch to hold your development changes:
```bash
$ git checkout -b a-descriptive-name-for-my-changes
```
**Do not** work on the `main` branch.
4. Set up a development environment by running the following command in a virtual environment:
```bash
$ pip install -e ".[dev]"
```
If you have already cloned the repo, you might need to `git pull` to get the most recent changes in the
library.
5. Develop the features on your branch.
As you work on the features, you should make sure that the test suite
passes. You should run the tests impacted by your changes like this:
```bash
$ pytest tests/<TEST_TO_RUN>.py
```
Before you run the tests, please make sure you install the dependencies required for testing. You can do so
with this command:
```bash
$ pip install -e ".[test]"
```
You can also run the full test suite with the following command, but it takes
a beefy machine to produce a result in a decent amount of time now that
Diffusers has grown a lot. Here is the command for it:
```bash
$ make test
```
🧨 Diffusers relies on `ruff` and `isort` to format its source code
consistently. After you make changes, apply automatic style corrections and code verifications
that can't be automated in one go with:
```bash
$ make style
```
🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality
control runs in CI, however, you can also run the same checks with:
```bash
$ make quality
```
Once you're happy with your changes, add changed files using `git add` and
make a commit with `git commit` to record your changes locally:
```bash
$ git add modified_file.py
$ git commit -m "A descriptive message about your changes."
```
It is a good idea to sync your copy of the code with the original
repository regularly. This way you can quickly account for changes:
```bash
$ git pull upstream main
```
Push the changes to your account using:
```bash
$ git push -u origin a-descriptive-name-for-my-changes
```
6. Once you are satisfied, go to the
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
to the project maintainers for review.
7. It's ok if maintainers ask you for changes. It happens to core contributors
too! So everyone can see the changes in the Pull request, work in your local
branch and push the changes to your fork. They will automatically appear in
the pull request.
### Tests
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
the [tests folder](https://github.com/huggingface/diffusers/tree/main/tests).
We like `pytest` and `pytest-xdist` because it's faster. From the root of the
repository, here's how to run tests with `pytest` for the library:
```bash
$ python -m pytest -n auto --dist=loadfile -s -v ./tests/
```
In fact, that's how `make test` is implemented!
You can specify a smaller set of tests in order to test only the feature
you're working on.
By default, slow tests are skipped. Set the `RUN_SLOW` environment variable to
`yes` to run them. This will download many gigabytes of models — make sure you
have enough disk space and a good Internet connection, or a lot of patience!
```bash
$ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/
```
`unittest` is fully supported, here's how to run tests with it:
```bash
$ python -m unittest discover -s tests -t . -v
$ python -m unittest discover -s examples -t examples -v
```
### Syncing forked main with upstream (HuggingFace) main
To avoid pinging the upstream repository which adds reference notes to each upstream PR and sends unnecessary notifications to the developers involved in these PRs,
when syncing the main branch of a forked repository, please, follow these steps:
1. When possible, avoid syncing with the upstream using a branch and PR on the forked repository. Instead, merge directly into the forked main.
2. If a PR is absolutely necessary, use the following steps after checking out your branch:
```bash
$ git checkout -b your-branch-for-syncing
$ git pull --squash --no-commit upstream main
$ git commit -m '<your message without GitHub references>'
$ git push --set-upstream origin your-branch-for-syncing
```
### Style guide
For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).

1
CONTRIBUTING.md Symbolic link
View File

@@ -0,0 +1 @@
docs/source/en/conceptual/contribution.md

View File

@@ -70,6 +70,10 @@ fix-copies:
python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite
# Auto docstrings in modular blocks
modular-autodoctrings:
python utils/modular_auto_docstring.py
# Run tests for the library
test:

View File

@@ -235,6 +235,10 @@ class _AttentionBackendRegistry:
def get_active_backend(cls):
return cls._active_backend, cls._backends[cls._active_backend]
@classmethod
def set_active_backend(cls, backend: str):
cls._active_backend = backend
@classmethod
def list_backends(cls):
return list(cls._backends.keys())
@@ -294,12 +298,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
_maybe_download_kernel_for_backend(backend)
old_backend = _AttentionBackendRegistry._active_backend
_AttentionBackendRegistry._active_backend = backend
_AttentionBackendRegistry.set_active_backend(backend)
try:
yield
finally:
_AttentionBackendRegistry._active_backend = old_backend
_AttentionBackendRegistry.set_active_backend(old_backend)
def dispatch_attention_fn(
@@ -348,6 +352,7 @@ def dispatch_attention_fn(
check(**kwargs)
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
return backend_fn(**kwargs)

View File

@@ -599,6 +599,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
from .attention import AttentionModuleMixin
from .attention_dispatch import (
AttentionBackendName,
_AttentionBackendRegistry,
_check_attention_backend_requirements,
_maybe_download_kernel_for_backend,
)
@@ -607,6 +608,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
from .attention_processor import Attention, MochiAttention
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
parallel_config_set = False
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if getattr(processor, "_parallel_config", None) is not None:
parallel_config_set = True
break
backend = backend.lower()
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
@@ -614,10 +625,17 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
if parallel_config_set and not _AttentionBackendRegistry._is_context_parallel_available(backend):
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
raise ValueError(
f"Context parallelism is enabled but current attention backend '{backend.value}' "
f"does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()`."
)
_check_attention_backend_requirements(backend)
_maybe_download_kernel_for_backend(backend)
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
@@ -626,6 +644,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
continue
processor._attention_backend = backend
# Important to set the active backend so that it propagates gracefully throughout.
_AttentionBackendRegistry.set_active_backend(backend)
def reset_attention_backend(self) -> None:
"""
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
@@ -1538,7 +1559,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
f"calling `enable_parallelism()`."
f"calling `model.enable_parallelism()`."
)
# All modules use the same attention processor and backend. We don't need to

View File

@@ -18,6 +18,7 @@ from collections import OrderedDict
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union
import PIL.Image
import torch
from ..configuration_utils import ConfigMixin, FrozenDict
@@ -323,11 +324,192 @@ class ConfigSpec:
description: Optional[str] = None
# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
# however some fields are not relevant for intermediate_inputs
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs
# -> should we use different class for inputs and intermediate_inputs?
# ======================================================
# InputParam and OutputParam templates
# ======================================================
INPUT_PARAM_TEMPLATES = {
"prompt": {
"type_hint": str,
"required": True,
"description": "The prompt or prompts to guide image generation.",
},
"negative_prompt": {
"type_hint": str,
"description": "The prompt or prompts not to guide the image generation.",
},
"max_sequence_length": {
"type_hint": int,
"default": 512,
"description": "Maximum sequence length for prompt encoding.",
},
"height": {
"type_hint": int,
"description": "The height in pixels of the generated image.",
},
"width": {
"type_hint": int,
"description": "The width in pixels of the generated image.",
},
"num_inference_steps": {
"type_hint": int,
"default": 50,
"description": "The number of denoising steps.",
},
"num_images_per_prompt": {
"type_hint": int,
"default": 1,
"description": "The number of images to generate per prompt.",
},
"generator": {
"type_hint": torch.Generator,
"description": "Torch generator for deterministic generation.",
},
"sigmas": {
"type_hint": List[float],
"description": "Custom sigmas for the denoising process.",
},
"strength": {
"type_hint": float,
"default": 0.9,
"description": "Strength for img2img/inpainting.",
},
"image": {
"type_hint": Union[PIL.Image.Image, List[PIL.Image.Image]],
"required": True,
"description": "Reference image(s) for denoising. Can be a single image or list of images.",
},
"latents": {
"type_hint": torch.Tensor,
"description": "Pre-generated noisy latents for image generation.",
},
"timesteps": {
"type_hint": torch.Tensor,
"description": "Timesteps for the denoising process.",
},
"output_type": {
"type_hint": str,
"default": "pil",
"description": "Output format: 'pil', 'np', 'pt'.",
},
"attention_kwargs": {
"type_hint": Dict[str, Any],
"description": "Additional kwargs for attention processors.",
},
"denoiser_input_fields": {
"name": None,
"kwargs_type": "denoiser_input_fields",
"description": "conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
},
# inpainting
"mask_image": {
"type_hint": PIL.Image.Image,
"required": True,
"description": "Mask image for inpainting.",
},
"padding_mask_crop": {
"type_hint": int,
"description": "Padding for mask cropping in inpainting.",
},
# controlnet
"control_image": {
"type_hint": PIL.Image.Image,
"required": True,
"description": "Control image for ControlNet conditioning.",
},
"control_guidance_start": {
"type_hint": float,
"default": 0.0,
"description": "When to start applying ControlNet.",
},
"control_guidance_end": {
"type_hint": float,
"default": 1.0,
"description": "When to stop applying ControlNet.",
},
"controlnet_conditioning_scale": {
"type_hint": float,
"default": 1.0,
"description": "Scale for ControlNet conditioning.",
},
"layers": {
"type_hint": int,
"default": 4,
"description": "Number of layers to extract from the image",
},
# common intermediate inputs
"prompt_embeds": {
"type_hint": torch.Tensor,
"required": True,
"description": "text embeddings used to guide the image generation. Can be generated from text_encoder step.",
},
"prompt_embeds_mask": {
"type_hint": torch.Tensor,
"required": True,
"description": "mask for the text embeddings. Can be generated from text_encoder step.",
},
"negative_prompt_embeds": {
"type_hint": torch.Tensor,
"description": "negative text embeddings used to guide the image generation. Can be generated from text_encoder step.",
},
"negative_prompt_embeds_mask": {
"type_hint": torch.Tensor,
"description": "mask for the negative text embeddings. Can be generated from text_encoder step.",
},
"image_latents": {
"type_hint": torch.Tensor,
"required": True,
"description": "image latents used to guide the image generation. Can be generated from vae_encoder step.",
},
"batch_size": {
"type_hint": int,
"default": 1,
"description": "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
},
"dtype": {
"type_hint": torch.dtype,
"default": torch.float32,
"description": "The dtype of the model inputs, can be generated in input step.",
},
}
OUTPUT_PARAM_TEMPLATES = {
"images": {
"type_hint": List[PIL.Image.Image],
"description": "Generated images.",
},
"latents": {
"type_hint": torch.Tensor,
"description": "Denoised latents.",
},
# intermediate outputs
"prompt_embeds": {
"type_hint": torch.Tensor,
"kwargs_type": "denoiser_input_fields",
"description": "The prompt embeddings.",
},
"prompt_embeds_mask": {
"type_hint": torch.Tensor,
"kwargs_type": "denoiser_input_fields",
"description": "The encoder attention mask.",
},
"negative_prompt_embeds": {
"type_hint": torch.Tensor,
"kwargs_type": "denoiser_input_fields",
"description": "The negative prompt embeddings.",
},
"negative_prompt_embeds_mask": {
"type_hint": torch.Tensor,
"kwargs_type": "denoiser_input_fields",
"description": "The negative prompt embeddings mask.",
},
"image_latents": {
"type_hint": torch.Tensor,
"description": "The latent representation of the input image.",
},
}
@dataclass
class InputParam:
"""Specification for an input parameter."""
@@ -337,11 +519,31 @@ class InputParam:
default: Any = None
required: bool = False
description: str = ""
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
kwargs_type: str = None
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@classmethod
def template(cls, template_name: str, note: str = None, **overrides) -> "InputParam":
"""Get template for name if exists, otherwise raise ValueError."""
if template_name not in INPUT_PARAM_TEMPLATES:
raise ValueError(f"InputParam template for {template_name} not found")
template_kwargs = INPUT_PARAM_TEMPLATES[template_name].copy()
# Determine the actual param name:
# 1. From overrides if provided
# 2. From template if present
# 3. Fall back to template_name
name = overrides.pop("name", template_kwargs.pop("name", template_name))
if note and "description" in template_kwargs:
template_kwargs["description"] = f"{template_kwargs['description']} ({note})"
template_kwargs.update(overrides)
return cls(name=name, **template_kwargs)
@dataclass
class OutputParam:
@@ -350,13 +552,33 @@ class OutputParam:
name: str
type_hint: Any = None
description: str = ""
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
kwargs_type: str = None
def __repr__(self):
return (
f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
)
@classmethod
def template(cls, template_name: str, note: str = None, **overrides) -> "OutputParam":
"""Get template for name if exists, otherwise raise ValueError."""
if template_name not in OUTPUT_PARAM_TEMPLATES:
raise ValueError(f"OutputParam template for {template_name} not found")
template_kwargs = OUTPUT_PARAM_TEMPLATES[template_name].copy()
# Determine the actual param name:
# 1. From overrides if provided
# 2. From template if present
# 3. Fall back to template_name
name = overrides.pop("name", template_kwargs.pop("name", template_name))
if note and "description" in template_kwargs:
template_kwargs["description"] = f"{template_kwargs['description']} ({note})"
template_kwargs.update(overrides)
return cls(name=name, **template_kwargs)
def format_inputs_short(inputs):
"""
@@ -509,10 +731,12 @@ def format_params(params, header="Args", indent_level=4, max_line_length=115):
desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description)
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
param_str += f"\n{desc_indent}{wrapped_desc}"
else:
param_str += f"\n{desc_indent}TODO: Add description."
formatted_params.append(param_str)
return "\n\n".join(formatted_params)
return "\n".join(formatted_params)
def format_input_params(input_params, indent_level=4, max_line_length=115):
@@ -582,7 +806,7 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty
loading_field_values = []
for field_name in component.loading_fields():
field_value = getattr(component, field_name)
if field_value is not None:
if field_value:
loading_field_values.append(f"{field_name}={field_value}")
# Add loading field information if available
@@ -669,17 +893,17 @@ def make_doc_string(
# Add description
if description:
desc_lines = description.strip().split("\n")
aligned_desc = "\n".join(" " + line for line in desc_lines)
aligned_desc = "\n".join(" " + line.rstrip() for line in desc_lines)
output += aligned_desc + "\n\n"
# Add components section if provided
if expected_components and len(expected_components) > 0:
components_str = format_components(expected_components, indent_level=2)
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
output += components_str + "\n\n"
# Add configs section if provided
if expected_configs and len(expected_configs) > 0:
configs_str = format_configs(expected_configs, indent_level=2)
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
output += configs_str + "\n\n"
# Add inputs section

View File

@@ -118,7 +118,40 @@ def get_timesteps(scheduler, num_inference_steps, strength):
# ====================
# auto_docstring
class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
"""
Prepare initial random noise for the generation process
Components:
pachifier (`QwenImagePachifier`)
Inputs:
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
batch_size (`int`, *optional*, defaults to 1):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
be generated in input step.
dtype (`dtype`, *optional*, defaults to torch.float32):
The dtype of the model inputs, can be generated in input step.
Outputs:
height (`int`):
if not set, updated to default value
width (`int`):
if not set, updated to default value
latents (`Tensor`):
The initial latents to use for the denoising process
"""
model_name = "qwenimage"
@property
@@ -134,28 +167,20 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents"),
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="generator"),
InputParam(
name="batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam(
name="dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs, can be generated in input step.",
),
InputParam.template("latents"),
InputParam.template("height"),
InputParam.template("width"),
InputParam.template("num_images_per_prompt"),
InputParam.template("generator"),
InputParam.template("batch_size"),
InputParam.template("dtype"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="height", type_hint=int, description="if not set, updated to default value"),
OutputParam(name="width", type_hint=int, description="if not set, updated to default value"),
OutputParam(
name="latents",
type_hint=torch.Tensor,
@@ -209,7 +234,42 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
"""
Prepare initial random noise (B, layers+1, C, H, W) for the generation process
Components:
pachifier (`QwenImageLayeredPachifier`)
Inputs:
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
layers (`int`, *optional*, defaults to 4):
Number of layers to extract from the image
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
batch_size (`int`, *optional*, defaults to 1):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
be generated in input step.
dtype (`dtype`, *optional*, defaults to torch.float32):
The dtype of the model inputs, can be generated in input step.
Outputs:
height (`int`):
if not set, updated to default value
width (`int`):
if not set, updated to default value
latents (`Tensor`):
The initial latents to use for the denoising process
"""
model_name = "qwenimage-layered"
@property
@@ -225,29 +285,21 @@ class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents"),
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="layers", default=4),
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="generator"),
InputParam(
name="batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam(
name="dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs, can be generated in input step.",
),
InputParam.template("latents"),
InputParam.template("height"),
InputParam.template("width"),
InputParam.template("layers"),
InputParam.template("num_images_per_prompt"),
InputParam.template("generator"),
InputParam.template("batch_size"),
InputParam.template("dtype"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="height", type_hint=int, description="if not set, updated to default value"),
OutputParam(name="width", type_hint=int, description="if not set, updated to default value"),
OutputParam(
name="latents",
type_hint=torch.Tensor,
@@ -301,7 +353,31 @@ class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
"""
Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps,
prepare_latents. Both noise and image latents should alreadybe patchified.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`)
Inputs:
latents (`Tensor`):
The initial random noised, can be generated in prepare latent step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be
generated from vae encoder and updated in input step.)
timesteps (`Tensor`):
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
Outputs:
initial_noise (`Tensor`):
The initial random noised used for inpainting denoising.
latents (`Tensor`):
The scaled noisy latents to use for inpainting/image-to-image denoising.
"""
model_name = "qwenimage"
@property
@@ -323,12 +399,7 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="The initial random noised, can be generated in prepare latent step.",
),
InputParam(
name="image_latents",
required=True,
type_hint=torch.Tensor,
description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
),
InputParam.template("image_latents", note="Can be generated from vae encoder and updated in input step."),
InputParam(
name="timesteps",
required=True,
@@ -345,6 +416,11 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="The initial random noised used for inpainting denoising.",
),
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The scaled noisy latents to use for inpainting/image-to-image denoising.",
),
]
@staticmethod
@@ -382,7 +458,29 @@ class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
"""
Step that creates mask latents from preprocessed mask_image by interpolating to latent space.
Components:
pachifier (`QwenImagePachifier`)
Inputs:
processed_mask_image (`Tensor`):
The processed mask to use for the inpainting process.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
dtype (`dtype`, *optional*, defaults to torch.float32):
The dtype of the model inputs, can be generated in input step.
Outputs:
mask (`Tensor`):
The mask to use for the inpainting process.
"""
model_name = "qwenimage"
@property
@@ -404,9 +502,9 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="The processed mask to use for the inpainting process.",
),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="dtype", required=True),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam.template("dtype"),
]
@property
@@ -450,7 +548,28 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
# ====================
# auto_docstring
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
"""
Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents
step.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`)
Inputs:
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
latents (`Tensor`):
The initial random noised latents for the denoising process. Can be generated in prepare latents step.
Outputs:
timesteps (`Tensor`):
The timesteps to use for the denoising process
"""
model_name = "qwenimage"
@property
@@ -466,13 +585,13 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="num_inference_steps", default=50),
InputParam(name="sigmas"),
InputParam.template("num_inference_steps"),
InputParam.template("sigmas"),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process, used to calculate the image sequence length.",
description="The initial random noised latents for the denoising process. Can be generated in prepare latents step.",
),
]
@@ -516,7 +635,27 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
"""
Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`)
Inputs:
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
Outputs:
timesteps (`Tensor`):
The timesteps to use for the denoising process.
"""
model_name = "qwenimage-layered"
@property
@@ -532,15 +671,17 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=50, type_hint=int),
InputParam("sigmas", type_hint=List[float]),
InputParam("image_latents", required=True, type_hint=torch.Tensor),
InputParam.template("num_inference_steps"),
InputParam.template("sigmas"),
InputParam.template("image_latents"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="timesteps", type_hint=torch.Tensor),
OutputParam(
name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process."
),
]
@torch.no_grad()
@@ -574,7 +715,32 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
"""
Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after
prepare latents step.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`)
Inputs:
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
latents (`Tensor`):
The latents to use for the denoising process. Can be generated in prepare latents step.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
Outputs:
timesteps (`Tensor`):
The timesteps to use for the denoising process.
num_inference_steps (`int`):
The number of denoising steps to perform at inference time. Updated based on strength.
"""
model_name = "qwenimage"
@property
@@ -590,15 +756,15 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="num_inference_steps", default=50),
InputParam(name="sigmas"),
InputParam.template("num_inference_steps"),
InputParam.template("sigmas"),
InputParam(
name="latents",
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process, used to calculate the image sequence length.",
description="The latents to use for the denoising process. Can be generated in prepare latents step.",
),
InputParam(name="strength", default=0.9),
InputParam.template("strength", default=0.9),
]
@property
@@ -607,7 +773,12 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
OutputParam(
name="timesteps",
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
description="The timesteps to use for the denoising process.",
),
OutputParam(
name="num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time. Updated based on strength.",
),
]
@@ -654,7 +825,33 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
## RoPE inputs for denoiser
# auto_docstring
class QwenImageRoPEInputsStep(ModularPipelineBlocks):
"""
Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step
Inputs:
batch_size (`int`, *optional*, defaults to 1):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
be generated in input step.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
Outputs:
img_shapes (`List`):
The shapes of the images latents, used for RoPE calculation
txt_seq_lens (`List`):
The sequence lengths of the prompt embeds, used for RoPE calculation
negative_txt_seq_lens (`List`):
The sequence lengths of the negative prompt embeds, used for RoPE calculation
"""
model_name = "qwenimage"
@property
@@ -666,11 +863,11 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
InputParam.template("batch_size"),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam.template("prompt_embeds_mask"),
InputParam.template("negative_prompt_embeds_mask"),
]
@property
@@ -702,7 +899,38 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
"""
Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after
prepare_latents step
Inputs:
batch_size (`int`, *optional*, defaults to 1):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
be generated in input step.
image_height (`int`):
The height of the reference image. Can be generated in input step.
image_width (`int`):
The width of the reference image. Can be generated in input step.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
Outputs:
img_shapes (`List`):
The shapes of the images latents, used for RoPE calculation
txt_seq_lens (`List`):
The sequence lengths of the prompt embeds, used for RoPE calculation
negative_txt_seq_lens (`List`):
The sequence lengths of the negative prompt embeds, used for RoPE calculation
"""
model_name = "qwenimage"
@property
@@ -712,13 +940,23 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="image_height", required=True),
InputParam(name="image_width", required=True),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
InputParam.template("batch_size"),
InputParam(
name="image_height",
required=True,
type_hint=int,
description="The height of the reference image. Can be generated in input step.",
),
InputParam(
name="image_width",
required=True,
type_hint=int,
description="The width of the reference image. Can be generated in input step.",
),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam.template("prompt_embeds_mask"),
InputParam.template("negative_prompt_embeds_mask"),
]
@property
@@ -756,7 +994,39 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
"""
Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.
Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images. Should be placed
after prepare_latents step.
Inputs:
batch_size (`int`, *optional*, defaults to 1):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
be generated in input step.
image_height (`List`):
The heights of the reference images. Can be generated in input step.
image_width (`List`):
The widths of the reference images. Can be generated in input step.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
Outputs:
img_shapes (`List`):
The shapes of the image latents, used for RoPE calculation
txt_seq_lens (`List`):
The sequence lengths of the prompt embeds, used for RoPE calculation
negative_txt_seq_lens (`List`):
The sequence lengths of the negative prompt embeds, used for RoPE calculation
"""
model_name = "qwenimage-edit-plus"
@property
@@ -770,13 +1040,23 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="image_height", required=True, type_hint=List[int]),
InputParam(name="image_width", required=True, type_hint=List[int]),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
InputParam.template("batch_size"),
InputParam(
name="image_height",
required=True,
type_hint=List[int],
description="The heights of the reference images. Can be generated in input step.",
),
InputParam(
name="image_width",
required=True,
type_hint=List[int],
description="The widths of the reference images. Can be generated in input step.",
),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam.template("prompt_embeds_mask"),
InputParam.template("negative_prompt_embeds_mask"),
]
@property
@@ -832,7 +1112,37 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
"""
Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step
Inputs:
batch_size (`int`, *optional*, defaults to 1):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
be generated in input step.
layers (`int`, *optional*, defaults to 4):
Number of layers to extract from the image
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
Outputs:
img_shapes (`List`):
The shapes of the image latents, used for RoPE calculation
txt_seq_lens (`List`):
The sequence lengths of the prompt embeds, used for RoPE calculation
negative_txt_seq_lens (`List`):
The sequence lengths of the negative prompt embeds, used for RoPE calculation
additional_t_cond (`Tensor`):
The additional t cond, used for RoPE calculation
"""
model_name = "qwenimage-layered"
@property
@@ -844,12 +1154,12 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="layers", required=True),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
InputParam.template("batch_size"),
InputParam.template("layers"),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam.template("prompt_embeds_mask"),
InputParam.template("negative_prompt_embeds_mask"),
]
@property
@@ -914,7 +1224,34 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
## ControlNet inputs for denoiser
# auto_docstring
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
"""
step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step.
Components:
controlnet (`QwenImageControlNetModel`)
Inputs:
control_guidance_start (`float`, *optional*, defaults to 0.0):
When to start applying ControlNet.
control_guidance_end (`float`, *optional*, defaults to 1.0):
When to stop applying ControlNet.
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Scale for ControlNet conditioning.
control_image_latents (`Tensor`):
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
step.
timesteps (`Tensor`):
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
Outputs:
controlnet_keep (`List`):
The controlnet keep values
"""
model_name = "qwenimage"
@property
@@ -930,12 +1267,17 @@ class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("control_guidance_start", default=0.0),
InputParam("control_guidance_end", default=1.0),
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("control_image_latents", required=True),
InputParam.template("control_guidance_start"),
InputParam.template("control_guidance_end"),
InputParam.template("controlnet_conditioning_scale"),
InputParam(
"timesteps",
name="control_image_latents",
required=True,
type_hint=torch.Tensor,
description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.",
),
InputParam(
name="timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",

View File

@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from typing import Any, Dict, List
import numpy as np
import PIL
import torch
from ...configuration_utils import FrozenDict
@@ -31,7 +29,30 @@ logger = logging.get_logger(__name__)
# after denoising loop (unpack latents)
# auto_docstring
class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
"""
Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size,
channels, 1, height, width)
Components:
pachifier (`QwenImagePachifier`)
Inputs:
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
latents (`Tensor`):
The latents to decode, can be generated in the denoise step.
Outputs:
latents (`Tensor`):
The denoisedlatents unpacked to B, C, 1, H, W
"""
model_name = "qwenimage"
@property
@@ -49,13 +70,21 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to decode, can be generated in the denoise step",
description="The latents to decode, can be generated in the denoise step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="latents", type_hint=torch.Tensor, description="The denoisedlatents unpacked to B, C, 1, H, W"
),
]
@@ -72,7 +101,29 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
"""
Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising.
Components:
pachifier (`QwenImageLayeredPachifier`)
Inputs:
latents (`Tensor`):
The denoised latents to decode, can be generated in the denoise step.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
layers (`int`, *optional*, defaults to 4):
Number of layers to extract from the image
Outputs:
latents (`Tensor`):
Denoised latents. (unpacked to B, C, layers+1, H, W)
"""
model_name = "qwenimage-layered"
@property
@@ -88,10 +139,21 @@ class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents", required=True, type_hint=torch.Tensor),
InputParam("height", required=True, type_hint=int),
InputParam("width", required=True, type_hint=int),
InputParam("layers", required=True, type_hint=int),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents to decode, can be generated in the denoise step.",
),
InputParam.template("height", required=True),
InputParam.template("width", required=True),
InputParam.template("layers"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam.template("latents", note="unpacked to B, C, layers+1, H, W"),
]
@torch.no_grad()
@@ -112,7 +174,26 @@ class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
# decode step
# auto_docstring
class QwenImageDecoderStep(ModularPipelineBlocks):
"""
Step that decodes the latents to images
Components:
vae (`AutoencoderKLQwenImage`)
Inputs:
latents (`Tensor`):
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
step.
Outputs:
images (`List`):
Generated images. (tensor output of the vae decoder.)
"""
model_name = "qwenimage"
@property
@@ -134,19 +215,13 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to decode, can be generated in the denoise step",
description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
)
]
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam.template("images", note="tensor output of the vae decoder.")]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
@@ -176,7 +251,26 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
"""
Decode unpacked latents (B, C, layers+1, H, W) into layer images.
Components:
vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`)
Inputs:
latents (`Tensor`):
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
step.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`List`):
Generated images.
"""
model_name = "qwenimage-layered"
@property
@@ -198,14 +292,19 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents", required=True, type_hint=torch.Tensor),
InputParam("output_type", default="pil", type_hint=str),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.",
),
InputParam.template("output_type"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
OutputParam.template("images"),
]
@torch.no_grad()
@@ -251,7 +350,27 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
# postprocess the decoded images
# auto_docstring
class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
"""
postprocess the generated image
Components:
image_processor (`VaeImageProcessor`)
Inputs:
images (`Tensor`):
the generated image tensor from decoders step
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`List`):
Generated images.
"""
model_name = "qwenimage"
@property
@@ -272,15 +391,19 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("images", required=True, description="the generated image from decoders step"),
InputParam(
name="output_type",
default="pil",
type_hint=str,
description="The type of the output images, can be 'pil', 'np', 'pt'",
name="images",
required=True,
type_hint=torch.Tensor,
description="the generated image tensor from decoders step",
),
InputParam.template("output_type"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam.template("images")]
@staticmethod
def check_inputs(output_type):
if output_type not in ["pil", "np", "pt"]:
@@ -301,7 +424,28 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
"""
postprocess the generated image, optional apply the mask overally to the original image..
Components:
image_mask_processor (`InpaintProcessor`)
Inputs:
images (`Tensor`):
the generated image tensor from decoders step
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
mask_overlay_kwargs (`Dict`, *optional*):
The kwargs for the postprocess step to apply the mask overlay. generated in
InpaintProcessImagesInputStep.
Outputs:
images (`List`):
Generated images.
"""
model_name = "qwenimage"
@property
@@ -322,16 +466,24 @@ class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("images", required=True, description="the generated image from decoders step"),
InputParam(
name="output_type",
default="pil",
type_hint=str,
description="The type of the output images, can be 'pil', 'np', 'pt'",
name="images",
required=True,
type_hint=torch.Tensor,
description="the generated image tensor from decoders step",
),
InputParam.template("output_type"),
InputParam(
name="mask_overlay_kwargs",
type_hint=Dict[str, Any],
description="The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep.",
),
InputParam("mask_overlay_kwargs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam.template("images")]
@staticmethod
def check_inputs(output_type, mask_overlay_kwargs):
if output_type not in ["pil", "np", "pt"]:

View File

@@ -50,7 +50,7 @@ class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
name="latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
@@ -80,17 +80,12 @@ class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
name="latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"image_latents",
required=True,
type_hint=torch.Tensor,
description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.",
),
InputParam.template("image_latents"),
]
@torch.no_grad()
@@ -134,29 +129,12 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam.template("controlnet_conditioning_scale", note="updated in prepare_controlnet_inputs step."),
InputParam(
"controlnet_conditioning_scale",
type_hint=float,
description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"controlnet_keep",
name="controlnet_keep",
required=True,
type_hint=List[float],
description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="denoiser_input_fields",
description=(
"All conditional model inputs for the denoiser. "
"It should contain prompt_embeds/negative_prompt_embeds."
),
description="The controlnet keep values. Can be generated in prepare_controlnet_inputs step.",
),
]
@@ -217,28 +195,13 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
InputParam.template("attention_kwargs"),
InputParam.template("denoiser_input_fields"),
InputParam(
"img_shapes",
required=True,
type_hint=List[Tuple[int, int]],
description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
description="The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.",
),
]
@@ -317,23 +280,8 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
InputParam.template("attention_kwargs"),
InputParam.template("denoiser_input_fields"),
InputParam(
"img_shapes",
required=True,
@@ -415,7 +363,7 @@ class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."),
OutputParam.template("latents"),
]
@torch.no_grad()
@@ -456,24 +404,19 @@ class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.",
),
InputParam(
"image_latents",
required=True,
type_hint=torch.Tensor,
description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.",
),
InputParam.template("image_latents"),
InputParam(
"initial_noise",
required=True,
type_hint=torch.Tensor,
description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.",
),
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam.template("latents"),
]
@torch.no_grad()
@@ -515,17 +458,12 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
name="timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam.template("num_inference_steps", required=True),
]
@torch.no_grad()
@@ -557,7 +495,42 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
# Qwen Image (text2image, image2image)
# auto_docstring
class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
"""
Denoise step that iteratively denoise the latents.
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
defined in `sub_blocks` sequencially:
- `QwenImageLoopBeforeDenoiser`
- `QwenImageLoopDenoiser`
- `QwenImageLoopAfterDenoiser`
This block supports text2image and image2image tasks for QwenImage.
Components:
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
(`FlowMatchEulerDiscreteScheduler`)
Inputs:
timesteps (`Tensor`):
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
num_inference_steps (`int`):
The number of denoising steps.
latents (`Tensor`):
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`List`):
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage"
block_classes = [
@@ -570,8 +543,8 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
"Denoise step that iteratively denoise the latents.\n"
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method\n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `QwenImageLoopBeforeDenoiser`\n"
" - `QwenImageLoopDenoiser`\n"
@@ -581,7 +554,47 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
# Qwen Image (inpainting)
# auto_docstring
class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
"""
Denoise step that iteratively denoise the latents.
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
defined in `sub_blocks` sequencially:
- `QwenImageLoopBeforeDenoiser`
- `QwenImageLoopDenoiser`
- `QwenImageLoopAfterDenoiser`
- `QwenImageLoopAfterDenoiserInpaint`
This block supports inpainting tasks for QwenImage.
Components:
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
(`FlowMatchEulerDiscreteScheduler`)
Inputs:
timesteps (`Tensor`):
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
num_inference_steps (`int`):
The number of denoising steps.
latents (`Tensor`):
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`List`):
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
mask (`Tensor`):
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
initial_noise (`Tensor`):
The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage"
block_classes = [
QwenImageLoopBeforeDenoiser,
@@ -606,7 +619,47 @@ class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
# Qwen Image (text2image, image2image) with controlnet
# auto_docstring
class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
"""
Denoise step that iteratively denoise the latents.
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
defined in `sub_blocks` sequencially:
- `QwenImageLoopBeforeDenoiser`
- `QwenImageLoopBeforeDenoiserControlNet`
- `QwenImageLoopDenoiser`
- `QwenImageLoopAfterDenoiser`
This block supports text2img/img2img tasks with controlnet for QwenImage.
Components:
guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer
(`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`)
Inputs:
timesteps (`Tensor`):
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
num_inference_steps (`int`):
The number of denoising steps.
latents (`Tensor`):
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
control_image_latents (`Tensor`):
The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.)
controlnet_keep (`List`):
The controlnet keep values. Can be generated in prepare_controlnet_inputs step.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`List`):
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage"
block_classes = [
QwenImageLoopBeforeDenoiser,
@@ -631,7 +684,54 @@ class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
# Qwen Image (inpainting) with controlnet
# auto_docstring
class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
"""
Denoise step that iteratively denoise the latents.
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
defined in `sub_blocks` sequencially:
- `QwenImageLoopBeforeDenoiser`
- `QwenImageLoopBeforeDenoiserControlNet`
- `QwenImageLoopDenoiser`
- `QwenImageLoopAfterDenoiser`
- `QwenImageLoopAfterDenoiserInpaint`
This block supports inpainting tasks with controlnet for QwenImage.
Components:
guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer
(`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`)
Inputs:
timesteps (`Tensor`):
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
num_inference_steps (`int`):
The number of denoising steps.
latents (`Tensor`):
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
control_image_latents (`Tensor`):
The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.)
controlnet_keep (`List`):
The controlnet keep values. Can be generated in prepare_controlnet_inputs step.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`List`):
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
mask (`Tensor`):
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
initial_noise (`Tensor`):
The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage"
block_classes = [
QwenImageLoopBeforeDenoiser,
@@ -664,7 +764,42 @@ class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
# Qwen Image Edit (image2image)
# auto_docstring
class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
"""
Denoise step that iteratively denoise the latents.
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
defined in `sub_blocks` sequencially:
- `QwenImageEditLoopBeforeDenoiser`
- `QwenImageEditLoopDenoiser`
- `QwenImageLoopAfterDenoiser`
This block supports QwenImage Edit.
Components:
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
(`FlowMatchEulerDiscreteScheduler`)
Inputs:
timesteps (`Tensor`):
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
num_inference_steps (`int`):
The number of denoising steps.
latents (`Tensor`):
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`List`):
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditLoopBeforeDenoiser,
@@ -687,7 +822,47 @@ class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
# Qwen Image Edit (inpainting)
# auto_docstring
class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
"""
Denoise step that iteratively denoise the latents.
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
defined in `sub_blocks` sequencially:
- `QwenImageEditLoopBeforeDenoiser`
- `QwenImageEditLoopDenoiser`
- `QwenImageLoopAfterDenoiser`
- `QwenImageLoopAfterDenoiserInpaint`
This block supports inpainting tasks for QwenImage Edit.
Components:
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
(`FlowMatchEulerDiscreteScheduler`)
Inputs:
timesteps (`Tensor`):
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
num_inference_steps (`int`):
The number of denoising steps.
latents (`Tensor`):
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`List`):
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
mask (`Tensor`):
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
initial_noise (`Tensor`):
The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditLoopBeforeDenoiser,
@@ -712,7 +887,42 @@ class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
# Qwen Image Layered (image2image)
# auto_docstring
class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper):
"""
Denoise step that iteratively denoise the latents.
Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks
defined in `sub_blocks` sequencially:
- `QwenImageEditLoopBeforeDenoiser`
- `QwenImageEditLoopDenoiser`
- `QwenImageLoopAfterDenoiser`
This block supports QwenImage Layered.
Components:
guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler
(`FlowMatchEulerDiscreteScheduler`)
Inputs:
timesteps (`Tensor`):
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
num_inference_steps (`int`):
The number of denoising steps.
latents (`Tensor`):
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`List`):
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage-layered"
block_classes = [
QwenImageEditLoopBeforeDenoiser,

File diff suppressed because it is too large Load Diff

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from typing import List, Optional, Tuple
import torch
@@ -109,7 +109,44 @@ def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: in
return height, width
# auto_docstring
class QwenImageTextInputsStep(ModularPipelineBlocks):
"""
Text input processing step that standardizes text embeddings for the pipeline.
This step:
1. Determines `batch_size` and `dtype` based on `prompt_embeds`
2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)
This block should be placed after all encoder steps to process the text embeddings before they are used in
subsequent pipeline steps.
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
Outputs:
batch_size (`int`):
The batch size of the prompt embeddings
dtype (`dtype`):
The data type of the prompt embeddings
prompt_embeds (`Tensor`):
The prompt embeddings. (batch-expanded)
prompt_embeds_mask (`Tensor`):
The encoder attention mask. (batch-expanded)
negative_prompt_embeds (`Tensor`):
The negative prompt embeddings. (batch-expanded)
negative_prompt_embeds_mask (`Tensor`):
The negative prompt embeddings mask. (batch-expanded)
"""
model_name = "qwenimage"
@property
@@ -129,26 +166,22 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"),
InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"),
InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"),
InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"),
InputParam.template("num_images_per_prompt"),
InputParam.template("prompt_embeds"),
InputParam.template("prompt_embeds_mask"),
InputParam.template("negative_prompt_embeds"),
InputParam.template("negative_prompt_embeds_mask"),
]
@property
def intermediate_outputs(self) -> List[str]:
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(name="batch_size", type_hint=int, description="The batch size of the prompt embeddings"),
OutputParam(name="dtype", type_hint=torch.dtype, description="The data type of the prompt embeddings"),
OutputParam.template("prompt_embeds", note="batch-expanded"),
OutputParam.template("prompt_embeds_mask", note="batch-expanded"),
OutputParam.template("negative_prompt_embeds", note="batch-expanded"),
OutputParam.template("negative_prompt_embeds_mask", note="batch-expanded"),
]
@staticmethod
@@ -221,20 +254,76 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
"""Input step for QwenImage: update height/width, expand batch, patchify."""
"""
Input processing step that:
1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size
2. For additional batch inputs: Expands batch dimensions to match final batch size
Configured inputs:
- Image latent inputs: ['image_latents']
This block should be placed after the encoder steps and the text input step.
Components:
pachifier (`QwenImagePachifier`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
batch_size (`int`, *optional*, defaults to 1):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
be generated in input step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
Outputs:
image_height (`int`):
The image height calculated from the image latents dimension
image_width (`int`):
The image width calculated from the image latents dimension
height (`int`):
if not provided, updated to image height
width (`int`):
if not provided, updated to image width
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
batch-expanded)
"""
model_name = "qwenimage"
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
image_latent_inputs: Optional[List[InputParam]] = None,
additional_batch_inputs: Optional[List[InputParam]] = None,
):
# by default, process `image_latents`
if image_latent_inputs is None:
image_latent_inputs = [InputParam.template("image_latents")]
if additional_batch_inputs is None:
additional_batch_inputs = []
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}")
else:
for input_param in image_latent_inputs:
if not isinstance(input_param, InputParam):
raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}")
if not isinstance(additional_batch_inputs, list):
additional_batch_inputs = [additional_batch_inputs]
raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}")
else:
for input_param in additional_batch_inputs:
if not isinstance(input_param, InputParam):
raise ValueError(
f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}"
)
self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
@@ -252,9 +341,9 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
if self._image_latent_inputs:
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}"
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}"
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
@@ -269,23 +358,19 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="batch_size", required=True),
InputParam(name="height"),
InputParam(name="width"),
InputParam.template("num_images_per_prompt"),
InputParam.template("batch_size"),
InputParam.template("height"),
InputParam.template("width"),
]
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
# default is `image_latents`
inputs += self._image_latent_inputs + self._additional_batch_inputs
return inputs
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
outputs = [
OutputParam(
name="image_height",
type_hint=int,
@@ -298,11 +383,43 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
),
]
# `height`/`width` are not new outputs, but they will be updated if any image latent inputs are provided
if len(self._image_latent_inputs) > 0:
outputs.append(
OutputParam(name="height", type_hint=int, description="if not provided, updated to image height")
)
outputs.append(
OutputParam(name="width", type_hint=int, description="if not provided, updated to image width")
)
# image latent inputs are modified in place (patchified and batch-expanded)
for input_param in self._image_latent_inputs:
outputs.append(
OutputParam(
name=input_param.name,
type_hint=input_param.type_hint,
description=input_param.description + " (patchified and batch-expanded)",
)
)
# additional batch inputs (batch-expanded only)
for input_param in self._additional_batch_inputs:
outputs.append(
OutputParam(
name=input_param.name,
type_hint=input_param.type_hint,
description=input_param.description + " (batch-expanded)",
)
)
return outputs
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs
for image_latent_input_name in self._image_latent_inputs:
for input_param in self._image_latent_inputs:
image_latent_input_name = input_param.name
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
@@ -331,7 +448,8 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
setattr(block_state, image_latent_input_name, image_latent_tensor)
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
for input_param in self._additional_batch_inputs:
input_name = input_param.name
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
@@ -349,20 +467,76 @@ class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
"""Input step for QwenImage Edit Plus: handles list of latents with different sizes."""
"""
Input processing step for Edit Plus that:
1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch
2. For additional batch inputs: Expands batch dimensions to match final batch size
Height/width defaults to last image in the list.
Configured inputs:
- Image latent inputs: ['image_latents']
This block should be placed after the encoder steps and the text input step.
Components:
pachifier (`QwenImagePachifier`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
batch_size (`int`, *optional*, defaults to 1):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
be generated in input step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
Outputs:
image_height (`List`):
The image heights calculated from the image latents dimension
image_width (`List`):
The image widths calculated from the image latents dimension
height (`int`):
if not provided, updated to image height
width (`int`):
if not provided, updated to image width
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified,
concatenated, and batch-expanded)
"""
model_name = "qwenimage-edit-plus"
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
image_latent_inputs: Optional[List[InputParam]] = None,
additional_batch_inputs: Optional[List[InputParam]] = None,
):
if image_latent_inputs is None:
image_latent_inputs = [InputParam.template("image_latents")]
if additional_batch_inputs is None:
additional_batch_inputs = []
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}")
else:
for input_param in image_latent_inputs:
if not isinstance(input_param, InputParam):
raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}")
if not isinstance(additional_batch_inputs, list):
additional_batch_inputs = [additional_batch_inputs]
raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}")
else:
for input_param in additional_batch_inputs:
if not isinstance(input_param, InputParam):
raise ValueError(
f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}"
)
self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
@@ -381,9 +555,9 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
if self._image_latent_inputs:
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}"
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}"
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
@@ -398,23 +572,20 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="batch_size", required=True),
InputParam(name="height"),
InputParam(name="width"),
InputParam.template("num_images_per_prompt"),
InputParam.template("batch_size"),
InputParam.template("height"),
InputParam.template("width"),
]
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
# default is `image_latents`
inputs += self._image_latent_inputs + self._additional_batch_inputs
return inputs
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
outputs = [
OutputParam(
name="image_height",
type_hint=List[int],
@@ -427,11 +598,43 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
),
]
# `height`/`width` are updated if any image latent inputs are provided
if len(self._image_latent_inputs) > 0:
outputs.append(
OutputParam(name="height", type_hint=int, description="if not provided, updated to image height")
)
outputs.append(
OutputParam(name="width", type_hint=int, description="if not provided, updated to image width")
)
# image latent inputs are modified in place (patchified, concatenated, and batch-expanded)
for input_param in self._image_latent_inputs:
outputs.append(
OutputParam(
name=input_param.name,
type_hint=input_param.type_hint,
description=input_param.description + " (patchified, concatenated, and batch-expanded)",
)
)
# additional batch inputs (batch-expanded only)
for input_param in self._additional_batch_inputs:
outputs.append(
OutputParam(
name=input_param.name,
type_hint=input_param.type_hint,
description=input_param.description + " (batch-expanded)",
)
)
return outputs
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs
for image_latent_input_name in self._image_latent_inputs:
for input_param in self._image_latent_inputs:
image_latent_input_name = input_param.name
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
@@ -476,7 +679,8 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
for input_param in self._additional_batch_inputs:
input_name = input_param.name
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
@@ -494,22 +698,75 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
return components, state
# YiYi TODO: support define config default component from the ModularPipeline level.
# it is same as QwenImageAdditionalInputsStep, but with layered pachifier.
# same as QwenImageAdditionalInputsStep, but with layered pachifier.
# auto_docstring
class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
"""Input step for QwenImage Layered: update height/width, expand batch, patchify with layered pachifier."""
"""
Input processing step for Layered that:
1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch
size
2. For additional batch inputs: Expands batch dimensions to match final batch size
Configured inputs:
- Image latent inputs: ['image_latents']
This block should be placed after the encoder steps and the text input step.
Components:
pachifier (`QwenImageLayeredPachifier`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
batch_size (`int`, *optional*, defaults to 1):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
be generated in input step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
Outputs:
image_height (`int`):
The image height calculated from the image latents dimension
image_width (`int`):
The image width calculated from the image latents dimension
height (`int`):
if not provided, updated to image height
width (`int`):
if not provided, updated to image width
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified
with layered pachifier and batch-expanded)
"""
model_name = "qwenimage-layered"
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
image_latent_inputs: Optional[List[InputParam]] = None,
additional_batch_inputs: Optional[List[InputParam]] = None,
):
if image_latent_inputs is None:
image_latent_inputs = [InputParam.template("image_latents")]
if additional_batch_inputs is None:
additional_batch_inputs = []
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}")
else:
for input_param in image_latent_inputs:
if not isinstance(input_param, InputParam):
raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}")
if not isinstance(additional_batch_inputs, list):
additional_batch_inputs = [additional_batch_inputs]
raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}")
else:
for input_param in additional_batch_inputs:
if not isinstance(input_param, InputParam):
raise ValueError(
f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}"
)
self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
@@ -527,9 +784,9 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
if self._image_latent_inputs:
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}"
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}"
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
@@ -544,21 +801,18 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="batch_size", required=True),
InputParam.template("num_images_per_prompt"),
InputParam.template("batch_size"),
]
# default is `image_latents`
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
inputs += self._image_latent_inputs + self._additional_batch_inputs
return inputs
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
outputs = [
OutputParam(
name="image_height",
type_hint=int,
@@ -569,15 +823,44 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
type_hint=int,
description="The image width calculated from the image latents dimension",
),
OutputParam(name="height", type_hint=int, description="The height of the image output"),
OutputParam(name="width", type_hint=int, description="The width of the image output"),
]
if len(self._image_latent_inputs) > 0:
outputs.append(
OutputParam(name="height", type_hint=int, description="if not provided, updated to image height")
)
outputs.append(
OutputParam(name="width", type_hint=int, description="if not provided, updated to image width")
)
# Add outputs for image latent inputs (patchified with layered pachifier and batch-expanded)
for input_param in self._image_latent_inputs:
outputs.append(
OutputParam(
name=input_param.name,
type_hint=input_param.type_hint,
description=input_param.description + " (patchified with layered pachifier and batch-expanded)",
)
)
# Add outputs for additional batch inputs (batch-expanded only)
for input_param in self._additional_batch_inputs:
outputs.append(
OutputParam(
name=input_param.name,
type_hint=input_param.type_hint,
description=input_param.description + " (batch-expanded)",
)
)
return outputs
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs
for image_latent_input_name in self._image_latent_inputs:
for input_param in self._image_latent_inputs:
image_latent_input_name = input_param.name
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
@@ -608,7 +891,8 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
setattr(block_state, image_latent_input_name, image_latent_tensor)
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
for input_param in self._additional_batch_inputs:
input_name = input_param.name
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
@@ -626,7 +910,34 @@ class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
return components, state
# auto_docstring
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
"""
prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps.
Inputs:
control_image_latents (`Tensor`):
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
step.
batch_size (`int`, *optional*, defaults to 1):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
be generated in input step.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
Outputs:
control_image_latents (`Tensor`):
The control image latents (patchified and batch-expanded).
height (`int`):
if not provided, updated to control image height
width (`int`):
if not provided, updated to control image width
"""
model_name = "qwenimage"
@property
@@ -636,11 +947,28 @@ class QwenImageControlNetInputsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="control_image_latents", required=True),
InputParam(name="batch_size", required=True),
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="height"),
InputParam(name="width"),
InputParam(
name="control_image_latents",
required=True,
type_hint=torch.Tensor,
description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.",
),
InputParam.template("batch_size"),
InputParam.template("num_images_per_prompt"),
InputParam.template("height"),
InputParam.template("width"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="control_image_latents",
type_hint=torch.Tensor,
description="The control image latents (patchified and batch-expanded).",
),
OutputParam(name="height", type_hint=int, description="if not provided, updated to control image height"),
OutputParam(name="width", type_hint=int, description="if not provided, updated to control image width"),
]
@torch.no_grad()

View File

@@ -12,14 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import PIL.Image
import torch
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam
from .before_denoise import (
QwenImageControlNetBeforeDenoiserStep,
QwenImageCreateMaskLatentsStep,
@@ -59,11 +56,91 @@ logger = logging.get_logger(__name__)
# ====================
# 1. VAE ENCODER
# 1. TEXT ENCODER
# ====================
# auto_docstring
class QwenImageAutoTextEncoderStep(AutoPipelineBlocks):
"""
Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block.
Components:
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
The tokenizer to use guider (`ClassifierFreeGuidance`)
Inputs:
prompt (`str`, *optional*):
The prompt or prompts to guide image generation.
negative_prompt (`str`, *optional*):
The prompt or prompts not to guide the image generation.
max_sequence_length (`int`, *optional*, defaults to 1024):
Maximum sequence length for prompt encoding.
Outputs:
prompt_embeds (`Tensor`):
The prompt embeddings.
prompt_embeds_mask (`Tensor`):
The encoder attention mask.
negative_prompt_embeds (`Tensor`):
The negative prompt embeddings.
negative_prompt_embeds_mask (`Tensor`):
The negative prompt embeddings mask.
"""
model_name = "qwenimage"
block_classes = [QwenImageTextEncoderStep()]
block_names = ["text_encoder"]
block_trigger_inputs = ["prompt"]
@property
def description(self) -> str:
return "Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block."
" - `QwenImageTextEncoderStep` (text_encoder) is used when `prompt` is provided."
" - if `prompt` is not provided, step will be skipped."
# ====================
# 2. VAE ENCODER
# ====================
# auto_docstring
class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
"""
This step is used for processing image and mask inputs for inpainting tasks. It:
- Resizes the image to the target size, based on `height` and `width`.
- Processes and updates `image` and `mask_image`.
- Creates `image_latents`.
Components:
image_mask_processor (`InpaintProcessor`) vae (`AutoencoderKLQwenImage`)
Inputs:
mask_image (`Image`):
Mask image for inpainting.
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
padding_mask_crop (`int`, *optional*):
Padding for mask cropping in inpainting.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
Outputs:
processed_image (`Tensor`):
The processed image
processed_mask_image (`Tensor`):
The processed mask image
mask_overlay_kwargs (`Dict`):
The kwargs for the postprocess step to apply the mask overlay
image_latents (`Tensor`):
The latent representation of the input image.
"""
model_name = "qwenimage"
block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()]
block_names = ["preprocess", "encode"]
@@ -78,7 +155,31 @@ class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
)
# auto_docstring
class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
"""
Vae encoder step that preprocess andencode the image inputs into their latent representations.
Components:
image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`)
Inputs:
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
Outputs:
processed_image (`Tensor`):
The processed image
image_latents (`Tensor`):
The latent representation of the input image.
"""
model_name = "qwenimage"
block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()]
@@ -89,7 +190,6 @@ class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
# Auto VAE encoder
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
block_names = ["inpaint", "img2img"]
@@ -107,7 +207,33 @@ class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
# optional controlnet vae encoder
# auto_docstring
class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
"""
Vae encoder step that encode the image inputs into their latent representations.
This is an auto pipeline block.
- `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.
- if `control_image` is not provided, step will be skipped.
Components:
vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor
(`VaeImageProcessor`)
Inputs:
control_image (`Image`, *optional*):
Control image for ControlNet conditioning.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
Outputs:
control_image_latents (`Tensor`):
The latents representing the control image
"""
block_classes = [QwenImageControlNetVaeEncoderStep]
block_names = ["controlnet"]
block_trigger_inputs = ["control_image"]
@@ -123,14 +249,65 @@ class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
# ====================
# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
# ====================
# assemble input steps
# auto_docstring
class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
"""
Input step that prepares the inputs for the img2img denoising step. It:
Components:
pachifier (`QwenImagePachifier`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
Outputs:
batch_size (`int`):
The batch size of the prompt embeddings
dtype (`dtype`):
The data type of the prompt embeddings
prompt_embeds (`Tensor`):
The prompt embeddings. (batch-expanded)
prompt_embeds_mask (`Tensor`):
The encoder attention mask. (batch-expanded)
negative_prompt_embeds (`Tensor`):
The negative prompt embeddings. (batch-expanded)
negative_prompt_embeds_mask (`Tensor`):
The negative prompt embeddings mask. (batch-expanded)
image_height (`int`):
The image height calculated from the image latents dimension
image_width (`int`):
The image width calculated from the image latents dimension
height (`int`):
if not provided, updated to image height
width (`int`):
if not provided, updated to image width
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
batch-expanded)
"""
model_name = "qwenimage"
block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"])]
block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep()]
block_names = ["text_inputs", "additional_inputs"]
@property
@@ -140,12 +317,69 @@ class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
" - update height/width based `image_latents`, patchify `image_latents`."
# auto_docstring
class QwenImageInpaintInputStep(SequentialPipelineBlocks):
"""
Input step that prepares the inputs for the inpainting denoising step. It:
Components:
pachifier (`QwenImagePachifier`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`, *optional*):
image latents used to guide the image generation. Can be generated from vae_encoder step.
processed_mask_image (`Tensor`, *optional*):
The processed mask image
Outputs:
batch_size (`int`):
The batch size of the prompt embeddings
dtype (`dtype`):
The data type of the prompt embeddings
prompt_embeds (`Tensor`):
The prompt embeddings. (batch-expanded)
prompt_embeds_mask (`Tensor`):
The encoder attention mask. (batch-expanded)
negative_prompt_embeds (`Tensor`):
The negative prompt embeddings. (batch-expanded)
negative_prompt_embeds_mask (`Tensor`):
The negative prompt embeddings mask. (batch-expanded)
image_height (`int`):
The image height calculated from the image latents dimension
image_width (`int`):
The image width calculated from the image latents dimension
height (`int`):
if not provided, updated to image height
width (`int`):
if not provided, updated to image width
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
batch-expanded)
processed_mask_image (`Tensor`):
The processed mask image (batch-expanded)
"""
model_name = "qwenimage"
block_classes = [
QwenImageTextInputsStep(),
QwenImageAdditionalInputsStep(
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
additional_batch_inputs=[
InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image")
]
),
]
block_names = ["text_inputs", "additional_inputs"]
@@ -158,7 +392,42 @@ class QwenImageInpaintInputStep(SequentialPipelineBlocks):
# assemble prepare latents steps
# auto_docstring
class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
"""
This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:
- Add noise to the image latents to create the latents input for the denoiser.
- Create the pachified latents `mask` based on the processedmask image.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`)
Inputs:
latents (`Tensor`):
The initial random noised, can be generated in prepare latent step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be
generated from vae encoder and updated in input step.)
timesteps (`Tensor`):
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
processed_mask_image (`Tensor`):
The processed mask to use for the inpainting process.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
dtype (`dtype`, *optional*, defaults to torch.float32):
The dtype of the model inputs, can be generated in input step.
Outputs:
initial_noise (`Tensor`):
The initial random noised used for inpainting denoising.
latents (`Tensor`):
The scaled noisy latents to use for inpainting/image-to-image denoising.
mask (`Tensor`):
The mask to use for the inpainting process.
"""
model_name = "qwenimage"
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
block_names = ["add_noise_to_latents", "create_mask_latents"]
@@ -176,7 +445,49 @@ class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
# Qwen Image (text2image)
# auto_docstring
class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
"""
step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs
(timesteps, latents, rope inputs etc.).
Components:
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage"
block_classes = [
QwenImageTextInputsStep(),
@@ -199,9 +510,63 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
def description(self):
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
# Qwen Image (inpainting)
# auto_docstring
class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
"""
Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint
task.
Components:
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`, *optional*):
image latents used to guide the image generation. Can be generated from vae_encoder step.
processed_mask_image (`Tensor`, *optional*):
The processed mask image
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage"
block_classes = [
QwenImageInpaintInputStep(),
@@ -226,9 +591,61 @@ class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
# Qwen Image (image2image)
# auto_docstring
class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
"""
Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img
task.
Components:
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage"
block_classes = [
QwenImageImg2ImgInputStep(),
@@ -253,9 +670,66 @@ class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
# Qwen Image (text2image) with controlnet
# auto_docstring
class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
"""
step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs
(timesteps, latents, rope inputs etc.).
Components:
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet
(`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
control_image_latents (`Tensor`):
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
control_guidance_start (`float`, *optional*, defaults to 0.0):
When to start applying ControlNet.
control_guidance_end (`float`, *optional*, defaults to 1.0):
When to stop applying ControlNet.
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Scale for ControlNet conditioning.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage"
block_classes = [
QwenImageTextInputsStep(),
@@ -282,9 +756,72 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
def description(self):
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
# Qwen Image (inpainting) with controlnet
# auto_docstring
class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
"""
Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint
task.
Components:
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet
(`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`, *optional*):
image latents used to guide the image generation. Can be generated from vae_encoder step.
processed_mask_image (`Tensor`, *optional*):
The processed mask image
control_image_latents (`Tensor`):
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
step.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
control_guidance_start (`float`, *optional*, defaults to 0.0):
When to start applying ControlNet.
control_guidance_end (`float`, *optional*, defaults to 1.0):
When to stop applying ControlNet.
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Scale for ControlNet conditioning.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage"
block_classes = [
QwenImageInpaintInputStep(),
@@ -313,9 +850,70 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
# Qwen Image (image2image) with controlnet
# auto_docstring
class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
"""
Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img
task.
Components:
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet
(`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
control_image_latents (`Tensor`):
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
step.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
control_guidance_start (`float`, *optional*, defaults to 0.0):
When to start applying ControlNet.
control_guidance_end (`float`, *optional*, defaults to 1.0):
When to stop applying ControlNet.
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Scale for ControlNet conditioning.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage"
block_classes = [
QwenImageImg2ImgInputStep(),
@@ -344,6 +942,12 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
# Auto denoise step for QwenImage
class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
@@ -402,19 +1006,36 @@ class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
@property
def outputs(self):
return [
OutputParam(
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
),
OutputParam.template("latents"),
]
# ====================
# 3. DECODE
# 4. DECODE
# ====================
# standard decode step works for most tasks except for inpaint
# auto_docstring
class QwenImageDecodeStep(SequentialPipelineBlocks):
"""
Decode step that decodes the latents to images and postprocess the generated image.
Components:
vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`)
Inputs:
latents (`Tensor`):
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
step.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`List`):
Generated images. (tensor output of the vae decoder.)
"""
model_name = "qwenimage"
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@@ -425,7 +1046,30 @@ class QwenImageDecodeStep(SequentialPipelineBlocks):
# Inpaint decode step
# auto_docstring
class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
"""
Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask
overally to the original image.
Components:
vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`)
Inputs:
latents (`Tensor`):
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
step.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
mask_overlay_kwargs (`Dict`, *optional*):
The kwargs for the postprocess step to apply the mask overlay. generated in
InpaintProcessImagesInputStep.
Outputs:
images (`List`):
Generated images. (tensor output of the vae decoder.)
"""
model_name = "qwenimage"
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@@ -452,11 +1096,11 @@ class QwenImageAutoDecodeStep(AutoPipelineBlocks):
# ====================
# 4. AUTO BLOCKS & PRESETS
# 5. AUTO BLOCKS & PRESETS
# ====================
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageTextEncoderStep()),
("text_encoder", QwenImageAutoTextEncoderStep()),
("vae_encoder", QwenImageAutoVaeEncoderStep()),
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
("denoise", QwenImageAutoCoreDenoiseStep()),
@@ -465,7 +1109,89 @@ AUTO_BLOCKS = InsertableDict(
)
# auto_docstring
class QwenImageAutoBlocks(SequentialPipelineBlocks):
"""
Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
- for image-to-image generation, you need to provide `image`
- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.
- to run the controlnet workflow, you need to provide `control_image`
- for text-to-image generation, all you need to provide is `prompt`
Components:
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
The tokenizer to use guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae
(`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) controlnet (`QwenImageControlNetModel`)
control_image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler
(`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
Inputs:
prompt (`str`, *optional*):
The prompt or prompts to guide image generation.
negative_prompt (`str`, *optional*):
The prompt or prompts not to guide the image generation.
max_sequence_length (`int`, *optional*, defaults to 1024):
Maximum sequence length for prompt encoding.
mask_image (`Image`, *optional*):
Mask image for inpainting.
image (`Union[Image, List]`, *optional*):
Reference image(s) for denoising. Can be a single image or list of images.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
padding_mask_crop (`int`, *optional*):
Padding for mask cropping in inpainting.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
control_image (`Image`, *optional*):
Control image for ControlNet conditioning.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
latents (`Tensor`):
Pre-generated noisy latents for image generation.
num_inference_steps (`int`):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
image_latents (`Tensor`, *optional*):
image latents used to guide the image generation. Can be generated from vae_encoder step.
processed_mask_image (`Tensor`, *optional*):
The processed mask image
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
control_image_latents (`Tensor`, *optional*):
The control image latents to use for the denoising process. Can be generated in controlnet vae encoder
step.
control_guidance_start (`float`, *optional*, defaults to 0.0):
When to start applying ControlNet.
control_guidance_end (`float`, *optional*, defaults to 1.0):
When to stop applying ControlNet.
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Scale for ControlNet conditioning.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
mask_overlay_kwargs (`Dict`, *optional*):
The kwargs for the postprocess step to apply the mask overlay. generated in
InpaintProcessImagesInputStep.
Outputs:
images (`List`):
Generated images.
"""
model_name = "qwenimage"
block_classes = AUTO_BLOCKS.values()
@@ -476,7 +1202,7 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
return (
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
+ "- for image-to-image generation, you need to provide `image`\n"
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n"
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
)
@@ -484,5 +1210,5 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
OutputParam.template("images"),
]

View File

@@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from typing import Optional
import PIL.Image
import torch
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam
from .before_denoise import (
QwenImageCreateMaskLatentsStep,
QwenImageEditRoPEInputsStep,
@@ -59,8 +58,35 @@ logger = logging.get_logger(__name__)
# ====================
# auto_docstring
class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
"""VL encoder that takes both image and text prompts."""
"""
QwenImage-Edit VL encoder step that encode the image and text prompts together.
Components:
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`)
Inputs:
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
prompt (`str`):
The prompt or prompts to guide image generation.
negative_prompt (`str`, *optional*):
The prompt or prompts not to guide the image generation.
Outputs:
resized_image (`List`):
The resized images
prompt_embeds (`Tensor`):
The prompt embeddings.
prompt_embeds_mask (`Tensor`):
The encoder attention mask.
negative_prompt_embeds (`Tensor`):
The negative prompt embeddings.
negative_prompt_embeds_mask (`Tensor`):
The negative prompt embeddings mask.
"""
model_name = "qwenimage-edit"
block_classes = [
@@ -80,7 +106,30 @@ class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
# Edit VAE encoder
# auto_docstring
class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
"""
Vae encoder step that encode the image inputs into their latent representations.
Components:
image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae
(`AutoencoderKLQwenImage`)
Inputs:
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
Outputs:
resized_image (`List`):
The resized images
processed_image (`Tensor`):
The processed image
image_latents (`Tensor`):
The latent representation of the input image.
"""
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditResizeStep(),
@@ -95,12 +144,46 @@ class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
# Edit Inpaint VAE encoder
# auto_docstring
class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
"""
This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:
- resize the image for target area (1024 * 1024) while maintaining the aspect ratio.
- process the resized image and mask image.
- create image latents.
Components:
image_resize_processor (`VaeImageProcessor`) image_mask_processor (`InpaintProcessor`) vae
(`AutoencoderKLQwenImage`)
Inputs:
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
mask_image (`Image`):
Mask image for inpainting.
padding_mask_crop (`int`, *optional*):
Padding for mask cropping in inpainting.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
Outputs:
resized_image (`List`):
The resized images
processed_image (`Tensor`):
The processed image
processed_mask_image (`Tensor`):
The processed mask image
mask_overlay_kwargs (`Dict`):
The kwargs for the postprocess step to apply the mask overlay
image_latents (`Tensor`):
The latent representation of the input image.
"""
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditResizeStep(),
QwenImageEditInpaintProcessImagesInputStep(),
QwenImageVaeEncoderStep(input_name="processed_image", output_name="image_latents"),
QwenImageVaeEncoderStep(),
]
block_names = ["resize", "preprocess", "encode"]
@@ -137,11 +220,64 @@ class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
# assemble input steps
# auto_docstring
class QwenImageEditInputStep(SequentialPipelineBlocks):
"""
Input step that prepares the inputs for the edit denoising step. It:
- make sure the text embeddings have consistent batch size as well as the additional inputs.
- update height/width based `image_latents`, patchify `image_latents`.
Components:
pachifier (`QwenImagePachifier`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
Outputs:
batch_size (`int`):
The batch size of the prompt embeddings
dtype (`dtype`):
The data type of the prompt embeddings
prompt_embeds (`Tensor`):
The prompt embeddings. (batch-expanded)
prompt_embeds_mask (`Tensor`):
The encoder attention mask. (batch-expanded)
negative_prompt_embeds (`Tensor`):
The negative prompt embeddings. (batch-expanded)
negative_prompt_embeds_mask (`Tensor`):
The negative prompt embeddings mask. (batch-expanded)
image_height (`int`):
The image height calculated from the image latents dimension
image_width (`int`):
The image width calculated from the image latents dimension
height (`int`):
if not provided, updated to image height
width (`int`):
if not provided, updated to image width
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
batch-expanded)
"""
model_name = "qwenimage-edit"
block_classes = [
QwenImageTextInputsStep(),
QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
QwenImageAdditionalInputsStep(),
]
block_names = ["text_inputs", "additional_inputs"]
@@ -154,12 +290,71 @@ class QwenImageEditInputStep(SequentialPipelineBlocks):
)
# auto_docstring
class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
"""
Input step that prepares the inputs for the edit inpaint denoising step. It:
- make sure the text embeddings have consistent batch size as well as the additional inputs.
- update height/width based `image_latents`, patchify `image_latents`.
Components:
pachifier (`QwenImagePachifier`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
processed_mask_image (`Tensor`, *optional*):
The processed mask image
Outputs:
batch_size (`int`):
The batch size of the prompt embeddings
dtype (`dtype`):
The data type of the prompt embeddings
prompt_embeds (`Tensor`):
The prompt embeddings. (batch-expanded)
prompt_embeds_mask (`Tensor`):
The encoder attention mask. (batch-expanded)
negative_prompt_embeds (`Tensor`):
The negative prompt embeddings. (batch-expanded)
negative_prompt_embeds_mask (`Tensor`):
The negative prompt embeddings mask. (batch-expanded)
image_height (`int`):
The image height calculated from the image latents dimension
image_width (`int`):
The image width calculated from the image latents dimension
height (`int`):
if not provided, updated to image height
width (`int`):
if not provided, updated to image width
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and
batch-expanded)
processed_mask_image (`Tensor`):
The processed mask image (batch-expanded)
"""
model_name = "qwenimage-edit"
block_classes = [
QwenImageTextInputsStep(),
QwenImageAdditionalInputsStep(
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
additional_batch_inputs=[
InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image")
]
),
]
block_names = ["text_inputs", "additional_inputs"]
@@ -174,7 +369,42 @@ class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
# assemble prepare latents steps
# auto_docstring
class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
"""
This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:
- Add noise to the image latents to create the latents input for the denoiser.
- Create the patchified latents `mask` based on the processed mask image.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`)
Inputs:
latents (`Tensor`):
The initial random noised, can be generated in prepare latent step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be
generated from vae encoder and updated in input step.)
timesteps (`Tensor`):
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
processed_mask_image (`Tensor`):
The processed mask to use for the inpainting process.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
dtype (`dtype`, *optional*, defaults to torch.float32):
The dtype of the model inputs, can be generated in input step.
Outputs:
initial_noise (`Tensor`):
The initial random noised used for inpainting denoising.
latents (`Tensor`):
The scaled noisy latents to use for inpainting/image-to-image denoising.
mask (`Tensor`):
The mask to use for the inpainting process.
"""
model_name = "qwenimage-edit"
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
block_names = ["add_noise_to_latents", "create_mask_latents"]
@@ -189,7 +419,50 @@ class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
# Qwen Image Edit (image2image) core denoise step
# auto_docstring
class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
"""
Core denoising workflow for QwenImage-Edit edit (img2img) task.
Components:
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditInputStep(),
@@ -212,9 +485,62 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
def description(self):
return "Core denoising workflow for QwenImage-Edit edit (img2img) task."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
# Qwen Image Edit (inpainting) core denoise step
# auto_docstring
class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
"""
Core denoising workflow for QwenImage-Edit edit inpaint task.
Components:
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
processed_mask_image (`Tensor`, *optional*):
The processed mask image
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditInpaintInputStep(),
@@ -239,6 +565,12 @@ class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
def description(self):
return "Core denoising workflow for QwenImage-Edit edit inpaint task."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
# Auto core denoise step for QwenImage Edit
class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
@@ -267,6 +599,12 @@ class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
"Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit."
)
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
# ====================
# 4. DECODE
@@ -274,7 +612,26 @@ class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
# Decode step (standard)
# auto_docstring
class QwenImageEditDecodeStep(SequentialPipelineBlocks):
"""
Decode step that decodes the latents to images and postprocess the generated image.
Components:
vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`)
Inputs:
latents (`Tensor`):
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
step.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`List`):
Generated images. (tensor output of the vae decoder.)
"""
model_name = "qwenimage-edit"
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@@ -285,7 +642,30 @@ class QwenImageEditDecodeStep(SequentialPipelineBlocks):
# Inpaint decode step
# auto_docstring
class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks):
"""
Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask
overlay to the original image.
Components:
vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`)
Inputs:
latents (`Tensor`):
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
step.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
mask_overlay_kwargs (`Dict`, *optional*):
The kwargs for the postprocess step to apply the mask overlay. generated in
InpaintProcessImagesInputStep.
Outputs:
images (`List`):
Generated images. (tensor output of the vae decoder.)
"""
model_name = "qwenimage-edit"
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@@ -313,9 +693,7 @@ class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
@property
def outputs(self):
return [
OutputParam(
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
),
OutputParam.template("latents"),
]
@@ -333,7 +711,66 @@ EDIT_AUTO_BLOCKS = InsertableDict(
)
# auto_docstring
class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
"""
Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.
- for edit (img2img) generation, you need to provide `image`
- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide
`padding_mask_crop`
Components:
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae
(`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler
(`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
Inputs:
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
prompt (`str`):
The prompt or prompts to guide image generation.
negative_prompt (`str`, *optional*):
The prompt or prompts not to guide the image generation.
mask_image (`Image`, *optional*):
Mask image for inpainting.
padding_mask_crop (`int`, *optional*):
Padding for mask cropping in inpainting.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
processed_mask_image (`Tensor`, *optional*):
The processed mask image
latents (`Tensor`):
Pre-generated noisy latents for image generation.
num_inference_steps (`int`):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
mask_overlay_kwargs (`Dict`, *optional*):
The kwargs for the postprocess step to apply the mask overlay. generated in
InpaintProcessImagesInputStep.
Outputs:
images (`List`):
Generated images.
"""
model_name = "qwenimage-edit"
block_classes = EDIT_AUTO_BLOCKS.values()
block_names = EDIT_AUTO_BLOCKS.keys()
@@ -349,5 +786,5 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
OutputParam.template("images"),
]

View File

@@ -12,11 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import PIL.Image
import torch
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
@@ -53,12 +48,41 @@ logger = logging.get_logger(__name__)
# ====================
# auto_docstring
class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
"""VL encoder that takes both image and text prompts. Uses 384x384 target area."""
"""
QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together.
Components:
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`)
Inputs:
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
prompt (`str`):
The prompt or prompts to guide image generation.
negative_prompt (`str`, *optional*):
The prompt or prompts not to guide the image generation.
Outputs:
resized_image (`List`):
Images resized to 1024x1024 target area for VAE encoding
resized_cond_image (`List`):
Images resized to 384x384 target area for VL text encoding
prompt_embeds (`Tensor`):
The prompt embeddings.
prompt_embeds_mask (`Tensor`):
The encoder attention mask.
negative_prompt_embeds (`Tensor`):
The negative prompt embeddings.
negative_prompt_embeds_mask (`Tensor`):
The negative prompt embeddings mask.
"""
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageEditPlusResizeStep(target_area=384 * 384, output_name="resized_cond_image"),
QwenImageEditPlusResizeStep(),
QwenImageEditPlusTextEncoderStep(),
]
block_names = ["resize", "encode"]
@@ -73,12 +97,36 @@ class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
# ====================
# auto_docstring
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
"""VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area."""
"""
VAE encoder step that encodes image inputs into latent representations.
Each image is resized independently based on its own aspect ratio to 1024x1024 target area.
Components:
image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae
(`AutoencoderKLQwenImage`)
Inputs:
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
Outputs:
resized_image (`List`):
Images resized to 1024x1024 target area for VAE encoding
resized_cond_image (`List`):
Images resized to 384x384 target area for VL text encoding
processed_image (`Tensor`):
The processed image
image_latents (`Tensor`):
The latent representation of the input image.
"""
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageEditPlusResizeStep(target_area=1024 * 1024, output_name="resized_image"),
QwenImageEditPlusResizeStep(),
QwenImageEditPlusProcessImagesInputStep(),
QwenImageVaeEncoderStep(),
]
@@ -98,11 +146,66 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
# assemble input steps
# auto_docstring
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
"""
Input step that prepares the inputs for the Edit Plus denoising step. It:
- Standardizes text embeddings batch size.
- Processes list of image latents: patchifies, concatenates along dim=1, expands batch.
- Outputs lists of image_height/image_width for RoPE calculation.
- Defaults height/width from last image in the list.
Components:
pachifier (`QwenImagePachifier`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
Outputs:
batch_size (`int`):
The batch size of the prompt embeddings
dtype (`dtype`):
The data type of the prompt embeddings
prompt_embeds (`Tensor`):
The prompt embeddings. (batch-expanded)
prompt_embeds_mask (`Tensor`):
The encoder attention mask. (batch-expanded)
negative_prompt_embeds (`Tensor`):
The negative prompt embeddings. (batch-expanded)
negative_prompt_embeds_mask (`Tensor`):
The negative prompt embeddings mask. (batch-expanded)
image_height (`List`):
The image heights calculated from the image latents dimension
image_width (`List`):
The image widths calculated from the image latents dimension
height (`int`):
if not provided, updated to image height
width (`int`):
if not provided, updated to image width
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified,
concatenated, and batch-expanded)
"""
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageTextInputsStep(),
QwenImageEditPlusAdditionalInputsStep(image_latent_inputs=["image_latents"]),
QwenImageEditPlusAdditionalInputsStep(),
]
block_names = ["text_inputs", "additional_inputs"]
@@ -118,7 +221,50 @@ class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
# Qwen Image Edit Plus (image2image) core denoise step
# auto_docstring
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
"""
Core denoising workflow for QwenImage-Edit Plus edit (img2img) task.
Components:
pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageEditPlusInputStep(),
@@ -144,9 +290,7 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam(
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
),
OutputParam.template("latents"),
]
@@ -155,7 +299,26 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
# ====================
# auto_docstring
class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks):
"""
Decode step that decodes the latents to images and postprocesses the generated image.
Components:
vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`)
Inputs:
latents (`Tensor`):
The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise
step.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`List`):
Generated images. (tensor output of the vae decoder.)
"""
model_name = "qwenimage-edit-plus"
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@@ -179,7 +342,53 @@ EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
)
# auto_docstring
class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
"""
Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.
- `image` is required input (can be single image or list of images).
- Each image is resized independently based on its own aspect ratio.
- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area.
Components:
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_processor (`VaeImageProcessor`) vae
(`AutoencoderKLQwenImage`) pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`)
transformer (`QwenImageTransformer2DModel`)
Inputs:
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
prompt (`str`):
The prompt or prompts to guide image generation.
negative_prompt (`str`, *optional*):
The prompt or prompts not to guide the image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`List`):
Generated images.
"""
model_name = "qwenimage-edit-plus"
block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
@@ -196,5 +405,5 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
OutputParam.template("images"),
]

View File

@@ -12,12 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import PIL.Image
import torch
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
@@ -55,8 +49,44 @@ logger = logging.get_logger(__name__)
# ====================
# auto_docstring
class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
"""Text encoder that takes text prompt, will generate a prompt based on image if not provided."""
"""
QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not
provided.
Components:
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
(`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`)
Inputs:
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
resolution (`int`, *optional*, defaults to 640):
The target area to resize the image to, can be 1024 or 640
prompt (`str`, *optional*):
The prompt or prompts to guide image generation.
use_en_prompt (`bool`, *optional*, defaults to False):
Whether to use English prompt template
negative_prompt (`str`, *optional*):
The prompt or prompts not to guide the image generation.
max_sequence_length (`int`, *optional*, defaults to 1024):
Maximum sequence length for prompt encoding.
Outputs:
resized_image (`List`):
The resized images
prompt (`str`):
The prompt or prompts to guide image generation. If not provided, updated using image caption
prompt_embeds (`Tensor`):
The prompt embeddings.
prompt_embeds_mask (`Tensor`):
The encoder attention mask.
negative_prompt_embeds (`Tensor`):
The negative prompt embeddings.
negative_prompt_embeds_mask (`Tensor`):
The negative prompt embeddings mask.
"""
model_name = "qwenimage-layered"
block_classes = [
@@ -77,7 +107,32 @@ class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
# Edit VAE encoder
# auto_docstring
class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
"""
Vae encoder step that encode the image inputs into their latent representations.
Components:
image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae
(`AutoencoderKLQwenImage`)
Inputs:
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
resolution (`int`, *optional*, defaults to 640):
The target area to resize the image to, can be 1024 or 640
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
Outputs:
resized_image (`List`):
The resized images
processed_image (`Tensor`):
The processed image
image_latents (`Tensor`):
The latent representation of the input image.
"""
model_name = "qwenimage-layered"
block_classes = [
QwenImageLayeredResizeStep(),
@@ -98,11 +153,60 @@ class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
# assemble input steps
# auto_docstring
class QwenImageLayeredInputStep(SequentialPipelineBlocks):
"""
Input step that prepares the inputs for the layered denoising step. It:
- make sure the text embeddings have consistent batch size as well as the additional inputs.
- update height/width based `image_latents`, patchify `image_latents`.
Components:
pachifier (`QwenImageLayeredPachifier`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
Outputs:
batch_size (`int`):
The batch size of the prompt embeddings
dtype (`dtype`):
The data type of the prompt embeddings
prompt_embeds (`Tensor`):
The prompt embeddings. (batch-expanded)
prompt_embeds_mask (`Tensor`):
The encoder attention mask. (batch-expanded)
negative_prompt_embeds (`Tensor`):
The negative prompt embeddings. (batch-expanded)
negative_prompt_embeds_mask (`Tensor`):
The negative prompt embeddings mask. (batch-expanded)
image_height (`int`):
The image height calculated from the image latents dimension
image_width (`int`):
The image width calculated from the image latents dimension
height (`int`):
if not provided, updated to image height
width (`int`):
if not provided, updated to image width
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified
with layered pachifier and batch-expanded)
"""
model_name = "qwenimage-layered"
block_classes = [
QwenImageTextInputsStep(),
QwenImageLayeredAdditionalInputsStep(image_latent_inputs=["image_latents"]),
QwenImageLayeredAdditionalInputsStep(),
]
block_names = ["text_inputs", "additional_inputs"]
@@ -116,7 +220,48 @@ class QwenImageLayeredInputStep(SequentialPipelineBlocks):
# Qwen Image Layered (image2image) core denoise step
# auto_docstring
class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
"""
Core denoising workflow for QwenImage-Layered img2img task.
Components:
pachifier (`QwenImageLayeredPachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
(`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`)
Inputs:
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_embeds_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_embeds_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
layers (`int`, *optional*, defaults to 4):
Number of layers to extract from the image
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "qwenimage-layered"
block_classes = [
QwenImageLayeredInputStep(),
@@ -142,9 +287,7 @@ class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam(
name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step"
),
OutputParam.template("latents"),
]
@@ -162,7 +305,54 @@ LAYERED_AUTO_BLOCKS = InsertableDict(
)
# auto_docstring
class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
"""
Auto Modular pipeline for layered denoising tasks using QwenImage-Layered.
Components:
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
(`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`)
image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) pachifier (`QwenImageLayeredPachifier`)
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
Inputs:
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
resolution (`int`, *optional*, defaults to 640):
The target area to resize the image to, can be 1024 or 640
prompt (`str`, *optional*):
The prompt or prompts to guide image generation.
use_en_prompt (`bool`, *optional*, defaults to False):
Whether to use English prompt template
negative_prompt (`str`, *optional*):
The prompt or prompts not to guide the image generation.
max_sequence_length (`int`, *optional*, defaults to 1024):
Maximum sequence length for prompt encoding.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
layers (`int`, *optional*, defaults to 4):
Number of layers to extract from the image
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`List`):
Generated images.
"""
model_name = "qwenimage-layered"
block_classes = LAYERED_AUTO_BLOCKS.values()
block_names = LAYERED_AUTO_BLOCKS.keys()
@@ -174,5 +364,5 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
@property
def outputs(self):
return [
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"),
OutputParam.template("images"),
]

View File

@@ -131,7 +131,7 @@ class ZImageLoopDenoiser(ModularPipelineBlocks):
),
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
description="The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
]
guider_input_names = []

View File

@@ -84,7 +84,6 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
>>> from diffusers.utils import load_image
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained(

View File

@@ -53,7 +53,6 @@ EXAMPLE_DOC_STRING = """
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> from diffusers import HiDreamImagePipeline
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
... "meta-llama/Meta-Llama-3.1-8B-Instruct",

View File

@@ -85,7 +85,6 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL
>>> from diffusers.utils import load_image
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
>>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
>>> controlnet = ControlNetModel.from_pretrained(

View File

@@ -459,7 +459,6 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
>>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline
>>> import torch
>>> pipeline = StableDiffusionPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
... )

View File

@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
import math
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -36,27 +36,30 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
sigma_min (`float`, *optional*, defaults to 0.3):
sigma_min (`float`, defaults to `0.3`):
Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1].
sigma_max (`float`, *optional*, defaults to 500):
sigma_max (`float`, defaults to `500`):
Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1].
sigma_data (`float`, *optional*, defaults to 1.0):
sigma_data (`float`, defaults to `1.0`):
The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1].
sigma_schedule (`str`, *optional*, defaults to `exponential`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
(https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
sigma_schedule (`str`, defaults to `"exponential"`):
Sigma schedule to compute the `sigmas`. Must be one of `"exponential"` or `"karras"`. The exponential
schedule was incorporated in [stabilityai/cosxl](https://huggingface.co/stabilityai/cosxl). The Karras
schedule is introduced in the [EDM](https://huggingface.co/papers/2206.00364) paper.
num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
solver_order (`int`, defaults to 2):
solver_order (`int`, defaults to `2`):
The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`.
prediction_type (`str`, defaults to `v_prediction`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
prediction_type (`str`, defaults to `"v_prediction"`):
Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion
process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper).
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
rho (`float`, defaults to `7.0`):
The parameter for calculating the Karras sigma schedule from the EDM
[paper](https://huggingface.co/papers/2206.00364).
solver_type (`str`, defaults to `"midpoint"`):
Solver type for the second-order solver. Must be one of `"midpoint"` or `"heun"`. The solver type slightly
affects the sample quality, especially for a small number of steps. It is recommended to use `"midpoint"`.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
@@ -65,8 +68,9 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
The final `sigma` value for the noise schedule during the sampling process. Must be one of `"zero"` or
`"sigma_min"`. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If
`"zero"`, the final sigma is set to 0.
"""
_compatibles = []
@@ -78,16 +82,16 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigma_min: float = 0.3,
sigma_max: float = 500,
sigma_data: float = 1.0,
sigma_schedule: str = "exponential",
sigma_schedule: Literal["exponential", "karras"] = "exponential",
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "v_prediction",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "v_prediction",
rho: float = 7.0,
solver_type: str = "midpoint",
solver_type: Literal["midpoint", "heun"] = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
):
final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
) -> None:
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.register_to_config(solver_type="midpoint")
@@ -113,26 +117,40 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
def init_noise_sigma(self) -> float:
"""
The standard deviation of the initial noise distribution.
Returns:
`float`:
The initial noise sigma value computed as `sqrt(sigma_max^2 + 1)`.
"""
return (self.config.sigma_max**2 + 1) ** 0.5
@property
def step_index(self):
def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
Returns:
`int` or `None`:
The current step index, or `None` if not yet initialized.
"""
return self._step_index
@property
def begin_index(self):
def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
Returns:
`int` or `None`:
The begin index, or `None` if not yet set.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -161,7 +179,18 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
scaled_sample = sample * c_in
return scaled_sample
def precondition_noise(self, sigma):
def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Precondition the noise level by computing a normalized timestep representation.
Args:
sigma (`float` or `torch.Tensor`):
The sigma (noise level) value to precondition.
Returns:
`torch.Tensor`:
The preconditioned noise value computed as `atan(sigma) / pi * 2`.
"""
if not isinstance(sigma, torch.Tensor):
sigma = torch.tensor([sigma])
@@ -228,12 +257,14 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = True
return sample
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
def set_timesteps(
self, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
@@ -334,7 +365,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
"""
Convert sigma values to corresponding timestep values through interpolation.
@@ -370,7 +401,19 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
t = t.reshape(sigma.shape)
return t
def _sigma_to_alpha_sigma_t(self, sigma):
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert sigma to alpha and sigma_t values for the diffusion process.
Args:
sigma (`torch.Tensor`):
The sigma (noise level) value.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing `alpha_t` (always 1 since inputs are pre-scaled) and `sigma_t` (same as input
sigma).
"""
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
sigma_t = sigma
@@ -536,7 +579,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
"""
Initialize the step_index counter for the scheduler.
@@ -557,7 +600,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
@@ -567,20 +610,19 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
@@ -702,5 +744,12 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in
def __len__(self):
def __len__(self) -> int:
"""
Returns the number of training timesteps.
Returns:
`int`:
The number of training timesteps configured for the scheduler.
"""
return self.config.num_train_timesteps

View File

@@ -32,22 +32,6 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_configure(config):
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
config.addinivalue_line("markers", "training: marks tests for training functionality")
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
config.addinivalue_line("markers", "quantization: marks tests for quantization functionality")
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")
config.addinivalue_line("markers", "slow: mark test as slow")
config.addinivalue_line("markers", "nightly: mark test as nightly")

View File

@@ -1,79 +0,0 @@
from .attention import AttentionTesterMixin
from .cache import (
CacheTesterMixin,
FasterCacheConfigMixin,
FasterCacheTesterMixin,
FirstBlockCacheConfigMixin,
FirstBlockCacheTesterMixin,
PyramidAttentionBroadcastConfigMixin,
PyramidAttentionBroadcastTesterMixin,
)
from .common import BaseModelTesterConfig, ModelTesterMixin
from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .parallelism import ContextParallelTesterMixin
from .quantization import (
BitsAndBytesCompileTesterMixin,
BitsAndBytesConfigMixin,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFConfigMixin,
GGUFTesterMixin,
ModelOptCompileTesterMixin,
ModelOptConfigMixin,
ModelOptTesterMixin,
QuantizationCompileTesterMixin,
QuantizationTesterMixin,
QuantoCompileTesterMixin,
QuantoConfigMixin,
QuantoTesterMixin,
TorchAoCompileTesterMixin,
TorchAoConfigMixin,
TorchAoTesterMixin,
)
from .single_file import SingleFileTesterMixin
from .training import TrainingTesterMixin
__all__ = [
"AttentionTesterMixin",
"BaseModelTesterConfig",
"BitsAndBytesCompileTesterMixin",
"BitsAndBytesConfigMixin",
"BitsAndBytesTesterMixin",
"CacheTesterMixin",
"ContextParallelTesterMixin",
"CPUOffloadTesterMixin",
"FasterCacheConfigMixin",
"FasterCacheTesterMixin",
"FirstBlockCacheConfigMixin",
"FirstBlockCacheTesterMixin",
"GGUFCompileTesterMixin",
"GGUFConfigMixin",
"GGUFTesterMixin",
"GroupOffloadTesterMixin",
"IPAdapterTesterMixin",
"LayerwiseCastingTesterMixin",
"LoraHotSwappingForModelTesterMixin",
"LoraTesterMixin",
"MemoryTesterMixin",
"ModelOptCompileTesterMixin",
"ModelOptConfigMixin",
"ModelOptTesterMixin",
"ModelTesterMixin",
"PyramidAttentionBroadcastConfigMixin",
"PyramidAttentionBroadcastTesterMixin",
"QuantizationCompileTesterMixin",
"QuantizationTesterMixin",
"QuantoCompileTesterMixin",
"QuantoConfigMixin",
"QuantoTesterMixin",
"SingleFileTesterMixin",
"TorchAoCompileTesterMixin",
"TorchAoConfigMixin",
"TorchAoTesterMixin",
"TorchCompileTesterMixin",
"TrainingTesterMixin",
]

View File

@@ -1,181 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import (
AttnProcessor,
)
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_attention,
torch_device,
)
@is_attention
class AttentionTesterMixin:
"""
Mixin class for testing attention processor and module functionality on models.
Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@torch.no_grad()
def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
if not hasattr(model, "fuse_qkv_projections"):
pytest.skip("Model does not support QKV projection fusion.")
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
model.fuse_qkv_projections()
has_fused_projections = False
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
has_fused_projections = True
assert module.fused_projections, "fused_projections flag should be True"
break
if has_fused_projections:
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
output_before_fusion,
output_after_fusion,
atol=atol,
rtol=rtol,
msg="Output should not change after fusing projections",
)
model.unfuse_qkv_projections()
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
assert not module.fused_projections, "fused_projections flag should be False"
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
output_before_fusion,
output_after_unfusion,
atol=atol,
rtol=rtol,
msg="Output should match original after unfusing projections",
)
def test_get_set_processor(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
# Check if model has attention processors
if not hasattr(model, "attn_processors"):
pytest.skip("Model does not have attention processors.")
# Test getting processors
processors = model.attn_processors
assert isinstance(processors, dict), "attn_processors should return a dict"
assert len(processors) > 0, "Model should have at least one attention processor"
# Test that all processors can be retrieved via get_processor
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
processor = module.get_processor()
assert processor is not None, "get_processor should return a processor"
# Test setting a new processor
new_processor = AttnProcessor()
module.set_processor(new_processor)
retrieved_processor = module.get_processor()
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
def test_attention_processor_dict(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict of new processors
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
# Set processors using dict
model.set_attn_processor(new_processors)
# Verify all processors were set
updated_processors = model.attn_processors
for key in current_processors.keys():
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
def test_attention_processor_count_mismatch_raises_error(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict with wrong number of processors
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
# Verify error is raised
with pytest.raises(ValueError) as exc_info:
model.set_attn_processor(wrong_processors)
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"

View File

@@ -1,556 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from diffusers.models.cache_utils import CacheMixin
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device
def require_cache_mixin(func):
"""Decorator to skip tests if model doesn't use CacheMixin."""
def wrapper(self, *args, **kwargs):
if not issubclass(self.model_class, CacheMixin):
pytest.skip(f"{self.model_class.__name__} does not use CacheMixin.")
return func(self, *args, **kwargs)
return wrapper
class CacheTesterMixin:
"""
Base mixin class providing common test implementations for cache testing.
Cache-specific mixins should:
1. Inherit from their respective config mixin (e.g., PyramidAttentionBroadcastConfigMixin)
2. Inherit from this mixin
3. Define the cache config to use for tests
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods in test classes:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Optional overrides:
- cache_input_key: Property returning the input tensor key to vary between passes (default: "hidden_states")
"""
@property
def cache_input_key(self):
return "hidden_states"
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def _get_cache_config(self):
"""
Get the cache config for testing.
Should be implemented by subclasses.
"""
raise NotImplementedError("Subclass must implement _get_cache_config")
def _get_hook_names(self):
"""
Get the hook names to check for this cache type.
Should be implemented by subclasses.
Returns a list of hook name strings.
"""
raise NotImplementedError("Subclass must implement _get_hook_names")
def _test_cache_enable_disable_state(self):
"""Test that cache enable/disable updates the is_cache_enabled state correctly."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
# Initially cache should not be enabled
assert not model.is_cache_enabled, "Cache should not be enabled initially."
config = self._get_cache_config()
# Enable cache
model.enable_cache(config)
assert model.is_cache_enabled, "Cache should be enabled after enable_cache()."
# Disable cache
model.disable_cache()
assert not model.is_cache_enabled, "Cache should not be enabled after disable_cache()."
def _test_cache_double_enable_raises_error(self):
"""Test that enabling cache twice raises an error."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
config = self._get_cache_config()
model.enable_cache(config)
# Trying to enable again should raise ValueError
with pytest.raises(ValueError, match="Caching has already been enabled"):
model.enable_cache(config)
# Cleanup
model.disable_cache()
def _test_cache_hooks_registered(self):
"""Test that cache hooks are properly registered and removed."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
config = self._get_cache_config()
hook_names = self._get_hook_names()
model.enable_cache(config)
# Check that at least one hook was registered
hook_count = 0
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
for hook_name in hook_names:
hook = module._diffusers_hook.get_hook(hook_name)
if hook is not None:
hook_count += 1
assert hook_count > 0, f"At least one cache hook should be registered. Hook names: {hook_names}"
# Disable and verify hooks are removed
model.disable_cache()
hook_count_after = 0
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
for hook_name in hook_names:
hook = module._diffusers_hook.get_hook(hook_name)
if hook is not None:
hook_count_after += 1
assert hook_count_after == 0, "Cache hooks should be removed after disable_cache()."
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with cache enabled."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass (vary input tensor to simulate denoising)
inputs_dict_step2 = inputs_dict.copy()
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass uses cached attention with different inputs (produces approximated output)
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_cache_context_manager(self, atol=1e-5, rtol=0):
"""Test the cache_context context manager properly isolates cache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# Run inference in first context
with model.cache_context("context_1"):
output_ctx1 = model(**inputs_dict, return_dict=False)[0]
# Run same inference in second context (cache should be reset)
with model.cache_context("context_2"):
output_ctx2 = model(**inputs_dict, return_dict=False)[0]
# Both contexts should produce the same output (first pass in each)
assert_tensors_close(
output_ctx1,
output_ctx2,
atol=atol,
rtol=rtol,
msg="First pass in different cache contexts should produce the same output.",
)
model.disable_cache()
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the cache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@is_cache
class PyramidAttentionBroadcastConfigMixin:
"""
Base mixin providing PyramidAttentionBroadcast cache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default PAB config - can be overridden by subclasses
PAB_CONFIG = {
"spatial_attention_block_skip_range": 2,
}
# Store timestep for callback (must be within default range (100, 800) for skipping to trigger)
_current_timestep = 500
def _get_cache_config(self):
config_kwargs = self.PAB_CONFIG.copy()
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep
return PyramidAttentionBroadcastConfig(**config_kwargs)
def _get_hook_names(self):
return [_PYRAMID_ATTENTION_BROADCAST_HOOK]
@is_cache
class PyramidAttentionBroadcastTesterMixin(PyramidAttentionBroadcastConfigMixin, CacheTesterMixin):
"""
Mixin class for testing PyramidAttentionBroadcast caching on models.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@require_cache_mixin
def test_pab_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_pab_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_pab_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_pab_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_pab_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_pab_reset_stateful_cache(self):
self._test_reset_stateful_cache()
@is_cache
class FirstBlockCacheConfigMixin:
"""
Base mixin providing FirstBlockCache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default FBC config - can be overridden by subclasses
# Higher threshold makes FBC more aggressive about caching (skips more often)
FBC_CONFIG = {
"threshold": 1.0,
}
def _get_cache_config(self):
return FirstBlockCacheConfig(**self.FBC_CONFIG)
def _get_hook_names(self):
return [_FBC_LEADER_BLOCK_HOOK, _FBC_BLOCK_HOOK]
@is_cache
class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
"""
Mixin class for testing FirstBlockCache on models.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with FBC cache enabled (requires cache_context)."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# FBC requires cache_context to be set for inference
with model.cache_context("fbc_test"):
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass
inputs_dict_step2 = inputs_dict.copy()
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass - FBC should skip remaining blocks and use cached residuals
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
with model.cache_context("fbc_test"):
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@require_cache_mixin
def test_fbc_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_fbc_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_fbc_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_fbc_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_fbc_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_fbc_reset_stateful_cache(self):
self._test_reset_stateful_cache()
@is_cache
class FasterCacheConfigMixin:
"""
Base mixin providing FasterCache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default FasterCache config - can be overridden by subclasses
FASTER_CACHE_CONFIG = {
"spatial_attention_block_skip_range": 2,
"spatial_attention_timestep_skip_range": (-1, 901),
"tensor_format": "BCHW",
}
def _get_cache_config(self, current_timestep_callback=None):
config_kwargs = self.FASTER_CACHE_CONFIG.copy()
if current_timestep_callback is None:
current_timestep_callback = lambda: 1000 # noqa: E731
config_kwargs["current_timestep_callback"] = current_timestep_callback
return FasterCacheConfig(**config_kwargs)
def _get_hook_names(self):
return [_FASTER_CACHE_DENOISER_HOOK, _FASTER_CACHE_BLOCK_HOOK]
@is_cache
class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
"""
Mixin class for testing FasterCache on models.
Note: FasterCache is designed for pipeline-level inference with proper CFG batch handling
and timestep management. Inference tests are skipped at model level - FasterCache should
be tested via pipeline tests (e.g., FluxPipeline, HunyuanVideoPipeline).
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with FasterCache enabled."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
current_timestep = [1000]
config = self._get_cache_config(current_timestep_callback=lambda: current_timestep[0])
model.enable_cache(config)
# First pass with timestep outside skip range - computes and populates cache
current_timestep[0] = 1000
_ = model(**inputs_dict, return_dict=False)[0]
# Move timestep inside skip range so subsequent passes use cache
current_timestep[0] = 500
# Create modified inputs for second pass
inputs_dict_step2 = inputs_dict.copy()
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass uses cached attention with different inputs
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FasterCache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@require_cache_mixin
def test_faster_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_faster_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_faster_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_faster_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_faster_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_faster_cache_reset_stateful_cache(self):
self._test_reset_stateful_cache()

View File

@@ -1,666 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections import defaultdict
from typing import Any, Dict, Optional, Type
import pytest
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
from ...testing_utils import assert_tensors_close, torch_device
def named_persistent_module_tensors(
module: nn.Module,
recurse: bool = False,
):
"""
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
Args:
module (`torch.nn.Module`):
The module we want the tensors on.
recurse (`bool`, *optional`, defaults to `False`):
Whether or not to go look in every submodule or just return the direct parameters and buffers.
"""
yield from module.named_parameters(recurse=recurse)
for named_buffer in module.named_buffers(recurse=recurse):
name, _ = named_buffer
# Get parent by splitting on dots and traversing the model
parent = module
if "." in name:
parent_name = name.rsplit(".", 1)[0]
for part in parent_name.split("."):
parent = getattr(parent, part)
name = name.split(".")[-1]
if name not in parent._non_persistent_buffers_set:
yield named_buffer
def compute_module_persistent_sizes(
model: nn.Module,
dtype: str | torch.device | None = None,
special_dtypes: dict[str, str | torch.device] | None = None,
):
"""
Compute the size of each submodule of a given model (parameters + persistent buffers).
"""
if dtype is not None:
dtype = _get_proper_dtype(dtype)
dtype_size = dtype_byte_size(dtype)
if special_dtypes is not None:
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
module_list = []
module_list = named_persistent_module_tensors(model, recurse=True)
for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
# According to the code in set_module_tensor_to_device, these types won't be converted
# so use their original size here
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
name_parts = name.split(".")
for idx in range(len(name_parts) + 1):
module_sizes[".".join(name_parts[:idx])] += size
return module_sizes
def calculate_expected_num_shards(index_map_path):
"""
Calculate expected number of shards from index file.
Args:
index_map_path: Path to the sharded checkpoint index file
Returns:
int: Expected number of shards
"""
with open(index_map_path) as f:
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
return expected_num_shards
def check_device_map_is_respected(model, device_map):
for param_name, param in model.named_parameters():
# Find device in device_map
while len(param_name) > 0 and param_name not in device_map:
param_name = ".".join(param_name.split(".")[:-1])
if param_name not in device_map:
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
param_device = device_map[param_name]
if param_device in ["cpu", "disk"]:
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
else:
assert param.device == torch.device(param_device), (
f"Expected device {param_device} for {param_name}, got {param.device}"
)
def cast_inputs_to_dtype(inputs, current_dtype, target_dtype):
if torch.is_tensor(inputs):
return inputs.to(target_dtype) if inputs.dtype == current_dtype else inputs
if isinstance(inputs, dict):
return {k: cast_inputs_to_dtype(v, current_dtype, target_dtype) for k, v in inputs.items()}
if isinstance(inputs, list):
return [cast_inputs_to_dtype(v, current_dtype, target_dtype) for v in inputs]
return inputs
class BaseModelTesterConfig:
"""
Base class defining the configuration interface for model testing.
This class defines the contract that all model test classes must implement.
It provides a consistent interface for accessing model configuration, initialization
parameters, and test inputs across all testing mixins.
Required properties (must be implemented by subclasses):
- model_class: The model class to test
Optional properties (can be overridden, have sensible defaults):
- pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None)
- pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {})
- output_shape: Expected output shape for output validation tests (default: None)
- model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7])
Required methods (must be implemented by subclasses):
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Example usage:
class MyModelTestConfig(BaseModelTesterConfig):
@property
def model_class(self):
return MyModel
@property
def pretrained_model_name_or_path(self):
return "org/my-model"
@property
def output_shape(self):
return (1, 3, 32, 32)
def get_init_dict(self):
return {"in_channels": 3, "out_channels": 3}
def get_dummy_inputs(self):
return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)}
class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin):
pass
"""
# ==================== Required Properties ====================
@property
def model_class(self) -> Type[nn.Module]:
"""The model class to test. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `model_class` property.")
# ==================== Optional Properties ====================
@property
def pretrained_model_name_or_path(self) -> Optional[str]:
"""Hub repository ID for the pretrained model (used for quantization and hub tests)."""
return None
@property
def pretrained_model_kwargs(self) -> Dict[str, Any]:
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
return {}
@property
def output_shape(self) -> Optional[tuple]:
"""Expected output shape for output validation tests."""
return None
@property
def model_split_percents(self) -> list:
"""Percentages for model parallelism tests."""
return [0.5, 0.7]
# ==================== Required Methods ====================
def get_init_dict(self) -> Dict[str, Any]:
"""
Returns dict of arguments to initialize the model.
Returns:
Dict[str, Any]: Initialization arguments for the model constructor.
Example:
return {
"in_channels": 3,
"out_channels": 3,
"sample_size": 32,
}
"""
raise NotImplementedError("Subclasses must implement `get_init_dict()`.")
def get_dummy_inputs(self) -> Dict[str, Any]:
"""
Returns dict of inputs to pass to the model forward pass.
Returns:
Dict[str, Any]: Input tensors/values for model.forward().
Example:
return {
"sample": torch.randn(1, 3, 32, 32, device=torch_device),
"timestep": torch.tensor([1], device=torch_device),
}
"""
raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.")
class ModelTesterMixin:
"""
Base mixin class for model testing with common test methods.
This mixin expects the test class to also inherit from BaseModelTesterConfig
(or implement its interface) which provides:
- model_class: The model class to test
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Example:
class MyModelTestConfig(BaseModelTesterConfig):
model_class = MyModel
def get_init_dict(self): ...
def get_dummy_inputs(self): ...
class TestMyModel(MyModelTestConfig, ModelTesterMixin):
pass
"""
@torch.no_grad()
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
model.save_pretrained(tmp_path)
new_model = self.model_class.from_pretrained(tmp_path)
new_model.to(torch_device)
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
assert param_1.shape == param_2.shape, (
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@torch.no_grad()
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
model.save_pretrained(tmp_path, variant="fp16")
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
with pytest.raises(OSError) as exc_info:
self.model_class.from_pretrained(tmp_path)
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
new_model.to(torch_device)
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
def test_from_save_pretrained_dtype(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
if torch_device == "mps" and dtype == torch.bfloat16:
pytest.skip(reason=f"{dtype} is not supported on {torch_device}")
model.to(dtype)
model.save_pretrained(tmp_path)
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None:
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype)
assert new_model.dtype == dtype
@torch.no_grad()
def test_determinism(self, atol=1e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
first_flat = first.flatten()
second_flat = second.flatten()
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
first_filtered = first_flat[mask]
second_filtered = second_flat[mask]
assert_tensors_close(
first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
)
@torch.no_grad()
def test_output(self, expected_output_shape=None):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
output = model(**inputs_dict, return_dict=False)[0]
assert output is not None, "Model output is None"
assert output[0].shape == expected_output_shape or self.output_shape, (
f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
)
@torch.no_grad()
def test_outputs_equivalence(self, atol=1e-5, rtol=0):
def set_nan_tensor_to_zero(t):
device = t.device
if device.type == "mps":
t = t.to("cpu")
t[t != t] = 0
return t.to(device)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (list, tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
assert_tensors_close(
set_nan_tensor_to_zero(tuple_object),
set_nan_tensor_to_zero(dict_object),
atol=atol,
rtol=rtol,
msg="Tuple and dict output are not equal",
)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
def test_getattr_is_correct(self, caplog):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
# save some things to test
model.dummy_attribute = 5
model.register_to_config(test_attribute=5)
logger_name = "diffusers.models.modeling_utils"
with caplog.at_level(logging.WARNING, logger=logger_name):
caplog.clear()
assert hasattr(model, "dummy_attribute")
assert getattr(model, "dummy_attribute") == 5
assert model.dummy_attribute == 5
# no warning should be thrown
assert caplog.text == ""
with caplog.at_level(logging.WARNING, logger=logger_name):
caplog.clear()
assert hasattr(model, "save_pretrained")
fn = model.save_pretrained
fn_1 = getattr(model, "save_pretrained")
assert fn == fn_1
# no warning should be thrown
assert caplog.text == ""
# warning should be thrown for config attributes accessed directly
with pytest.warns(FutureWarning):
assert model.test_attribute == 5
with pytest.warns(FutureWarning):
assert getattr(model, "test_attribute") == 5
with pytest.raises(AttributeError) as error:
model.does_not_exist
assert str(error.value) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be used with an accelerator",
)
def test_keep_in_fp32_modules(self):
model = self.model_class(**self.get_init_dict())
fp32_modules = model._keep_in_fp32_modules
if fp32_modules is None or len(fp32_modules) == 0:
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
# Test with float16
model.to(torch_device)
model.to(torch.float16)
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
assert param.dtype == torch.float32, f"Parameter {name} should be float32 but got {param.dtype}"
else:
assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}"
@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be use for inference with an accelerator",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
@torch.no_grad()
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules or []
model.to(dtype).save_pretrained(tmp_path)
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)
for name, param in model_loaded.named_parameters():
if fp32_modules and any(
module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules
):
assert param.data.dtype == torch.float32
else:
assert param.data.dtype == dtype
inputs = cast_inputs_to_dtype(self.get_dummy_inputs(), torch.float32, dtype)
output = model(**inputs, return_dict=False)[0]
output_loaded = model_loaded(**inputs, return_dict=False)[0]
self._check_dtype_inference_output(output, output_loaded, dtype)
def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0):
"""Check dtype inference output with configurable tolerance."""
assert_tensors_close(
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
)
@require_accelerator
@torch.no_grad()
def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
new_model = self.model_class.from_pretrained(tmp_path).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
)
@require_accelerator
@torch.no_grad()
def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
variant = "fp16"
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant)
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
assert os.path.exists(os.path.join(tmp_path, index_filename)), (
f"Variant index file {index_filename} should exist"
)
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
)
@torch.no_grad()
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0):
from diffusers.utils import constants
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
# Save original values to restore after test
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
try:
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
# Load without parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = False
model_sequential = self.model_class.from_pretrained(tmp_path).eval()
model_sequential = model_sequential.to(torch_device)
# Load with parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = True
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
torch.manual_seed(0)
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
model_parallel = model_parallel.to(torch_device)
torch.manual_seed(0)
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
assert_tensors_close(
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
)
finally:
# Restore original values
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
if original_parallel_workers is not None:
constants.HF_PARALLEL_WORKERS = original_parallel_workers
@require_torch_multi_accelerator
@torch.no_grad()
def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_sizes(model)[""]
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
model.cpu().save_pretrained(tmp_path)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory)
# Making sure part of the model will be on GPU 0 and GPU 1
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism"
)

View File

@@ -1,166 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import pytest
import torch
from ...testing_utils import (
backend_empty_cache,
is_torch_compile,
require_accelerator,
require_torch_version_greater,
torch_device,
)
@is_torch_compile
@require_accelerator
@require_torch_version_greater("2.7.1")
class TorchCompileTesterMixin:
"""
Mixin class for testing torch.compile functionality on models.
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- different_shapes_for_compilation: List of (height, width) tuples for dynamic shape testing (default: None)
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: compile
Use `pytest -m "not compile"` to skip these tests
"""
@property
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
"""Optional list of (height, width) tuples for dynamic shape testing."""
return None
def setup_method(self):
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
@torch.no_grad()
def test_torch_compile_recompilation_and_graph_break(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True)
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_torch_compile_repeated_blocks(self):
if self.model_class._repeated_blocks is None:
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_with_group_offloading(self):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
torch._dynamo.config.cache_size_limit = 10000
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.eval()
group_offload_kwargs = {
"onload_device": torch_device,
"offload_device": "cpu",
"offload_type": "block_level",
"num_blocks_per_group": 1,
"use_stream": True,
"non_blocking": True,
}
model.enable_group_offload(**group_offload_kwargs)
model.compile()
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_on_different_shapes(self):
if self.different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
torch.fx.experimental._config.use_duck_shape = False
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True, dynamic=True)
for height, width in self.different_shapes_for_compilation:
with torch._dynamo.config.patch(error_on_recompile=True):
inputs_dict = self.get_dummy_inputs(height=height, width=width)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_works_with_aot(self, tmp_path):
from torch._inductor.package import load_package
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
package_path = os.path.join(str(tmp_path), f"{self.model_class.__name__}.pt2")
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
assert os.path.exists(package_path), f"Package file not created at {package_path}"
loaded_binary = load_package(package_path, run_single_threaded=True)
model.forward = loaded_binary
_ = model(**inputs_dict)
_ = model(**inputs_dict)

View File

@@ -1,158 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from ...testing_utils import backend_empty_cache, is_ip_adapter, torch_device
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
"""
Check if IP Adapter processors are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if IP Adapter is correctly set, False otherwise
"""
for module in model.attn_processors.values():
if isinstance(module, processor_cls):
return True
return False
@is_ip_adapter
class IPAdapterTesterMixin:
"""
Mixin class for testing IP Adapter functionality on models.
Expected from config mixin:
- model_class: The model class to test
Required properties (must be implemented by subclasses):
- ip_adapter_processor_cls: The IP Adapter processor class to use
Required methods (must be implemented by subclasses):
- create_ip_adapter_state_dict(): Creates IP Adapter state dict for testing
- modify_inputs_for_ip_adapter(): Modifies inputs to include IP Adapter data
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: ip_adapter
Use `pytest -m "not ip_adapter"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@property
def ip_adapter_processor_cls(self):
"""IP Adapter processor class to use for testing. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `ip_adapter_processor_cls` property.")
def create_ip_adapter_state_dict(self, model):
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
raise NotImplementedError("child class must implement method to create IPAdapter model inputs")
@torch.no_grad()
def test_load_ip_adapter(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
"IP Adapter processors not set correctly"
)
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), (
"Output should differ with IP Adapter enabled"
)
@pytest.mark.skip(
reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring"
)
def test_ip_adapter_scale(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
# Test scale = 0.0 (no effect)
model.set_ip_adapter_scale(0.0)
torch.manual_seed(0)
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Test scale = 1.0 (full effect)
model.set_ip_adapter_scale(1.0)
torch.manual_seed(0)
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Outputs should differ with different scales
assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), (
"Output should differ with different IP Adapter scales"
)
@pytest.mark.skip(
reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring"
)
def test_unload_ip_adapter(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
# Save original processors
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
# Create and load IP adapter
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set"
# Unload IP adapter
model.unload_ip_adapter()
assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
"IP Adapter should be unloaded"
)
# Verify processors are restored
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
assert original_processors == current_processors, "Processors should be restored after unload"

View File

@@ -1,555 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
import os
import re
import pytest
import safetensors.torch
import torch
import torch.nn as nn
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import check_if_dicts_are_equal
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_lora,
is_torch_compile,
require_peft_backend,
require_peft_version_greater,
require_torch_accelerator,
require_torch_version_greater,
torch_device,
)
if is_peft_available():
from diffusers.loaders.peft import PeftAdapterMixin
def check_if_lora_correctly_set(model) -> bool:
"""
Check if LoRA layers are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if LoRA is correctly set, False otherwise
"""
from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
@is_lora
@require_peft_backend
class LoraTesterMixin:
"""
Mixin class for testing LoRA/PEFT functionality on models.
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: lora
Use `pytest -m "not lora"` to skip these tests
"""
def setup_method(self):
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
@torch.no_grad()
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False, atol=1e-4, rtol=1e-4):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=atol, rtol=rtol), (
"Output should differ with LoRA enabled"
)
model.save_lora_adapter(tmp_path)
assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), (
"LoRA weights file not created"
)
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors"))
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
assert_tensors_close(loaded_v, retrieved_v, atol=atol, rtol=rtol, msg=f"Mismatch in LoRA weight {k}")
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=atol, rtol=rtol), (
"Output should differ with LoRA enabled"
)
assert_tensors_close(
outputs_with_lora,
outputs_with_lora_2,
atol=atol,
rtol=rtol,
msg="Outputs should match before and after save/load",
)
def test_lora_wrong_adapter_name_raises_error(self, tmp_path):
from peft import LoraConfig
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
wrong_name = "foo"
with pytest.raises(ValueError) as exc_info:
model.save_lora_adapter(tmp_path, adapter_name=wrong_name)
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
from peft import LoraConfig
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.save_lora_adapter(tmp_path)
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.save_lora_adapter(tmp_path)
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
# Perturb the metadata in the state dict
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
with pytest.raises(TypeError) as exc_info:
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
@is_lora
@is_torch_compile
@require_peft_backend
@require_peft_version_greater("0.14.0")
@require_torch_version_greater("2.7.1")
@require_torch_accelerator
class LoraHotSwappingForModelTesterMixin:
"""
Mixin class for testing LoRA hot swapping functionality on models.
Test that hotswapping does not result in recompilation on the model directly.
We're not extensively testing the hotswapping functionality since it is implemented in PEFT
and is extensively tested there. The goal of this test is specifically to ensure that
hotswapping with diffusers does not require recompilation.
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
for the analogous PEFT test.
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- different_shapes_for_compilation: List of (height, width) tuples for dynamic compilation tests (default: None)
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest marks: lora, torch_compile
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
"""
@property
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
"""Optional list of (height, width) tuples for dynamic compilation tests."""
return None
def setup_method(self):
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
def teardown_method(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def _get_lora_config(self, lora_rank, lora_alpha, target_modules):
from peft import LoraConfig
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return lora_config
def _get_linear_module_name_other_than_attn(self, model):
linear_names = [
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
]
return linear_names[0]
def _check_model_hotswap(
self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None, atol=5e-3, rtol=5e-3
):
"""
Check that hotswapping works on a model.
Steps:
- create 2 LoRA adapters and save them
- load the first adapter
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
- optionally check if recompilations happen on different shapes
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
different_shapes = self.different_shapes_for_compilation
# create 2 adapters with different ranks and alphas
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
alpha0, alpha1 = rank0, rank1
max_rank = max([rank0, rank1])
if target_modules1 is None:
target_modules1 = target_modules0[:]
lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1)
model.add_adapter(lora_config0, adapter_name="adapter0")
with torch.inference_mode():
torch.manual_seed(0)
output0_before = model(**inputs_dict)["sample"]
model.add_adapter(lora_config1, adapter_name="adapter1")
model.set_adapter("adapter1")
with torch.inference_mode():
torch.manual_seed(0)
output1_before = model(**inputs_dict)["sample"]
# sanity checks:
assert not torch.allclose(output0_before, output1_before, atol=atol, rtol=rtol)
assert not (output0_before == 0).all()
assert not (output1_before == 0).all()
# save the adapter checkpoints
model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0")
model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1")
del model
# load the first adapter
torch.manual_seed(0)
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
if do_compile or (rank0 != rank1):
# no need to prepare if the model is not compiled or if the ranks are identical
model.enable_lora_hotswap(target_rank=max_rank)
file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors")
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
with torch.inference_mode():
# additionally check if dynamic compilation works.
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output0_after = model(**inputs_dict)["sample"]
assert_tensors_close(
output0_before, output0_after, atol=atol, rtol=rtol, msg="Output mismatch after loading adapter0"
)
# hotswap the 2nd adapter
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output1_after = model(**inputs_dict)["sample"]
assert_tensors_close(
output1_before,
output1_after,
atol=atol,
rtol=rtol,
msg="Output mismatch after hotswapping to adapter1",
)
# check error when not passing valid adapter name
name = "does-not-exist"
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
with pytest.raises(ValueError, match=re.escape(msg)):
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_model(self, tmp_path, rank0, rank1):
self._check_model_hotswap(
tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
pytest.skip("Test only applies to UNet.")
# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
pytest.skip("Test only applies to UNet.")
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1):
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
# block.
target_modules = ["to_q"]
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
target_modules.append(self._get_linear_module_name_other_than_attn(model))
del model
# It's important to add this context to raise an error on recompilation
with torch._dynamo.config.patch(error_on_recompile=True):
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with pytest.raises(RuntimeError, match=msg):
model.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
# ensure that enable_lora_hotswap is called before loading the first adapter
import logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
# check possibility to ignore the error/warning
import logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with pytest.raises(ValueError, match=msg):
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
# check the error and log
import logging
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")
def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1):
different_shapes_for_compilation = self.different_shapes_for_compilation
if different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
# variable to represent input sizes that are the same. For more details,
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
torch.fx.experimental._config.use_duck_shape = False
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=rank0,
rank1=rank1,
target_modules0=target_modules,
)

View File

@@ -1,498 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import glob
import inspect
from functools import wraps
import pytest
import torch
from accelerate.utils.modeling import compute_module_sizes
from diffusers.utils.testing_utils import _check_safetensors_serialization
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
backend_synchronize,
is_cpu_offload,
is_group_offload,
is_memory,
require_accelerator,
torch_device,
)
from .common import cast_inputs_to_dtype, check_device_map_is_respected
def require_offload_support(func):
"""
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
return func(self, *args, **kwargs)
return wrapper
def require_group_offload_support(func):
"""
Decorator to skip tests if model doesn't support group offloading.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
return func(self, *args, **kwargs)
return wrapper
@is_cpu_offload
class CPUOffloadTesterMixin:
"""
Mixin class for testing CPU offloading functionality.
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cpu_offload
Use `pytest -m "not cpu_offload"` to skip these tests
"""
@property
def model_split_percents(self) -> list[float]:
"""List of percentages for splitting model across devices during offloading tests."""
return [0.5, 0.7]
@require_offload_support
@torch.no_grad()
def test_cpu_offload(self, tmp_path, atol=1e-5, rtol=0):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
model.cpu().save_pretrained(str(tmp_path))
for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with CPU offloading"
)
@require_offload_support
@torch.no_grad()
def test_disk_offload_without_safetensors(self, tmp_path, atol=1e-5, rtol=0):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
# Force disk offload by setting very small CPU memory
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
model.cpu().save_pretrained(str(tmp_path), safe_serialization=False)
# This errors out because it's missing an offload folder
with pytest.raises(ValueError):
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
new_model = self.model_class.from_pretrained(
str(tmp_path), device_map="auto", max_memory=max_memory, offload_folder=str(tmp_path)
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with disk offloading"
)
@require_offload_support
@torch.no_grad()
def test_disk_offload_with_safetensors(self, tmp_path, atol=1e-5, rtol=0):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model.cpu().save_pretrained(str(tmp_path))
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
str(tmp_path), device_map="auto", offload_folder=str(tmp_path), max_memory=max_memory
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0],
new_output[0],
atol=atol,
rtol=rtol,
msg="Output should match with disk offloading (safetensors)",
)
@is_group_offload
class GroupOffloadTesterMixin:
"""
Mixin class for testing group offloading functionality.
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: group_offload
Use `pytest -m "not group_offload"` to skip these tests
"""
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
def test_group_offloading(self, record_stream, atol=1e-5, rtol=0):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
@torch.no_grad()
def run_forward(model):
assert all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
), "Group offloading hook should be set"
model.eval()
return model(**inputs_dict)[0]
model = self.model_class(**init_dict)
model.to(torch_device)
output_without_group_offloading = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
output_with_group_offloading1 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
output_with_group_offloading2 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level")
output_with_group_offloading3 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
)
output_with_group_offloading4 = run_forward(model)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading1,
atol=atol,
rtol=rtol,
msg="Output should match with block-level offloading",
)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading2,
atol=atol,
rtol=rtol,
msg="Output should match with non-blocking block-level offloading",
)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading3,
atol=atol,
rtol=rtol,
msg="Output should match with leaf-level offloading",
)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading4,
atol=atol,
rtol=rtol,
msg="Output should match with leaf-level offloading with stream",
)
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
@torch.no_grad()
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
_ = model(**inputs_dict)[0]
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
storage_dtype, compute_dtype = torch.float16, torch.float32
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
model.enable_group_offload(
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
@torch.no_grad()
@torch.inference_mode()
def test_group_offloading_with_disk(self, tmp_path, record_stream, offload_type, atol=1e-5, rtol=0):
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
return "generator" in params
def _run_forward(model, inputs_dict):
accepts_generator = _has_generator_arg(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
return model(**inputs_dict)[0]
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
model.to(torch_device)
output_without_group_offloading = _run_forward(model, inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
num_blocks_per_group = None if offload_type == "leaf_level" else 1
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
tmpdir = str(tmp_path)
model.enable_group_offload(
torch_device,
offload_type=offload_type,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
# in nature. So, skip it.
if offload_type != "leaf_level":
is_correct, extra_files, missing_files = _check_safetensors_serialization(
module=model,
offload_to_disk_path=tmpdir,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
)
if not is_correct:
if extra_files:
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
elif missing_files:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading,
atol=atol,
rtol=rtol,
msg="Output should match with disk-based group offloading",
)
class LayerwiseCastingTesterMixin:
"""
Mixin class for testing layerwise dtype casting for memory optimization.
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
@torch.no_grad()
def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0
def reset_memory_stats():
gc.collect()
backend_synchronize(torch_device)
backend_empty_cache(torch_device)
backend_reset_peak_memory_stats(torch_device)
def get_memory_usage(storage_dtype, compute_dtype):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**config).eval()
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
reset_memory_stats()
model(**inputs_dict)
model_memory_footprint = model.get_memory_footprint()
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
return model_memory_footprint, peak_inference_memory_allocated_mb
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
torch.float8_e4m3fn, torch.bfloat16
)
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint, (
"Memory footprint should decrease with lower precision storage"
)
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory, (
"Peak memory should be lower with bf16 compute on newer GPUs"
)
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
# bytes. This only happens for some models, so we allow a small tolerance.
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
assert (
fp8_e4m3_fp32_max_memory < fp32_max_memory
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
), "Peak memory should be lower or within tolerance with fp8 storage"
def test_layerwise_casting_training(self):
def test_fn(storage_dtype, compute_dtype):
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
model = self.model_class(**self.get_init_dict())
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
model.train()
inputs_dict = self.get_dummy_inputs()
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
with torch.amp.autocast(device_type=torch.device(torch_device).type):
output = model(**inputs_dict, return_dict=False)[0]
input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
noise = cast_inputs_to_dtype(noise, torch.float32, compute_dtype)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
test_fn(torch.float16, torch.float32)
test_fn(torch.float8_e4m3fn, torch.float32)
test_fn(torch.float8_e5m2, torch.float32)
test_fn(torch.float8_e4m3fn, torch.bfloat16)
@is_memory
@require_accelerator
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
"""
Combined mixin class for all memory optimization tests including CPU/disk offloading,
group offloading, and layerwise dtype casting.
This mixin inherits from:
- CPUOffloadTesterMixin: CPU and disk offloading tests
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: memory
Use `pytest -m "not memory"` to skip these tests
"""

View File

@@ -1,99 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from ...testing_utils import (
is_context_parallel,
require_torch_multi_accelerator,
)
@torch.no_grad()
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
result_queue.put(("success", output.shape))
except Exception as e:
if rank == 0:
result_queue.put(("error", str(e)))
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
cp_dict = {cp_type: world_size}
ctx = mp.get_context("spawn")
result_queue = ctx.Queue()
mp.spawn(
_context_parallel_worker,
args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue),
nprocs=world_size,
join=True,
)
status, result = result_queue.get(timeout=60)
assert status == "success", f"Context parallel inference failed: {result}"

File diff suppressed because it is too large Load Diff

View File

@@ -1,265 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_single_file,
nightly,
require_torch_accelerator,
torch_device,
)
from .common import check_device_map_is_respected
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
"""Download a single file checkpoint from the Hub to a temporary directory."""
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
return path
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
"""Download diffusers config files (excluding weights) from a repository."""
path = snapshot_download(
pretrained_model_name_or_path,
ignore_patterns=[
"**/*.ckpt",
"*.ckpt",
"**/*.bin",
"*.bin",
"**/*.pt",
"*.pt",
"**/*.safetensors",
"*.safetensors",
],
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
local_dir=tmpdir,
)
return path
@nightly
@require_torch_accelerator
@is_single_file
class SingleFileTesterMixin:
"""
Mixin class for testing single file loading for models.
Required properties (must be implemented by subclasses):
- ckpt_path: Path or Hub path to the single file checkpoint
Optional properties:
- torch_dtype: torch dtype to use for testing (default: None)
- alternate_ckpt_paths: List of alternate checkpoint paths for variant testing (default: None)
Expected from config mixin:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: Additional kwargs for from_pretrained (e.g., subfolder)
Pytest mark: single_file
Use `pytest -m "not single_file"` to skip these tests
"""
# ==================== Required Properties ====================
@property
def ckpt_path(self) -> str:
"""Path or Hub path to the single file checkpoint. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `ckpt_path` property.")
# ==================== Optional Properties ====================
@property
def torch_dtype(self) -> torch.dtype | None:
"""torch dtype to use for single file testing."""
return None
@property
def alternate_ckpt_paths(self) -> list[str] | None:
"""List of alternate checkpoint paths for variant testing."""
return None
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_model_config(self):
pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs}
single_file_kwargs = {"device": torch_device}
if self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between pretrained loading and single file loading: "
f"pretrained={model.config[param_name]}, single_file={param_value}"
)
def test_single_file_model_parameters(self, atol=1e-5, rtol=1e-5):
pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs}
single_file_kwargs = {"device": torch_device}
if self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
state_dict = model.state_dict()
state_dict_single_file = model_single_file.state_dict()
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
"Model parameters keys differ between pretrained and single file loading. "
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
)
for key in state_dict.keys():
param = state_dict[key]
param_single_file = state_dict_single_file[key]
assert param.shape == param_single_file.shape, (
f"Parameter shape mismatch for {key}: "
f"pretrained {param.shape} vs single file {param_single_file.shape}"
)
assert_tensors_close(
param, param_single_file, atol=atol, rtol=rtol, msg=f"Parameter values differ for {key}"
)
def test_single_file_loading_local_files_only(self, tmp_path):
single_file_kwargs = {}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
model_single_file = self.model_class.from_single_file(
local_ckpt_path, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with local_files_only=True"
def test_single_file_loading_with_diffusers_config(self):
single_file_kwargs = {}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
# Load with config parameter
model_single_file = self.model_class.from_single_file(
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
)
# Load pretrained for comparison
pretrained_kwargs = {**self.pretrained_model_kwargs}
if self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
# Compare configs
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
)
def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path):
single_file_kwargs = {}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, str(tmp_path))
model_single_file = self.model_class.from_single_file(
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
def test_single_file_loading_dtype(self):
for dtype in [torch.float32, torch.float16]:
if torch_device == "mps" and dtype == torch.bfloat16:
continue
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
# Cleanup
del model_single_file
gc.collect()
backend_empty_cache(torch_device)
def test_checkpoint_variant_loading(self):
if not self.alternate_ckpt_paths:
return
for ckpt_path in self.alternate_ckpt_paths:
backend_empty_cache(torch_device)
single_file_kwargs = {}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
del model
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_loading_with_device_map(self):
single_file_kwargs = {"device_map": torch_device}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
assert model is not None, "Failed to load model with device_map"
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute when loaded with device_map"
assert model.hf_device_map is not None, "hf_device_map should not be None when loaded with device_map"
check_device_map_is_respected(model, model.hf_device_map)

View File

@@ -1,220 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import pytest
import torch
from diffusers.training_utils import EMAModel
from ...testing_utils import (
backend_empty_cache,
is_training,
require_torch_accelerator_with_training,
torch_all_close,
torch_device,
)
@is_training
@require_torch_accelerator_with_training
class TrainingTesterMixin:
"""
Mixin class for testing training functionality on models.
Expected from config mixin:
- model_class: The model class to test
- output_shape: Tuple defining the expected output shape
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: training
Use `pytest -m "not training"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
def test_training_with_ema(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
ema_model = EMAModel(model.parameters())
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
ema_model.step(model.parameters())
def test_gradient_checkpointing(self):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
init_dict = self.get_init_dict()
# at init model should have gradient checkpointing disabled
model = self.model_class(**init_dict)
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
# check enable works
model.enable_gradient_checkpointing()
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
# check disable works
model.disable_gradient_checkpointing()
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
def test_gradient_checkpointing_is_applied(self, expected_set=None):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
if expected_set is None:
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
init_dict = self.get_init_dict()
model_class_copy = copy.copy(self.model_class)
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
modules_with_gc_enabled = {}
for submodule in model.modules():
if hasattr(submodule, "gradient_checkpointing"):
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
modules_with_gc_enabled[submodule.__class__.__name__] = True
assert set(modules_with_gc_enabled.keys()) == expected_set, (
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} do not match expected set {expected_set}"
)
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
if skip is None:
skip = set()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict_copy = copy.deepcopy(inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict, return_dict=False)[0]
# run the backwards pass on the model
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
torch.manual_seed(0)
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict_copy, return_dict=False)[0]
# run the backwards pass on the model
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
assert (loss - loss_2).abs() < loss_tolerance, (
f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
if "post_quant_conv" in name:
continue
if name in skip:
continue
if param.grad is None:
continue
assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol), (
f"Gradient mismatch for {name}"
)
def test_mixed_precision_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
# Test with float16
if torch.device(torch_device).type != "cpu":
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
# Test with bfloat16
if torch.device(torch_device).type != "cpu":
model.zero_grad()
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()

View File

@@ -13,51 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import unittest
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesCompileTesterMixin,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
IPAdapterTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelOptCompileTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
PyramidAttentionBroadcastTesterMixin,
QuantoCompileTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
TorchAoCompileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
# TODO: This standalone function maintains backward compatibility with pipeline tests
# (tests/pipelines/test_pipelines_common.py) and will be refactored.
def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
"""Create a dummy IP Adapter state dict for Flux transformer testing."""
def create_flux_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 0
@@ -67,7 +39,7 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterAttnProcessor(
sd = FluxIPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
@@ -78,8 +50,11 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
# "image_proj" (ImageProjection layer weights)
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=(
@@ -100,37 +75,57 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
)
del sd
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class FluxTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return FluxTransformer2DModel
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-flux-pipe"
def dummy_input(self):
return self.prepare_dummy_input()
@property
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}
@property
def output_shape(self) -> tuple[int, int]:
def input_shape(self):
return (16, 4)
@property
def input_shape(self) -> tuple[int, int]:
def output_shape(self):
return (16, 4)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
def get_init_dict(self) -> dict[str, int | list[int]]:
"""Return Flux model initialization arguments."""
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
@@ -142,40 +137,11 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": [4, 4, 8],
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
height = width = 4
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
inputs_dict = self.dummy_input
return init_dict, inputs_dict
return {
"hidden_states": randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"pooled_projections": randn_tensor(
(batch_size, embedding_dim), generator=self.generator, device=torch_device
),
"img_ids": randn_tensor(
(height * width, num_image_channels), generator=self.generator, device=torch_device
),
"txt_ids": randn_tensor(
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
def test_deprecated_inputs_img_txt_ids_3d(self):
"""Test that deprecated 3D img_ids and txt_ids still work."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
@@ -196,269 +162,63 @@ class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
assert output_1.shape == output_2.shape
assert torch.allclose(output_1, output_2, atol=1e-5), (
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
"are not equal as them as 2d inputs"
self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux Transformer."""
# The test exists for cases like
# https://github.com/huggingface/diffusers/issues/11874
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_exclude_modules(self):
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
lora_rank = 4
target_module = "single_transformer_blocks.0.proj_out"
adapter_name = "foo"
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux Transformer"""
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for Flux Transformer."""
@property
def ip_adapter_processor_cls(self):
return FluxIPAdapterAttnProcessor
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
torch.manual_seed(0)
# Create dummy image embeds for IP adapter
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
return inputs_dict
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
return create_flux_ip_adapter_state_dict(model)
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux Transformer."""
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
"pooled_projections": randn_tensor((batch_size, embedding_dim), device=torch_device),
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
state_dict = model.state_dict()
target_mod_shape = state_dict[f"{target_module}.weight"].shape
lora_state_dict = {
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
}
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
config = LoraConfig(
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
)
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 8
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
"pooled_projections": randn_tensor((batch_size, embedding_dim), device=torch_device),
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
@property
def ckpt_path(self):
return "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
@property
def alternate_ckpt_paths(self):
return ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
@property
def pretrained_model_name_or_path(self):
return "black-forest-labs/FLUX.1-dev"
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
@property
def gguf_filename(self):
return "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin):
@property
def gguf_filename(self):
return "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64), device=torch_device),
"encoder_hidden_states": randn_tensor((1, 512, 4096), device=torch_device),
"pooled_projections": randn_tensor((1, 768), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3), device=torch_device),
"txt_ids": randn_tensor((512, 3), device=torch_device),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
"""FirstBlockCache tests for Flux Transformer."""
class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin):
"""FasterCache tests for Flux Transformer."""
# Flux is guidance distilled, so we can test at model level without CFG batch handling
FASTER_CACHE_CONFIG = {
"spatial_attention_block_skip_range": 2,
"spatial_attention_timestep_skip_range": (-1, 901),
"tensor_format": "BCHW",
"is_guidance_distilled": True,
}
def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)

View File

@@ -38,7 +38,6 @@ from diffusers.utils.import_utils import (
is_gguf_available,
is_kernels_available,
is_note_seq_available,
is_nvidia_modelopt_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
@@ -131,59 +130,6 @@ def torch_all_close(a, b, *args, **kwargs):
return True
def assert_tensors_close(
actual: "torch.Tensor",
expected: "torch.Tensor",
atol: float = 1e-5,
rtol: float = 1e-5,
msg: str = "",
) -> None:
"""
Assert that two tensors are close within tolerance.
Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected|
Provides concise, actionable error messages without dumping full tensors.
Args:
actual: The actual tensor from the computation.
expected: The expected tensor to compare against.
atol: Absolute tolerance.
rtol: Relative tolerance.
msg: Optional message prefix for the assertion error.
Raises:
AssertionError: If tensors have different shapes or values exceed tolerance.
Example:
>>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass")
"""
if not is_torch_available():
raise ValueError("PyTorch needs to be installed to use this function.")
if actual.shape != expected.shape:
raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")
if not torch.allclose(actual, expected, atol=atol, rtol=rtol):
abs_diff = (actual - expected).abs()
max_diff = abs_diff.max().item()
flat_idx = abs_diff.argmax().item()
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
threshold = atol + rtol * expected.abs()
mismatched = (abs_diff > threshold).sum().item()
total = actual.numel()
raise AssertionError(
f"{msg}\n"
f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n"
f" Max diff: {max_diff:.6e} at index {max_idx}\n"
f" Actual: {actual.flatten()[flat_idx].item():.6e}\n"
f" Expected: {expected.flatten()[flat_idx].item():.6e}\n"
f" atol: {atol:.6e}, rtol: {rtol:.6e}"
)
def numpy_cosine_similarity_distance(a, b):
similarity = np.dot(a, b) / (norm(a) * norm(b))
distance = 1.0 - similarity.mean()
@@ -295,6 +241,7 @@ def parse_flag_from_env(key, default=False):
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
def floats_tensor(shape, scale=1.0, rng=None, name=None):
@@ -335,155 +282,12 @@ def nightly(test_case):
def is_torch_compile(test_case):
"""
Decorator marking a test as a torch.compile test. These tests can be filtered using:
pytest -m "not compile" to skip
pytest -m compile to run only these tests
"""
return pytest.mark.compile(test_case)
Decorator marking a test that runs compile tests in the diffusers CI.
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
def is_single_file(test_case):
"""
Decorator marking a test as a single file loading test. These tests can be filtered using:
pytest -m "not single_file" to skip
pytest -m single_file to run only these tests
"""
return pytest.mark.single_file(test_case)
def is_lora(test_case):
"""
Decorator marking a test as a LoRA test. These tests can be filtered using:
pytest -m "not lora" to skip
pytest -m lora to run only these tests
"""
return pytest.mark.lora(test_case)
def is_ip_adapter(test_case):
"""
Decorator marking a test as an IP Adapter test. These tests can be filtered using:
pytest -m "not ip_adapter" to skip
pytest -m ip_adapter to run only these tests
"""
return pytest.mark.ip_adapter(test_case)
def is_training(test_case):
"""
Decorator marking a test as a training test. These tests can be filtered using:
pytest -m "not training" to skip
pytest -m training to run only these tests
"""
return pytest.mark.training(test_case)
def is_attention(test_case):
"""
Decorator marking a test as an attention test. These tests can be filtered using:
pytest -m "not attention" to skip
pytest -m attention to run only these tests
"""
return pytest.mark.attention(test_case)
def is_memory(test_case):
"""
Decorator marking a test as a memory optimization test. These tests can be filtered using:
pytest -m "not memory" to skip
pytest -m memory to run only these tests
"""
return pytest.mark.memory(test_case)
def is_cpu_offload(test_case):
"""
Decorator marking a test as a CPU offload test. These tests can be filtered using:
pytest -m "not cpu_offload" to skip
pytest -m cpu_offload to run only these tests
"""
return pytest.mark.cpu_offload(test_case)
def is_group_offload(test_case):
"""
Decorator marking a test as a group offload test. These tests can be filtered using:
pytest -m "not group_offload" to skip
pytest -m group_offload to run only these tests
"""
return pytest.mark.group_offload(test_case)
def is_quantization(test_case):
"""
Decorator marking a test as a quantization test. These tests can be filtered using:
pytest -m "not quantization" to skip
pytest -m quantization to run only these tests
"""
return pytest.mark.quantization(test_case)
def is_bitsandbytes(test_case):
"""
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
pytest -m "not bitsandbytes" to skip
pytest -m bitsandbytes to run only these tests
"""
return pytest.mark.bitsandbytes(test_case)
def is_quanto(test_case):
"""
Decorator marking a test as a Quanto quantization test. These tests can be filtered using:
pytest -m "not quanto" to skip
pytest -m quanto to run only these tests
"""
return pytest.mark.quanto(test_case)
def is_torchao(test_case):
"""
Decorator marking a test as a TorchAO quantization test. These tests can be filtered using:
pytest -m "not torchao" to skip
pytest -m torchao to run only these tests
"""
return pytest.mark.torchao(test_case)
def is_gguf(test_case):
"""
Decorator marking a test as a GGUF quantization test. These tests can be filtered using:
pytest -m "not gguf" to skip
pytest -m gguf to run only these tests
"""
return pytest.mark.gguf(test_case)
def is_modelopt(test_case):
"""
Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using:
pytest -m "not modelopt" to skip
pytest -m modelopt to run only these tests
"""
return pytest.mark.modelopt(test_case)
def is_context_parallel(test_case):
"""
Decorator marking a test as a context parallel inference test. These tests can be filtered using:
pytest -m "not context_parallel" to skip
pytest -m context_parallel to run only these tests
"""
return pytest.mark.context_parallel(test_case)
def is_cache(test_case):
"""
Decorator marking a test as a cache test. These tests can be filtered using:
pytest -m "not cache" to skip
pytest -m cache to run only these tests
"""
return pytest.mark.cache(test_case)
return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
def require_torch(test_case):
@@ -846,19 +650,6 @@ def require_kernels_version_greater_or_equal(kernels_version):
return decorator
def require_modelopt_version_greater_or_equal(modelopt_version):
def decorator(test_case):
correct_nvidia_modelopt_version = is_nvidia_modelopt_available() and version.parse(
version.parse(importlib.metadata.version("modelopt")).base_version
) >= version.parse(modelopt_version)
return pytest.mark.skipif(
not correct_nvidia_modelopt_version,
reason=f"Test requires modelopt with version greater than {modelopt_version}.",
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend

View File

@@ -1,592 +0,0 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility script to generate test suites for diffusers model classes.
Usage:
python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_flux.py
This will analyze the model file and generate a test file with appropriate
test classes based on the model's mixins and attributes.
"""
import argparse
import ast
import sys
from pathlib import Path
MIXIN_TO_TESTER = {
"ModelMixin": "ModelTesterMixin",
"PeftAdapterMixin": "LoraTesterMixin",
}
ATTRIBUTE_TO_TESTER = {
"_cp_plan": "ContextParallelTesterMixin",
"_supports_gradient_checkpointing": "TrainingTesterMixin",
}
ALWAYS_INCLUDE_TESTERS = [
"ModelTesterMixin",
"MemoryTesterMixin",
"TorchCompileTesterMixin",
]
# Attention-related class names that indicate the model uses attention
ATTENTION_INDICATORS = {
"AttentionMixin",
"AttentionModuleMixin",
}
OPTIONAL_TESTERS = [
# Quantization testers
("BitsAndBytesTesterMixin", "bnb"),
("QuantoTesterMixin", "quanto"),
("TorchAoTesterMixin", "torchao"),
("GGUFTesterMixin", "gguf"),
("ModelOptTesterMixin", "modelopt"),
# Quantization compile testers
("BitsAndBytesCompileTesterMixin", "bnb_compile"),
("QuantoCompileTesterMixin", "quanto_compile"),
("TorchAoCompileTesterMixin", "torchao_compile"),
("GGUFCompileTesterMixin", "gguf_compile"),
("ModelOptCompileTesterMixin", "modelopt_compile"),
# Cache testers
("PyramidAttentionBroadcastTesterMixin", "pab_cache"),
("FirstBlockCacheTesterMixin", "fbc_cache"),
("FasterCacheTesterMixin", "faster_cache"),
# Other testers
("SingleFileTesterMixin", "single_file"),
("IPAdapterTesterMixin", "ip_adapter"),
]
class ModelAnalyzer(ast.NodeVisitor):
def __init__(self):
self.model_classes = []
self.current_class = None
self.imports = set()
def visit_Import(self, node: ast.Import):
for alias in node.names:
self.imports.add(alias.name.split(".")[-1])
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom):
for alias in node.names:
self.imports.add(alias.name)
self.generic_visit(node)
def visit_ClassDef(self, node: ast.ClassDef):
base_names = []
for base in node.bases:
if isinstance(base, ast.Name):
base_names.append(base.id)
elif isinstance(base, ast.Attribute):
base_names.append(base.attr)
if "ModelMixin" in base_names:
class_info = {
"name": node.name,
"bases": base_names,
"attributes": {},
"has_forward": False,
"init_params": [],
}
for item in node.body:
if isinstance(item, ast.Assign):
for target in item.targets:
if isinstance(target, ast.Name):
attr_name = target.id
if attr_name.startswith("_"):
class_info["attributes"][attr_name] = self._get_value(item.value)
elif isinstance(item, ast.FunctionDef):
if item.name == "forward":
class_info["has_forward"] = True
class_info["forward_params"] = self._extract_func_params(item)
elif item.name == "__init__":
class_info["init_params"] = self._extract_func_params(item)
self.model_classes.append(class_info)
self.generic_visit(node)
def _extract_func_params(self, func_node: ast.FunctionDef) -> list[dict]:
params = []
args = func_node.args
num_defaults = len(args.defaults)
num_args = len(args.args)
first_default_idx = num_args - num_defaults
for i, arg in enumerate(args.args):
if arg.arg == "self":
continue
param_info = {"name": arg.arg, "type": None, "default": None}
if arg.annotation:
param_info["type"] = self._get_annotation_str(arg.annotation)
default_idx = i - first_default_idx
if default_idx >= 0 and default_idx < len(args.defaults):
param_info["default"] = self._get_value(args.defaults[default_idx])
params.append(param_info)
return params
def _get_annotation_str(self, node) -> str:
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Constant):
return repr(node.value)
elif isinstance(node, ast.Subscript):
base = self._get_annotation_str(node.value)
if isinstance(node.slice, ast.Tuple):
args = ", ".join(self._get_annotation_str(el) for el in node.slice.elts)
else:
args = self._get_annotation_str(node.slice)
return f"{base}[{args}]"
elif isinstance(node, ast.Attribute):
return f"{self._get_annotation_str(node.value)}.{node.attr}"
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
left = self._get_annotation_str(node.left)
right = self._get_annotation_str(node.right)
return f"{left} | {right}"
elif isinstance(node, ast.Tuple):
return ", ".join(self._get_annotation_str(el) for el in node.elts)
return "Any"
def _get_value(self, node):
if isinstance(node, ast.Constant):
return node.value
elif isinstance(node, ast.Name):
if node.id == "None":
return None
elif node.id == "True":
return True
elif node.id == "False":
return False
return node.id
elif isinstance(node, ast.List):
return [self._get_value(el) for el in node.elts]
elif isinstance(node, ast.Dict):
return {self._get_value(k): self._get_value(v) for k, v in zip(node.keys, node.values)}
return "<complex>"
def analyze_model_file(filepath: str) -> tuple[list[dict], set[str]]:
with open(filepath) as f:
source = f.read()
tree = ast.parse(source)
analyzer = ModelAnalyzer()
analyzer.visit(tree)
return analyzer.model_classes, analyzer.imports
def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]:
testers = list(ALWAYS_INCLUDE_TESTERS)
for base in model_info["bases"]:
if base in MIXIN_TO_TESTER:
tester = MIXIN_TO_TESTER[base]
if tester not in testers:
testers.append(tester)
for attr, tester in ATTRIBUTE_TO_TESTER.items():
if attr in model_info["attributes"]:
value = model_info["attributes"][attr]
if value is not None and value is not False:
if tester not in testers:
testers.append(tester)
if "_cp_plan" in model_info["attributes"] and model_info["attributes"]["_cp_plan"] is not None:
if "ContextParallelTesterMixin" not in testers:
testers.append("ContextParallelTesterMixin")
# Include AttentionTesterMixin if the model imports attention-related classes
if imports & ATTENTION_INDICATORS:
testers.append("AttentionTesterMixin")
for tester, flag in OPTIONAL_TESTERS:
if flag in include_optional:
if tester not in testers:
testers.append(tester)
return testers
def generate_config_class(model_info: dict, model_name: str) -> str:
class_name = f"{model_name}TesterConfig"
model_class = model_info["name"]
forward_params = model_info.get("forward_params", [])
init_params = model_info.get("init_params", [])
lines = [
f"class {class_name}:",
" @property",
" def model_class(self):",
f" return {model_class}",
"",
" @property",
" def pretrained_model_name_or_path(self):",
' return "" # TODO: Set Hub repository ID',
"",
" @property",
" def pretrained_model_kwargs(self):",
' return {"subfolder": "transformer"}',
"",
" @property",
" def generator(self):",
' return torch.Generator("cpu").manual_seed(0)',
"",
" def get_init_dict(self) -> dict[str, int | list[int]]:",
]
if init_params:
lines.append(" # __init__ parameters:")
for param in init_params:
type_str = f": {param['type']}" if param["type"] else ""
default_str = f" = {param['default']}" if param["default"] is not None else ""
lines.append(f" # {param['name']}{type_str}{default_str}")
lines.extend(
[
" return {}",
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
]
)
if forward_params:
lines.append(" # forward() parameters:")
for param in forward_params:
type_str = f": {param['type']}" if param["type"] else ""
default_str = f" = {param['default']}" if param["default"] is not None else ""
lines.append(f" # {param['name']}{type_str}{default_str}")
lines.extend(
[
" # TODO: Fill in dummy inputs",
" return {}",
"",
" @property",
" def input_shape(self) -> tuple[int, ...]:",
" return (1, 1)",
"",
" @property",
" def output_shape(self) -> tuple[int, ...]:",
" return (1, 1)",
]
)
return "\n".join(lines)
def generate_test_class(model_name: str, config_class: str, tester: str) -> str:
tester_short = tester.replace("TesterMixin", "")
class_name = f"Test{model_name}{tester_short}"
lines = [f"class {class_name}({config_class}, {tester}):"]
if tester == "TorchCompileTesterMixin":
lines.extend(
[
" @property",
" def different_shapes_for_compilation(self):",
" return [(4, 4), (4, 8), (8, 8)]",
"",
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
" # TODO: Implement dynamic input generation",
" return {}",
]
)
elif tester == "IPAdapterTesterMixin":
lines.extend(
[
" @property",
" def ip_adapter_processor_cls(self):",
" return None # TODO: Set processor class",
"",
" def modify_inputs_for_ip_adapter(self, model, inputs_dict):",
" # TODO: Add IP adapter image embeds to inputs",
" return inputs_dict",
"",
" def create_ip_adapter_state_dict(self, model):",
" # TODO: Create IP adapter state dict",
" return {}",
]
)
elif tester == "SingleFileTesterMixin":
lines.extend(
[
" @property",
" def ckpt_path(self):",
' return "" # TODO: Set checkpoint path',
"",
" @property",
" def alternate_ckpt_paths(self):",
" return []",
"",
" @property",
" def pretrained_model_name_or_path(self):",
' return "" # TODO: Set Hub repository ID',
]
)
elif tester == "GGUFTesterMixin":
lines.extend(
[
" @property",
" def gguf_filename(self):",
' return "" # TODO: Set GGUF filename',
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization tests",
" return {}",
]
)
elif tester in ["BitsAndBytesTesterMixin", "QuantoTesterMixin", "TorchAoTesterMixin", "ModelOptTesterMixin"]:
lines.extend(
[
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization tests",
" return {}",
]
)
elif tester in [
"BitsAndBytesCompileTesterMixin",
"QuantoCompileTesterMixin",
"TorchAoCompileTesterMixin",
"ModelOptCompileTesterMixin",
]:
lines.extend(
[
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization compile tests",
" return {}",
]
)
elif tester == "GGUFCompileTesterMixin":
lines.extend(
[
" @property",
" def gguf_filename(self):",
' return "" # TODO: Set GGUF filename',
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization compile tests",
" return {}",
]
)
elif tester in [
"PyramidAttentionBroadcastTesterMixin",
"FirstBlockCacheTesterMixin",
"FasterCacheTesterMixin",
]:
lines.append(" pass")
elif tester == "LoraHotSwappingForModelTesterMixin":
lines.extend(
[
" @property",
" def different_shapes_for_compilation(self):",
" return [(4, 4), (4, 8), (8, 8)]",
"",
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
" # TODO: Implement dynamic input generation",
" return {}",
]
)
else:
lines.append(" pass")
return "\n".join(lines)
def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str:
model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "")
testers = determine_testers(model_info, include_optional, imports)
tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"})
lines = [
"# coding=utf-8",
"# Copyright 2025 HuggingFace Inc.",
"#",
'# Licensed under the Apache License, Version 2.0 (the "License");',
"# you may not use this file except in compliance with the License.",
"# You may obtain a copy of the License at",
"#",
"# http://www.apache.org/licenses/LICENSE-2.0",
"#",
"# Unless required by applicable law or agreed to in writing, software",
'# distributed under the License is distributed on an "AS IS" BASIS,',
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
"# See the License for the specific language governing permissions and",
"# limitations under the License.",
"",
"import torch",
"",
f"from diffusers import {model_info['name']}",
"from diffusers.utils.torch_utils import randn_tensor",
"",
"from ...testing_utils import enable_full_determinism, torch_device",
]
if "LoraTesterMixin" in testers:
lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin")
lines.extend(
[
"from ..testing_utils import (",
*[f" {tester}," for tester in sorted(tester_imports)],
")",
"",
"",
"enable_full_determinism()",
"",
"",
]
)
config_class = f"{model_name}TesterConfig"
lines.append(generate_config_class(model_info, model_name))
lines.append("")
lines.append("")
for tester in testers:
lines.append(generate_test_class(model_name, config_class, tester))
lines.append("")
lines.append("")
if "LoraTesterMixin" in testers:
lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin"))
lines.append("")
lines.append("")
return "\n".join(lines).rstrip() + "\n"
def get_test_output_path(model_filepath: str) -> str:
path = Path(model_filepath)
model_filename = path.stem
if "transformers" in path.parts:
return f"tests/models/transformers/test_models_{model_filename}.py"
elif "unets" in path.parts:
return f"tests/models/unets/test_models_{model_filename}.py"
elif "autoencoders" in path.parts:
return f"tests/models/autoencoders/test_models_{model_filename}.py"
else:
return f"tests/models/test_models_{model_filename}.py"
def main():
parser = argparse.ArgumentParser(description="Generate test suite for a diffusers model class")
parser.add_argument(
"model_filepath",
type=str,
help="Path to the model file (e.g., src/diffusers/models/transformers/transformer_flux.py)",
)
parser.add_argument(
"--output", "-o", type=str, default=None, help="Output file path (default: auto-generated based on model path)"
)
parser.add_argument(
"--include",
"-i",
type=str,
nargs="*",
default=[],
choices=[
"bnb",
"quanto",
"torchao",
"gguf",
"modelopt",
"bnb_compile",
"quanto_compile",
"torchao_compile",
"gguf_compile",
"modelopt_compile",
"pab_cache",
"fbc_cache",
"faster_cache",
"single_file",
"ip_adapter",
"all",
],
help="Optional testers to include",
)
parser.add_argument(
"--class-name",
"-c",
type=str,
default=None,
help="Specific model class to generate tests for (default: first model class found)",
)
parser.add_argument("--dry-run", action="store_true", help="Print generated code without writing to file")
args = parser.parse_args()
if not Path(args.model_filepath).exists():
print(f"Error: File not found: {args.model_filepath}", file=sys.stderr)
sys.exit(1)
model_classes, imports = analyze_model_file(args.model_filepath)
if not model_classes:
print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr)
sys.exit(1)
if args.class_name:
model_info = next((m for m in model_classes if m["name"] == args.class_name), None)
if not model_info:
available = [m["name"] for m in model_classes]
print(f"Error: Class '{args.class_name}' not found. Available: {available}", file=sys.stderr)
sys.exit(1)
else:
model_info = model_classes[0]
if len(model_classes) > 1:
print(f"Multiple model classes found, using: {model_info['name']}", file=sys.stderr)
print("Use --class-name to specify a different class", file=sys.stderr)
include_optional = args.include
if "all" in include_optional:
include_optional = [flag for _, flag in OPTIONAL_TESTERS]
generated_code = generate_test_file(model_info, args.model_filepath, include_optional, imports)
if args.dry_run:
print(generated_code)
else:
output_path = args.output or get_test_output_path(args.model_filepath)
output_dir = Path(output_path).parent
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
f.write(generated_code)
print(f"Generated test file: {output_path}")
print(f"Model class: {model_info['name']}")
print(f"Detected attributes: {list(model_info['attributes'].keys())}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,300 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Auto Docstring Generator for Modular Pipeline Blocks
This script scans Python files for classes that have `# auto_docstring` comment above them
and inserts/updates the docstring from the class's `doc` property.
Run from the root of the repo:
python utils/modular_auto_docstring.py [path] [--fix_and_overwrite]
Examples:
# Check for auto_docstring markers (will error if found without proper docstring)
python utils/modular_auto_docstring.py
# Check specific directory
python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/
# Fix and overwrite the docstrings
python utils/modular_auto_docstring.py --fix_and_overwrite
Usage in code:
# auto_docstring
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
# docstring will be automatically inserted here
@property
def doc(self):
return "Your docstring content..."
"""
import argparse
import ast
import glob
import importlib
import os
import re
import sys
# All paths are set with the intent you should run this script from the root of the repo
DIFFUSERS_PATH = "src/diffusers"
REPO_PATH = "."
# Pattern to match the auto_docstring comment
AUTO_DOCSTRING_PATTERN = re.compile(r"^\s*#\s*auto_docstring\s*$")
def setup_diffusers_import():
"""Setup import path to use the local diffusers module."""
src_path = os.path.join(REPO_PATH, "src")
if src_path not in sys.path:
sys.path.insert(0, src_path)
def get_module_from_filepath(filepath: str) -> str:
"""Convert a filepath to a module name."""
filepath = os.path.normpath(filepath)
if filepath.startswith("src" + os.sep):
filepath = filepath[4:]
if filepath.endswith(".py"):
filepath = filepath[:-3]
module_name = filepath.replace(os.sep, ".")
return module_name
def load_module(filepath: str):
"""Load a module from filepath."""
setup_diffusers_import()
module_name = get_module_from_filepath(filepath)
try:
module = importlib.import_module(module_name)
return module
except Exception as e:
print(f"Warning: Could not import module {module_name}: {e}")
return None
def get_doc_from_class(module, class_name: str) -> str:
"""Get the doc property from an instantiated class."""
if module is None:
return None
cls = getattr(module, class_name, None)
if cls is None:
return None
try:
instance = cls()
if hasattr(instance, "doc"):
return instance.doc
except Exception as e:
print(f"Warning: Could not instantiate {class_name}: {e}")
return None
def find_auto_docstring_classes(filepath: str) -> list:
"""
Find all classes in a file that have # auto_docstring comment above them.
Returns list of (class_name, class_line_number, has_existing_docstring, docstring_end_line)
"""
with open(filepath, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Parse AST to find class locations and their docstrings
content = "".join(lines)
try:
tree = ast.parse(content)
except SyntaxError as e:
print(f"Syntax error in {filepath}: {e}")
return []
# Build a map of class_name -> (class_line, has_docstring, docstring_end_line)
class_info = {}
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
has_docstring = False
docstring_end_line = node.lineno # default to class line
if node.body and isinstance(node.body[0], ast.Expr):
first_stmt = node.body[0]
if isinstance(first_stmt.value, ast.Constant) and isinstance(first_stmt.value.value, str):
has_docstring = True
docstring_end_line = first_stmt.end_lineno or first_stmt.lineno
class_info[node.name] = (node.lineno, has_docstring, docstring_end_line)
# Now scan for # auto_docstring comments
classes_to_update = []
for i, line in enumerate(lines):
if AUTO_DOCSTRING_PATTERN.match(line):
# Found the marker, look for class definition on next non-empty, non-comment line
j = i + 1
while j < len(lines):
next_line = lines[j].strip()
if next_line and not next_line.startswith("#"):
break
j += 1
if j < len(lines) and lines[j].strip().startswith("class "):
# Extract class name
match = re.match(r"class\s+(\w+)", lines[j].strip())
if match:
class_name = match.group(1)
if class_name in class_info:
class_line, has_docstring, docstring_end_line = class_info[class_name]
classes_to_update.append((class_name, class_line, has_docstring, docstring_end_line))
return classes_to_update
def strip_class_name_line(doc: str, class_name: str) -> str:
"""Remove the 'class ClassName' line from the doc if present."""
lines = doc.strip().split("\n")
if lines and lines[0].strip() == f"class {class_name}":
# Remove the class line and any blank line following it
lines = lines[1:]
while lines and not lines[0].strip():
lines = lines[1:]
return "\n".join(lines)
def format_docstring(doc: str, indent: str = " ") -> str:
"""Format a doc string as a properly indented docstring."""
lines = doc.strip().split("\n")
if len(lines) == 1:
return f'{indent}"""{lines[0]}"""\n'
else:
result = [f'{indent}"""\n']
for line in lines:
if line.strip():
result.append(f"{indent}{line}\n")
else:
result.append("\n")
result.append(f'{indent}"""\n')
return "".join(result)
def process_file(filepath: str, overwrite: bool = False) -> list:
"""
Process a file and find/insert docstrings for # auto_docstring marked classes.
Returns list of classes that need updating.
"""
classes_to_update = find_auto_docstring_classes(filepath)
if not classes_to_update:
return []
if not overwrite:
# Just return the list of classes that need updating
return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update]
# Load the module to get doc properties
module = load_module(filepath)
with open(filepath, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Process in reverse order to maintain line numbers
updated = False
for class_name, class_line, has_docstring, docstring_end_line in reversed(classes_to_update):
doc = get_doc_from_class(module, class_name)
if doc is None:
print(f"Warning: Could not get doc for {class_name} in {filepath}")
continue
# Remove the "class ClassName" line since it's redundant in a docstring
doc = strip_class_name_line(doc, class_name)
# Format the new docstring with 4-space indent
new_docstring = format_docstring(doc, " ")
if has_docstring:
# Replace existing docstring (line after class definition to docstring_end_line)
# class_line is 1-indexed, we want to replace from class_line+1 to docstring_end_line
lines = lines[:class_line] + [new_docstring] + lines[docstring_end_line:]
else:
# Insert new docstring right after class definition line
# class_line is 1-indexed, so lines[class_line-1] is the class line
# Insert at position class_line (which is right after the class line)
lines = lines[:class_line] + [new_docstring] + lines[class_line:]
updated = True
print(f"Updated docstring for {class_name} in {filepath}")
if updated:
with open(filepath, "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines)
return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update]
def check_auto_docstrings(path: str = None, overwrite: bool = False):
"""
Check all files for # auto_docstring markers and optionally fix them.
"""
if path is None:
path = DIFFUSERS_PATH
if os.path.isfile(path):
all_files = [path]
else:
all_files = glob.glob(os.path.join(path, "**/*.py"), recursive=True)
all_markers = []
for filepath in all_files:
markers = process_file(filepath, overwrite)
all_markers.extend(markers)
if not overwrite and len(all_markers) > 0:
message = "\n".join([f"- {f}: {cls} at line {line}" for f, cls, line in all_markers])
raise ValueError(
f"Found the following # auto_docstring markers that need docstrings:\n{message}\n\n"
f"Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them."
)
if overwrite and len(all_markers) > 0:
print(f"\nUpdated {len(all_markers)} docstring(s).")
elif len(all_markers) == 0:
print("No # auto_docstring markers found.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Check and fix # auto_docstring markers in modular pipeline blocks",
)
parser.add_argument("path", nargs="?", default=None, help="File or directory to process (default: src/diffusers)")
parser.add_argument(
"--fix_and_overwrite",
action="store_true",
help="Whether to fix the docstrings by inserting them from doc property.",
)
args = parser.parse_args()
check_auto_docstrings(args.path, args.fix_and_overwrite)