mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-19 01:56:12 +08:00
Compare commits
2 Commits
main
...
modular-me
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
929414d0ea | ||
|
|
d2ded13cd6 |
@@ -1 +0,0 @@
|
||||
docs/source/en/conceptual/contribution.md
|
||||
506
CONTRIBUTING.md
Normal file
506
CONTRIBUTING.md
Normal file
@@ -0,0 +1,506 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# How to contribute to Diffusers 🧨
|
||||
|
||||
We ❤️ contributions from the open-source community! Everyone is welcome, and all types of participation –not just code– are valued and appreciated. Answering questions, helping others, reaching out, and improving the documentation are all immensely valuable to the community, so don't be afraid and get involved if you're up for it!
|
||||
|
||||
Everyone is encouraged to start by saying 👋 in our public Discord channel. We discuss the latest trends in diffusion models, ask questions, show off personal projects, help each other with contributions, or just hang out ☕. <a href="https://discord.gg/G7tWnz98XR"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=Discord&logoColor=white"></a>
|
||||
|
||||
Whichever way you choose to contribute, we strive to be part of an open, welcoming, and kind community. Please, read our [code of conduct](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md) and be mindful to respect it during your interactions. We also recommend you become familiar with the [ethical guidelines](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines) that guide our project and ask you to adhere to the same principles of transparency and responsibility.
|
||||
|
||||
We enormously value feedback from the community, so please do not be afraid to speak up if you believe you have valuable feedback that can help improve the library - every message, comment, issue, and pull request (PR) is read and considered.
|
||||
|
||||
## Overview
|
||||
|
||||
You can contribute in many ways ranging from answering questions on issues to adding new diffusion models to
|
||||
the core library.
|
||||
|
||||
In the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community.
|
||||
|
||||
* 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR).
|
||||
* 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose).
|
||||
* 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues).
|
||||
* 4. Fix a simple issue, marked by the "Good first issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
|
||||
* 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source).
|
||||
* 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples).
|
||||
* 7. Contribute to the [examples](https://github.com/huggingface/diffusers/tree/main/examples).
|
||||
* 8. Fix a more difficult issue, marked by the "Good second issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22).
|
||||
* 9. Add a new pipeline, model, or scheduler, see ["New Pipeline/Model"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) and ["New scheduler"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) issues. For this contribution, please have a look at [Design Philosophy](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md).
|
||||
|
||||
As said before, **all contributions are valuable to the community**.
|
||||
In the following, we will explain each contribution a bit more in detail.
|
||||
|
||||
For all contributions 4-9, you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr).
|
||||
|
||||
### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord
|
||||
|
||||
Any question or comment related to the Diffusers library can be asked on the [discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/) or on [Discord](https://discord.gg/G7tWnz98XR). Such questions and comments include (but are not limited to):
|
||||
- Reports of training or inference experiments in an attempt to share knowledge
|
||||
- Presentation of personal projects
|
||||
- Questions to non-official training examples
|
||||
- Project proposals
|
||||
- General feedback
|
||||
- Paper summaries
|
||||
- Asking for help on personal projects that build on top of the Diffusers library
|
||||
- General questions
|
||||
- Ethical questions regarding diffusion models
|
||||
- ...
|
||||
|
||||
Every question that is asked on the forum or on Discord actively encourages the community to publicly
|
||||
share knowledge and might very well help a beginner in the future who has the same question you're
|
||||
having. Please do pose any questions you might have.
|
||||
In the same spirit, you are of immense help to the community by answering such questions because this way you are publicly documenting knowledge for everybody to learn from.
|
||||
|
||||
**Please** keep in mind that the more effort you put into asking or answering a question, the higher
|
||||
the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.
|
||||
In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formatted/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
|
||||
|
||||
**NOTE about channels**:
|
||||
[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.
|
||||
In addition, questions and answers posted in the forum can easily be linked to.
|
||||
In contrast, *Discord* has a chat-like format that invites fast back-and-forth communication.
|
||||
While it will most likely take less time for you to get an answer to your question on Discord, your
|
||||
question won't be visible anymore over time. Also, it's much harder to find information that was posted a while back on Discord. We therefore strongly recommend using the forum for high-quality questions and answers in an attempt to create long-lasting knowledge for the community. If discussions on Discord lead to very interesting answers and conclusions, we recommend posting the results on the forum to make the information more available for future readers.
|
||||
|
||||
### 2. Opening new issues on the GitHub issues tab
|
||||
|
||||
The 🧨 Diffusers library is robust and reliable thanks to the users who notify us of
|
||||
the problems they encounter. So thank you for reporting an issue.
|
||||
|
||||
Remember, GitHub issues are reserved for technical questions directly related to the Diffusers library, bug reports, feature requests, or feedback on the library design.
|
||||
|
||||
In a nutshell, this means that everything that is **not** related to the **code of the Diffusers library** (including the documentation) should **not** be asked on GitHub, but rather on either the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).
|
||||
|
||||
**Please consider the following guidelines when opening a new issue**:
|
||||
- Make sure you have searched whether your issue has already been asked before (use the search bar on GitHub under Issues).
|
||||
- Please never report a new issue on another (related) issue. If another issue is highly related, please
|
||||
open a new issue nevertheless and link to the related issue.
|
||||
- Make sure your issue is written in English. Please use one of the great, free online translation services, such as [DeepL](https://www.deepl.com/translator) to translate from your native language to English if you are not comfortable in English.
|
||||
- Check whether your issue might be solved by updating to the newest Diffusers version. Before posting your issue, please make sure that `python -c "import diffusers; print(diffusers.__version__)"` is higher or matches the latest Diffusers version.
|
||||
- Remember that the more effort you put into opening a new issue, the higher the quality of your answer will be and the better the overall quality of the Diffusers issues.
|
||||
|
||||
New issues usually include the following.
|
||||
|
||||
#### 2.1. Reproducible, minimal bug reports
|
||||
|
||||
A bug report should always have a reproducible code snippet and be as minimal and concise as possible.
|
||||
This means in more detail:
|
||||
- Narrow the bug down as much as you can, **do not just dump your whole code file**.
|
||||
- Format your code.
|
||||
- Do not include any external libraries except for Diffusers depending on them.
|
||||
- **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue.
|
||||
- Explain the issue. If the reader doesn't know what the issue is and why it is an issue, she cannot solve it.
|
||||
- **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell.
|
||||
- If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible.
|
||||
|
||||
For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
|
||||
|
||||
You can open a bug report [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml).
|
||||
|
||||
#### 2.2. Feature requests
|
||||
|
||||
A world-class feature request addresses the following points:
|
||||
|
||||
1. Motivation first:
|
||||
* Is it related to a problem/frustration with the library? If so, please explain
|
||||
why. Providing a code snippet that demonstrates the problem is best.
|
||||
* Is it related to something you would need for a project? We'd love to hear
|
||||
about it!
|
||||
* Is it something you worked on and think could benefit the community?
|
||||
Awesome! Tell us what problem it solved for you.
|
||||
2. Write a *full paragraph* describing the feature;
|
||||
3. Provide a **code snippet** that demonstrates its future use;
|
||||
4. In case this is related to a paper, please attach a link;
|
||||
5. Attach any additional information (drawings, screenshots, etc.) you think may help.
|
||||
|
||||
You can open a feature request [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=).
|
||||
|
||||
#### 2.3 Feedback
|
||||
|
||||
Feedback about the library design and why it is good or not good helps the core maintainers immensely to build a user-friendly library. To understand the philosophy behind the current design philosophy, please have a look [here](https://huggingface.co/docs/diffusers/conceptual/philosophy). If you feel like a certain design choice does not fit with the current design philosophy, please explain why and how it should be changed. If a certain design choice follows the design philosophy too much, hence restricting use cases, explain why and how it should be changed.
|
||||
If a certain design choice is very useful for you, please also leave a note as this is great feedback for future design decisions.
|
||||
|
||||
You can open an issue about feedback [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).
|
||||
|
||||
#### 2.4 Technical questions
|
||||
|
||||
Technical questions are mainly about why certain code of the library was written in a certain way, or what a certain part of the code does. Please make sure to link to the code in question and please provide detail on
|
||||
why this part of the code is difficult to understand.
|
||||
|
||||
You can open an issue about a technical question [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml).
|
||||
|
||||
#### 2.5 Proposal to add a new model, scheduler, or pipeline
|
||||
|
||||
If the diffusion model community released a new model, pipeline, or scheduler that you would like to see in the Diffusers library, please provide the following information:
|
||||
|
||||
* Short description of the diffusion pipeline, model, or scheduler and link to the paper or public release.
|
||||
* Link to any of its open-source implementation.
|
||||
* Link to the model weights if they are available.
|
||||
|
||||
If you are willing to contribute to the model yourself, let us know so we can best guide you. Also, don't forget
|
||||
to tag the original author of the component (model, scheduler, pipeline, etc.) by GitHub handle if you can find it.
|
||||
|
||||
You can open a request for a model/pipeline/scheduler [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml).
|
||||
|
||||
### 3. Answering issues on the GitHub issues tab
|
||||
|
||||
Answering issues on GitHub might require some technical knowledge of Diffusers, but we encourage everybody to give it a try even if you are not 100% certain that your answer is correct.
|
||||
Some tips to give a high-quality answer to an issue:
|
||||
- Be as concise and minimal as possible.
|
||||
- Stay on topic. An answer to the issue should concern the issue and only the issue.
|
||||
- Provide links to code, papers, or other sources that prove or encourage your point.
|
||||
- Answer in code. If a simple code snippet is the answer to the issue or shows how the issue can be solved, please provide a fully reproducible code snippet.
|
||||
|
||||
Also, many issues tend to be simply off-topic, duplicates of other issues, or irrelevant. It is of great
|
||||
help to the maintainers if you can answer such issues, encouraging the author of the issue to be
|
||||
more precise, provide the link to a duplicated issue or redirect them to [the forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).
|
||||
|
||||
If you have verified that the issued bug report is correct and requires a correction in the source code,
|
||||
please have a look at the next sections.
|
||||
|
||||
For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section.
|
||||
|
||||
### 4. Fixing a "Good first issue"
|
||||
|
||||
*Good first issues* are marked by the [Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) label. Usually, the issue already
|
||||
explains how a potential solution should look so that it is easier to fix.
|
||||
If the issue hasn't been closed and you would like to try to fix this issue, you can just leave a message "I would like to try this issue.". There are usually three scenarios:
|
||||
- a.) The issue description already proposes a fix. In this case and if the solution makes sense to you, you can open a PR or draft PR to fix it.
|
||||
- b.) The issue description does not propose a fix. In this case, you can ask what a proposed fix could look like and someone from the Diffusers team should answer shortly. If you have a good idea of how to fix it, feel free to directly open a PR.
|
||||
- c.) There is already an open PR to fix the issue, but the issue hasn't been closed yet. If the PR has gone stale, you can simply open a new PR and link to the stale PR. PRs often go stale if the original contributor who wanted to fix the issue suddenly cannot find the time anymore to proceed. This often happens in open-source and is very normal. In this case, the community will be very happy if you give it a new try and leverage the knowledge of the existing PR. If there is already a PR and it is active, you can help the author by giving suggestions, reviewing the PR or even asking whether you can contribute to the PR.
|
||||
|
||||
|
||||
### 5. Contribute to the documentation
|
||||
|
||||
A good library **always** has good documentation! The official documentation is often one of the first points of contact for new users of the library, and therefore contributing to the documentation is a **highly
|
||||
valuable contribution**.
|
||||
|
||||
Contributing to the library can have many forms:
|
||||
|
||||
- Correcting spelling or grammatical errors.
|
||||
- Correct incorrect formatting of the docstring. If you see that the official documentation is weirdly displayed or a link is broken, we are very happy if you take some time to correct it.
|
||||
- Correct the shape or dimensions of a docstring input or output tensor.
|
||||
- Clarify documentation that is hard to understand or incorrect.
|
||||
- Update outdated code examples.
|
||||
- Translating the documentation to another language.
|
||||
|
||||
Anything displayed on [the official Diffusers doc page](https://huggingface.co/docs/diffusers/index) is part of the official documentation and can be corrected, adjusted in the respective [documentation source](https://github.com/huggingface/diffusers/tree/main/docs/source).
|
||||
|
||||
Please have a look at [this page](https://github.com/huggingface/diffusers/tree/main/docs) on how to verify changes made to the documentation locally.
|
||||
|
||||
|
||||
### 6. Contribute a community pipeline
|
||||
|
||||
[Pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) are usually the first point of contact between the Diffusers library and the user.
|
||||
Pipelines are examples of how to use Diffusers [models](https://huggingface.co/docs/diffusers/api/models/overview) and [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview).
|
||||
We support two types of pipelines:
|
||||
|
||||
- Official Pipelines
|
||||
- Community Pipelines
|
||||
|
||||
Both official and community pipelines follow the same design and consist of the same type of components.
|
||||
|
||||
Official pipelines are tested and maintained by the core maintainers of Diffusers. Their code
|
||||
resides in [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
|
||||
In contrast, community pipelines are contributed and maintained purely by the **community** and are **not** tested.
|
||||
They reside in [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and while they can be accessed via the [PyPI diffusers package](https://pypi.org/project/diffusers/), their code is not part of the PyPI distribution.
|
||||
|
||||
The reason for the distinction is that the core maintainers of the Diffusers library cannot maintain and test all
|
||||
possible ways diffusion models can be used for inference, but some of them may be of interest to the community.
|
||||
Officially released diffusion pipelines,
|
||||
such as Stable Diffusion are added to the core src/diffusers/pipelines package which ensures
|
||||
high quality of maintenance, no backward-breaking code changes, and testing.
|
||||
More bleeding edge pipelines should be added as community pipelines. If usage for a community pipeline is high, the pipeline can be moved to the official pipelines upon request from the community. This is one of the ways we strive to be a community-driven library.
|
||||
|
||||
To add a community pipeline, one should add a <name-of-the-community>.py file to [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and adapt the [examples/community/README.md](https://github.com/huggingface/diffusers/tree/main/examples/community/README.md) to include an example of the new pipeline.
|
||||
|
||||
An example can be seen [here](https://github.com/huggingface/diffusers/pull/2400).
|
||||
|
||||
Community pipeline PRs are only checked at a superficial level and ideally they should be maintained by their original authors.
|
||||
|
||||
Contributing a community pipeline is a great way to understand how Diffusers models and schedulers work. Having contributed a community pipeline is usually the first stepping stone to contributing an official pipeline to the
|
||||
core package.
|
||||
|
||||
### 7. Contribute to training examples
|
||||
|
||||
Diffusers examples are a collection of training scripts that reside in [examples](https://github.com/huggingface/diffusers/tree/main/examples).
|
||||
|
||||
We support two types of training examples:
|
||||
|
||||
- Official training examples
|
||||
- Research training examples
|
||||
|
||||
Research training examples are located in [examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) whereas official training examples include all folders under [examples](https://github.com/huggingface/diffusers/tree/main/examples) except the `research_projects` and `community` folders.
|
||||
The official training examples are maintained by the Diffusers' core maintainers whereas the research training examples are maintained by the community.
|
||||
This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.
|
||||
If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.
|
||||
|
||||
Both official training and research examples consist of a directory that contains one or more training scripts, a `requirements.txt` file, and a `README.md` file. In order for the user to make use of the
|
||||
training examples, it is required to clone the repository:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
```
|
||||
|
||||
as well as to install all additional dependencies required for training:
|
||||
|
||||
```bash
|
||||
cd diffusers
|
||||
pip install -r examples/<your-example-folder>/requirements.txt
|
||||
```
|
||||
|
||||
Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).
|
||||
|
||||
Training examples of the Diffusers library should adhere to the following philosophy:
|
||||
- All the code necessary to run the examples should be found in a single Python file.
|
||||
- One should be able to run the example from the command line with `python <your-example>.py --args`.
|
||||
- Examples should be kept simple and serve as **an example** on how to use Diffusers for training. The purpose of example scripts is **not** to create state-of-the-art diffusion models, but rather to reproduce known training schemes without adding too much custom logic. As a byproduct of this point, our examples also strive to serve as good educational materials.
|
||||
|
||||
To contribute an example, it is highly recommended to look at already existing examples such as [dreambooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) to get an idea of how they should look like.
|
||||
We strongly advise contributors to make use of the [Accelerate library](https://github.com/huggingface/accelerate) as it's tightly integrated
|
||||
with Diffusers.
|
||||
Once an example script works, please make sure to add a comprehensive `README.md` that states how to use the example exactly. This README should include:
|
||||
- An example command on how to run the example script as shown [here e.g.](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-locally-with-pytorch).
|
||||
- A link to some training results (logs, models, ...) that show what the user can expect as shown [here e.g.](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
|
||||
- If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations).
|
||||
|
||||
If you are contributing to the official training examples, please also make sure to add a test to [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py). This is not necessary for non-official training examples.
|
||||
|
||||
### 8. Fixing a "Good second issue"
|
||||
|
||||
*Good second issues* are marked by the [Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22) label. Good second issues are
|
||||
usually more complicated to solve than [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
|
||||
The issue description usually gives less guidance on how to fix the issue and requires
|
||||
a decent understanding of the library by the interested contributor.
|
||||
If you are interested in tackling a good second issue, feel free to open a PR to fix it and link the PR to the issue. If you see that a PR has already been opened for this issue but did not get merged, have a look to understand why it wasn't merged and try to open an improved PR.
|
||||
Good second issues are usually more difficult to get merged compared to good first issues, so don't hesitate to ask for help from the core maintainers. If your PR is almost finished the core maintainers can also jump into your PR and commit to it in order to get it merged.
|
||||
|
||||
### 9. Adding pipelines, models, schedulers
|
||||
|
||||
Pipelines, models, and schedulers are the most important pieces of the Diffusers library.
|
||||
They provide easy access to state-of-the-art diffusion technologies and thus allow the community to
|
||||
build powerful generative AI applications.
|
||||
|
||||
By adding a new model, pipeline, or scheduler you might enable a new powerful use case for any of the user interfaces relying on Diffusers which can be of immense value for the whole generative AI ecosystem.
|
||||
|
||||
Diffusers has a couple of open feature requests for all three components - feel free to gloss over them
|
||||
if you don't know yet what specific component you would like to add:
|
||||
- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)
|
||||
- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
|
||||
|
||||
Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md) a read to better understand the design of any of the three components. Please be aware that
|
||||
we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
|
||||
as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please
|
||||
open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design
|
||||
pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
|
||||
|
||||
Please make sure to add links to the original codebase/paper to the PR and ideally also ping the
|
||||
original author directly on the PR so that they can follow the progress and potentially help with questions.
|
||||
|
||||
If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help.
|
||||
|
||||
## How to write a good issue
|
||||
|
||||
**The better your issue is written, the higher the chances that it will be quickly resolved.**
|
||||
|
||||
1. Make sure that you've used the correct template for your issue. You can pick between *Bug Report*, *Feature Request*, *Feedback about API Design*, *New model/pipeline/scheduler addition*, *Forum*, or a blank issue. Make sure to pick the correct one when opening [a new issue](https://github.com/huggingface/diffusers/issues/new/choose).
|
||||
2. **Be precise**: Give your issue a fitting title. Try to formulate your issue description as simple as possible. The more precise you are when submitting an issue, the less time it takes to understand the issue and potentially solve it. Make sure to open an issue for one issue only and not for multiple issues. If you found multiple issues, simply open multiple issues. If your issue is a bug, try to be as precise as possible about what bug it is - you should not just write "Error in diffusers".
|
||||
3. **Reproducibility**: No reproducible code snippet == no solution. If you encounter a bug, maintainers **have to be able to reproduce** it. Make sure that you include a code snippet that can be copy-pasted into a Python interpreter to reproduce the issue. Make sure that your code snippet works, *i.e.* that there are no missing imports or missing links to images, ... Your issue should contain an error message **and** a code snippet that can be copy-pasted without any changes to reproduce the exact same error message. If your issue is using local model weights or local data that cannot be accessed by the reader, the issue cannot be solved. If you cannot share your data or model, try to make a dummy model or dummy data.
|
||||
4. **Minimalistic**: Try to help the reader as much as you can to understand the issue as quickly as possible by staying as concise as possible. Remove all code / all information that is irrelevant to the issue. If you have found a bug, try to create the easiest code example you can to demonstrate your issue, do not just dump your whole workflow into the issue as soon as you have found a bug. E.g., if you train a model and get an error at some point during the training, you should first try to understand what part of the training code is responsible for the error and try to reproduce it with a couple of lines. Try to use dummy data instead of full datasets.
|
||||
5. Add links. If you are referring to a certain naming, method, or model make sure to provide a link so that the reader can better understand what you mean. If you are referring to a specific PR or issue, make sure to link it to your issue. Do not assume that the reader knows what you are talking about. The more links you add to your issue the better.
|
||||
6. Formatting. Make sure to nicely format your issue by formatting code into Python code syntax, and error messages into normal code syntax. See the [official GitHub formatting docs](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax) for more information.
|
||||
7. Think of your issue not as a ticket to be solved, but rather as a beautiful entry to a well-written encyclopedia. Every added issue is a contribution to publicly available knowledge. By adding a nicely written issue you not only make it easier for maintainers to solve your issue, but you are helping the whole community to better understand a certain aspect of the library.
|
||||
|
||||
## How to write a good PR
|
||||
|
||||
1. Be a chameleon. Understand existing design patterns and syntax and make sure your code additions flow seamlessly into the existing code base. Pull requests that significantly diverge from existing design patterns or user interfaces will not be merged.
|
||||
2. Be laser focused. A pull request should solve one problem and one problem only. Make sure to not fall into the trap of "also fixing another problem while we're adding it". It is much more difficult to review pull requests that solve multiple, unrelated problems at once.
|
||||
3. If helpful, try to add a code snippet that displays an example of how your addition can be used.
|
||||
4. The title of your pull request should be a summary of its contribution.
|
||||
5. If your pull request addresses an issue, please mention the issue number in
|
||||
the pull request description to make sure they are linked (and people
|
||||
consulting the issue know you are working on it);
|
||||
6. To indicate a work in progress please prefix the title with `[WIP]`. These
|
||||
are useful to avoid duplicated work, and to differentiate it from PRs ready
|
||||
to be merged;
|
||||
7. Try to formulate and format your text as explained in [How to write a good issue](#how-to-write-a-good-issue).
|
||||
8. Make sure existing tests pass;
|
||||
9. Add high-coverage tests. No quality testing = no merge.
|
||||
- If you are adding new `@slow` tests, make sure they pass using
|
||||
`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.
|
||||
CircleCI does not run the slow tests, but GitHub Actions does every night!
|
||||
10. All public methods must have informative docstrings that work nicely with markdown. See [`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) for an example.
|
||||
11. Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like
|
||||
[`hf-internal-testing`](https://huggingface.co/hf-internal-testing) or [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images) to place these files.
|
||||
If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images
|
||||
to this dataset.
|
||||
|
||||
## How to open a PR
|
||||
|
||||
Before writing code, we strongly advise you to search through the existing PRs or
|
||||
issues to make sure that nobody is already working on the same thing. If you are
|
||||
unsure, it is always a good idea to open an issue to get some feedback.
|
||||
|
||||
You will need basic `git` proficiency to be able to contribute to
|
||||
🧨 Diffusers. `git` is not the easiest tool to use but it has the greatest
|
||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/42f25d601a910dceadaee6c44345896b4cfa9928/setup.py#L270)):
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/diffusers) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote:
|
||||
|
||||
```bash
|
||||
$ git clone git@github.com:<your GitHub handle>/diffusers.git
|
||||
$ cd diffusers
|
||||
$ git remote add upstream https://github.com/huggingface/diffusers.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes:
|
||||
|
||||
```bash
|
||||
$ git checkout -b a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
**Do not** work on the `main` branch.
|
||||
|
||||
4. Set up a development environment by running the following command in a virtual environment:
|
||||
|
||||
```bash
|
||||
$ pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
If you have already cloned the repo, you might need to `git pull` to get the most recent changes in the
|
||||
library.
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
passes. You should run the tests impacted by your changes like this:
|
||||
|
||||
```bash
|
||||
$ pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
Before you run the tests, please make sure you install the dependencies required for testing. You can do so
|
||||
with this command:
|
||||
|
||||
```bash
|
||||
$ pip install -e ".[test]"
|
||||
```
|
||||
|
||||
You can also run the full test suite with the following command, but it takes
|
||||
a beefy machine to produce a result in a decent amount of time now that
|
||||
Diffusers has grown a lot. Here is the command for it:
|
||||
|
||||
```bash
|
||||
$ make test
|
||||
```
|
||||
|
||||
🧨 Diffusers relies on `ruff` and `isort` to format its source code
|
||||
consistently. After you make changes, apply automatic style corrections and code verifications
|
||||
that can't be automated in one go with:
|
||||
|
||||
```bash
|
||||
$ make style
|
||||
```
|
||||
|
||||
🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality
|
||||
control runs in CI, however, you can also run the same checks with:
|
||||
|
||||
```bash
|
||||
$ make quality
|
||||
```
|
||||
|
||||
Once you're happy with your changes, add changed files using `git add` and
|
||||
make a commit with `git commit` to record your changes locally:
|
||||
|
||||
```bash
|
||||
$ git add modified_file.py
|
||||
$ git commit -m "A descriptive message about your changes."
|
||||
```
|
||||
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
|
||||
```bash
|
||||
$ git pull upstream main
|
||||
```
|
||||
|
||||
Push the changes to your account using:
|
||||
|
||||
```bash
|
||||
$ git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
6. Once you are satisfied, go to the
|
||||
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
||||
to the project maintainers for review.
|
||||
|
||||
7. It's ok if maintainers ask you for changes. It happens to core contributors
|
||||
too! So everyone can see the changes in the Pull request, work in your local
|
||||
branch and push the changes to your fork. They will automatically appear in
|
||||
the pull request.
|
||||
|
||||
### Tests
|
||||
|
||||
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
|
||||
the [tests folder](https://github.com/huggingface/diffusers/tree/main/tests).
|
||||
|
||||
We like `pytest` and `pytest-xdist` because it's faster. From the root of the
|
||||
repository, here's how to run tests with `pytest` for the library:
|
||||
|
||||
```bash
|
||||
$ python -m pytest -n auto --dist=loadfile -s -v ./tests/
|
||||
```
|
||||
|
||||
In fact, that's how `make test` is implemented!
|
||||
|
||||
You can specify a smaller set of tests in order to test only the feature
|
||||
you're working on.
|
||||
|
||||
By default, slow tests are skipped. Set the `RUN_SLOW` environment variable to
|
||||
`yes` to run them. This will download many gigabytes of models — make sure you
|
||||
have enough disk space and a good Internet connection, or a lot of patience!
|
||||
|
||||
```bash
|
||||
$ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/
|
||||
```
|
||||
|
||||
`unittest` is fully supported, here's how to run tests with it:
|
||||
|
||||
```bash
|
||||
$ python -m unittest discover -s tests -t . -v
|
||||
$ python -m unittest discover -s examples -t examples -v
|
||||
```
|
||||
|
||||
### Syncing forked main with upstream (HuggingFace) main
|
||||
|
||||
To avoid pinging the upstream repository which adds reference notes to each upstream PR and sends unnecessary notifications to the developers involved in these PRs,
|
||||
when syncing the main branch of a forked repository, please, follow these steps:
|
||||
1. When possible, avoid syncing with the upstream using a branch and PR on the forked repository. Instead, merge directly into the forked main.
|
||||
2. If a PR is absolutely necessary, use the following steps after checking out your branch:
|
||||
```bash
|
||||
$ git checkout -b your-branch-for-syncing
|
||||
$ git pull --squash --no-commit upstream main
|
||||
$ git commit -m '<your message without GitHub references>'
|
||||
$ git push --set-upstream origin your-branch-for-syncing
|
||||
```
|
||||
|
||||
### Style guide
|
||||
|
||||
For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).
|
||||
@@ -99,9 +99,3 @@ image.save("chroma-single-file.png")
|
||||
[[autodoc]] ChromaImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ChromaInpaintPipeline
|
||||
|
||||
[[autodoc]] ChromaInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -35,11 +35,5 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a
|
||||
## Flux2Pipeline
|
||||
|
||||
[[autodoc]] Flux2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Flux2KleinPipeline
|
||||
|
||||
[[autodoc]] Flux2KleinPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -1,22 +1,14 @@
|
||||
# DreamBooth training example for FLUX.2 [dev] and FLUX 2 [klein]
|
||||
# DreamBooth training example for FLUX.2 [dev]
|
||||
|
||||
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
|
||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||
|
||||
The `train_dreambooth_lora_flux2.py`, `train_dreambooth_lora_flux2_klein.py` scripts shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://huggingface.co/black-forest-labs/FLUX.2-dev) and [FLUX 2 [klein]](https://huggingface.co/black-forest-labs/FLUX.2-klein).
|
||||
|
||||
> [!NOTE]
|
||||
> **Model Variants**
|
||||
>
|
||||
> We support two FLUX model families:
|
||||
> - **FLUX.2 [dev]**: The full-size model using Mistral Small 3.1 as the text encoder. Very capable but memory intensive.
|
||||
> - **FLUX 2 [klein]**: Available in 4B and 9B parameter variants, using Qwen VL as the text encoder. Much more memory efficient and suitable for consumer hardware.
|
||||
The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2).
|
||||
|
||||
> [!NOTE]
|
||||
> **Memory consumption**
|
||||
>
|
||||
> FLUX.2 [dev] can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
|
||||
> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. FLUX 2 [klein] models (4B and 9B) are significantly more memory efficient alternatives. Below we provide some tips and tricks to reduce memory consumption during training.
|
||||
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
|
||||
> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training.
|
||||
|
||||
> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
|
||||
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md)
|
||||
@@ -25,7 +17,7 @@ The `train_dreambooth_lora_flux2.py`, `train_dreambooth_lora_flux2_klein.py` scr
|
||||
> [!NOTE]
|
||||
> **Gated model**
|
||||
>
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you've accepted the gate. Use the command below to log in:
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
hf auth login
|
||||
@@ -96,32 +88,23 @@ snapshot_download(
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
As mentioned, Flux2 LoRA training is *very* memory intensive (especially for FLUX.2 [dev]). Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
|
||||
As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
|
||||
|
||||
## Memory Optimizations
|
||||
> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption.
|
||||
> However some techniques may be mutually exclusive so be sure to check before launching a training run.
|
||||
|
||||
### Remote Text Encoder
|
||||
FLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
|
||||
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
|
||||
This way, the text encoder model is not loaded into memory during training.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> **Remote text encoder is only supported for FLUX.2 [dev]**. FLUX 2 [klein] models use the Qwen VL text encoder and do not support remote text encoding.
|
||||
|
||||
> [!NOTE]
|
||||
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
|
||||
|
||||
### FSDP Text Encoder
|
||||
FLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
|
||||
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
|
||||
This way, it distributes the memory cost across multiple nodes.
|
||||
|
||||
### CPU Offloading
|
||||
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
|
||||
|
||||
### Latent Caching
|
||||
Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
|
||||
|
||||
### QLoRA: Low Precision Training with Quantization
|
||||
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
|
||||
- **FP8 training** with `torchao`:
|
||||
@@ -131,29 +114,22 @@ enable FP8 training by passing `--do_fp8_training`.
|
||||
- **NF4 training** with `bitsandbytes`:
|
||||
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing:
|
||||
`--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
|
||||
|
||||
### Gradient Checkpointing and Accumulation
|
||||
* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
|
||||
by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
|
||||
* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
|
||||
Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
|
||||
|
||||
### 8-bit-Adam Optimizer
|
||||
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
|
||||
Make sure to install `bitsandbytes` if you want to do so.
|
||||
|
||||
### Image Resolution
|
||||
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
|
||||
Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
|
||||
|
||||
### Precision of saved LoRA layers
|
||||
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
|
||||
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
|
||||
|
||||
## Training Examples
|
||||
|
||||
### FLUX.2 [dev] Training
|
||||
To perform DreamBooth with LoRA on FLUX.2 [dev], run:
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
|
||||
export INSTANCE_DIR="dog"
|
||||
@@ -185,84 +161,13 @@ accelerate launch train_dreambooth_lora_flux2.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### FLUX 2 [klein] Training
|
||||
|
||||
FLUX 2 [klein] models are more memory efficient alternatives available in 4B and 9B parameter variants. They use the Qwen VL text encoder instead of Mistral Small 3.1.
|
||||
|
||||
> [!NOTE]
|
||||
> The `--remote_text_encoder` flag is **not supported** for FLUX 2 [klein] models. The Qwen VL text encoder must be loaded locally, but offloading is still supported.
|
||||
|
||||
**FLUX 2 [klein] 4B:**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.2-klein-4B"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-flux2-klein-4b"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux2_klein.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_fp8_training \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--use_8bit_adam \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--optimizer="adamW" \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
**FLUX 2 [klein] 9B:**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.2-klein-9B"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-flux2-klein-9b"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux2_klein.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_fp8_training \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--use_8bit_adam \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--optimizer="adamW" \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
> [!NOTE]
|
||||
> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. Note that this will use more resources and may slow down the training in some cases.
|
||||
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
|
||||
|
||||
### FSDP on the transformer
|
||||
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
|
||||
@@ -284,6 +189,12 @@ fsdp_config:
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
```
|
||||
|
||||
## LoRA + DreamBooth
|
||||
|
||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
### Prodigy Optimizer
|
||||
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
|
||||
By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
|
||||
@@ -295,6 +206,8 @@ to use prodigy, first make sure to install the prodigyopt library: `pip install
|
||||
> [!TIP]
|
||||
> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
|
||||
|
||||
To perform DreamBooth with LoRA, run:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
|
||||
export INSTANCE_DIR="dog"
|
||||
@@ -358,10 +271,13 @@ the exact modules for LoRA training. Here are some examples of target modules yo
|
||||
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
|
||||
|
||||
|
||||
|
||||
## Training Image-to-Image
|
||||
|
||||
Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
|
||||
|
||||
**important**
|
||||
|
||||
**Important**
|
||||
To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
@@ -418,6 +334,5 @@ we've added aspect ratio bucketing support which allows training on images with
|
||||
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
|
||||
|
||||
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
|
||||
|
||||
|
||||
Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
|
||||
`
|
||||
Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
|
||||
|
||||
@@ -1,262 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRAFlux2Klein(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "dog"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_flux2_klein.py"
|
||||
transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj"
|
||||
|
||||
def test_dreambooth_lora_flux2(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
|
||||
starts_with_transformer = all(
|
||||
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 8
|
||||
--checkpointing_steps=2
|
||||
--text_encoder_out_layers 1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_with_metadata(self):
|
||||
# Use a `lora_alpha` that is different from `rank`.
|
||||
lora_alpha = 8
|
||||
rank = 4
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_alpha={lora_alpha}
|
||||
--rank={rank}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(state_dict_file))
|
||||
|
||||
# Check if the metadata was properly serialized.
|
||||
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
if raw:
|
||||
raw = json.loads(raw)
|
||||
|
||||
loaded_lora_alpha = raw["transformer.lora_alpha"]
|
||||
self.assertTrue(loaded_lora_alpha == lora_alpha)
|
||||
loaded_lora_rank = raw["transformer.r"]
|
||||
self.assertTrue(loaded_lora_rank == rank)
|
||||
@@ -127,7 +127,7 @@ def save_model_card(
|
||||
)
|
||||
|
||||
model_description = f"""
|
||||
# Flux.2 DreamBooth LoRA - {repo_id}
|
||||
# Flux DreamBooth LoRA - {repo_id}
|
||||
|
||||
<Gallery />
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -44,7 +44,7 @@ CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
|
||||
parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str)
|
||||
parser.add_argument("--dit_filename", default="flux2-dev.safetensors", type=str)
|
||||
parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str)
|
||||
parser.add_argument("--vae", action="store_true")
|
||||
parser.add_argument("--dit", action="store_true")
|
||||
parser.add_argument("--vae_dtype", type=str, default="fp32")
|
||||
@@ -385,9 +385,9 @@ def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) ->
|
||||
|
||||
|
||||
def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
if model_type == "flux2-dev":
|
||||
if model_type == "test" or model_type == "dummy-flux2":
|
||||
config = {
|
||||
"model_id": "black-forest-labs/FLUX.2-dev",
|
||||
"model_id": "diffusers-internal-dev/dummy-flux2",
|
||||
"diffusers_config": {
|
||||
"patch_size": 1,
|
||||
"in_channels": 128,
|
||||
@@ -405,53 +405,6 @@ def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
}
|
||||
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "klein-4b":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy0115",
|
||||
"diffusers_config": {
|
||||
"patch_size": 1,
|
||||
"in_channels": 128,
|
||||
"num_layers": 5,
|
||||
"num_single_layers": 20,
|
||||
"attention_head_dim": 128,
|
||||
"num_attention_heads": 24,
|
||||
"joint_attention_dim": 7680,
|
||||
"timestep_guidance_channels": 256,
|
||||
"mlp_ratio": 3.0,
|
||||
"axes_dims_rope": (32, 32, 32, 32),
|
||||
"rope_theta": 2000,
|
||||
"eps": 1e-6,
|
||||
"guidance_embeds": False,
|
||||
},
|
||||
}
|
||||
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
|
||||
elif model_type == "klein-9b":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy0115",
|
||||
"diffusers_config": {
|
||||
"patch_size": 1,
|
||||
"in_channels": 128,
|
||||
"num_layers": 8,
|
||||
"num_single_layers": 24,
|
||||
"attention_head_dim": 128,
|
||||
"num_attention_heads": 32,
|
||||
"joint_attention_dim": 12288,
|
||||
"timestep_guidance_channels": 256,
|
||||
"mlp_ratio": 3.0,
|
||||
"axes_dims_rope": (32, 32, 32, 32),
|
||||
"rope_theta": 2000,
|
||||
"eps": 1e-6,
|
||||
"guidance_embeds": False,
|
||||
},
|
||||
}
|
||||
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown model_type: {model_type}. Choose from: flux2-dev, klein-4b, klein-9b")
|
||||
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
@@ -494,14 +447,7 @@ def main(args):
|
||||
|
||||
if args.dit:
|
||||
original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename)
|
||||
|
||||
if "klein-4b" in args.dit_filename:
|
||||
model_type = "klein-4b"
|
||||
elif "klein-9b" in args.dit_filename:
|
||||
model_type = "klein-9b"
|
||||
else:
|
||||
model_type = "flux2-dev"
|
||||
transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, model_type)
|
||||
transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test")
|
||||
if not args.full_pipe:
|
||||
dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32
|
||||
transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer")
|
||||
@@ -519,15 +465,8 @@ def main(args):
|
||||
"black-forest-labs/FLUX.1-dev", subfolder="scheduler"
|
||||
)
|
||||
|
||||
if_distilled = "base" not in args.dit_filename
|
||||
|
||||
pipe = Flux2Pipeline(
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
if_distilled=if_distilled,
|
||||
vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
|
||||
)
|
||||
pipe.save_pretrained(args.output_path)
|
||||
|
||||
|
||||
@@ -460,7 +460,6 @@ else:
|
||||
"BriaFiboPipeline",
|
||||
"BriaPipeline",
|
||||
"ChromaImg2ImgPipeline",
|
||||
"ChromaInpaintPipeline",
|
||||
"ChromaPipeline",
|
||||
"ChronoEditPipeline",
|
||||
"CLIPImageProjection",
|
||||
@@ -481,7 +480,6 @@ else:
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimatePipeline",
|
||||
"Flux2KleinPipeline",
|
||||
"Flux2Pipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
@@ -1188,7 +1186,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
BriaFiboPipeline,
|
||||
BriaPipeline,
|
||||
ChromaImg2ImgPipeline,
|
||||
ChromaInpaintPipeline,
|
||||
ChromaPipeline,
|
||||
ChronoEditPipeline,
|
||||
CLIPImageProjection,
|
||||
@@ -1209,7 +1206,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
Flux2KleinPipeline,
|
||||
Flux2Pipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
|
||||
@@ -40,9 +40,6 @@ from .single_file_utils import (
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx2_audio_vae_to_diffusers,
|
||||
convert_ltx2_transformer_to_diffusers,
|
||||
convert_ltx2_vae_to_diffusers,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
@@ -179,18 +176,6 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"ZImageControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
|
||||
},
|
||||
"LTX2VideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_ltx2_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLLTX2Video": {
|
||||
"checkpoint_mapping_fn": convert_ltx2_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"AutoencoderKLLTX2Audio": {
|
||||
"checkpoint_mapping_fn": convert_ltx2_audio_vae_to_diffusers,
|
||||
"default_subfolder": "audio_vae",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -112,8 +112,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
|
||||
"patchify_proj.weight",
|
||||
"transformer_blocks.27.scale_shift_table",
|
||||
"vae.decoder.last_scale_shift_table", # 0.9.1, 0.9.5, 0.9.7, 0.9.8
|
||||
"vae.decoder.up_blocks.9.res_blocks.0.conv1.conv.weight", # 0.9.0
|
||||
"vae.per_channel_statistics.mean-of-means",
|
||||
],
|
||||
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
||||
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
||||
@@ -148,11 +147,6 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"net.pos_embedder.dim_spatial_range",
|
||||
],
|
||||
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
|
||||
"ltx2": [
|
||||
"model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1.weight",
|
||||
"vae.per_channel_statistics.mean-of-means",
|
||||
"audio_vae.per_channel_statistics.mean-of-means",
|
||||
],
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -234,7 +228,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
|
||||
"z-image-turbo-controlnet-2.0": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0"},
|
||||
"z-image-turbo-controlnet-2.1": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
|
||||
"ltx2-dev": {"pretrained_model_name_or_path": "Lightricks/LTX-2"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -803,9 +796,6 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
|
||||
model_type = "z-image-turbo-controlnet"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx2"]):
|
||||
model_type = "ltx2-dev"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -3930,161 +3920,3 @@ def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwa
|
||||
return converted_state_dict
|
||||
else:
|
||||
raise ValueError("Unknown Z-Image Turbo ControlNet type.")
|
||||
|
||||
|
||||
def convert_ltx2_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Transformer prefix
|
||||
"model.diffusion_model.": "",
|
||||
# Input Patchify Projections
|
||||
"patchify_proj": "proj_in",
|
||||
"audio_patchify_proj": "audio_proj_in",
|
||||
# Modulation Parameters
|
||||
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
|
||||
# substrings of the other modulation parameters below
|
||||
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
||||
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
||||
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
||||
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
||||
# Transformer Blocks
|
||||
# Per-Block Cross Attention Modulation Parameters
|
||||
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
def remove_keys_inplace(key: str, state_dict) -> None:
|
||||
state_dict.pop(key)
|
||||
|
||||
def convert_ltx2_transformer_adaln_single(key: str, state_dict) -> None:
|
||||
# Skip if not a weight, bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
if key.startswith("adaln_single."):
|
||||
new_key = key.replace("adaln_single.", "time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
if key.startswith("audio_adaln_single."):
|
||||
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
update_state_dict_inplace(converted_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_ltx2_vae_to_diffusers(checkpoint, **kwargs):
|
||||
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
# Video VAE prefix
|
||||
"vae.": "",
|
||||
# Encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# Decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
def remove_keys_inplace(key: str, state_dict) -> None:
|
||||
state_dict.pop(key)
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
}
|
||||
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in LTX_2_0_VIDEO_VAE_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
update_state_dict_inplace(converted_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in LTX_2_0_VAE_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae_to_diffusers(checkpoint, **kwargs):
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
|
||||
# Audio VAE prefix
|
||||
"audio_vae.": "",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in LTX_2_0_AUDIO_VAE_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
update_state_dict_inplace(converted_state_dict, key, new_key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -235,10 +235,6 @@ class _AttentionBackendRegistry:
|
||||
def get_active_backend(cls):
|
||||
return cls._active_backend, cls._backends[cls._active_backend]
|
||||
|
||||
@classmethod
|
||||
def set_active_backend(cls, backend: str):
|
||||
cls._active_backend = backend
|
||||
|
||||
@classmethod
|
||||
def list_backends(cls):
|
||||
return list(cls._backends.keys())
|
||||
@@ -298,12 +294,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
|
||||
_maybe_download_kernel_for_backend(backend)
|
||||
|
||||
old_backend = _AttentionBackendRegistry._active_backend
|
||||
_AttentionBackendRegistry.set_active_backend(backend)
|
||||
_AttentionBackendRegistry._active_backend = backend
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_AttentionBackendRegistry.set_active_backend(old_backend)
|
||||
_AttentionBackendRegistry._active_backend = old_backend
|
||||
|
||||
|
||||
def dispatch_attention_fn(
|
||||
@@ -352,7 +348,6 @@ def dispatch_attention_fn(
|
||||
check(**kwargs)
|
||||
|
||||
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
|
||||
|
||||
return backend_fn(**kwargs)
|
||||
|
||||
|
||||
@@ -1578,6 +1573,8 @@ def _templated_context_parallel_attention(
|
||||
backward_op,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("Attention mask is not yet supported for templated attention.")
|
||||
if is_causal:
|
||||
raise ValueError("Causal attention is not yet supported for templated attention.")
|
||||
if enable_gqa:
|
||||
|
||||
@@ -599,7 +599,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_dispatch import (
|
||||
AttentionBackendName,
|
||||
_AttentionBackendRegistry,
|
||||
_check_attention_backend_requirements,
|
||||
_maybe_download_kernel_for_backend,
|
||||
)
|
||||
@@ -608,16 +607,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
|
||||
parallel_config_set = False
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
processor = module.processor
|
||||
if getattr(processor, "_parallel_config", None) is not None:
|
||||
parallel_config_set = True
|
||||
break
|
||||
|
||||
backend = backend.lower()
|
||||
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
||||
@@ -625,17 +614,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||
|
||||
backend = AttentionBackendName(backend)
|
||||
if parallel_config_set and not _AttentionBackendRegistry._is_context_parallel_available(backend):
|
||||
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
|
||||
raise ValueError(
|
||||
f"Context parallelism is enabled but current attention backend '{backend.value}' "
|
||||
f"does not support context parallelism. "
|
||||
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()`."
|
||||
)
|
||||
|
||||
_check_attention_backend_requirements(backend)
|
||||
_maybe_download_kernel_for_backend(backend)
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
@@ -644,9 +626,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
continue
|
||||
processor._attention_backend = backend
|
||||
|
||||
# Important to set the active backend so that it propagates gracefully throughout.
|
||||
_AttentionBackendRegistry.set_active_backend(backend)
|
||||
|
||||
def reset_attention_backend(self) -> None:
|
||||
"""
|
||||
Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
|
||||
@@ -1559,7 +1538,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
|
||||
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
|
||||
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
|
||||
f"calling `model.enable_parallelism()`."
|
||||
f"calling `enable_parallelism()`."
|
||||
)
|
||||
|
||||
# All modules use the same attention processor and backend. We don't need to
|
||||
|
||||
@@ -585,13 +585,7 @@ class Flux2PosEmbed(nn.Module):
|
||||
|
||||
|
||||
class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 256,
|
||||
embedding_dim: int = 6144,
|
||||
bias: bool = False,
|
||||
guidance_embeds: bool = True,
|
||||
):
|
||||
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
@@ -599,24 +593,20 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
|
||||
if guidance_embeds:
|
||||
self.guidance_embedder = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
else:
|
||||
self.guidance_embedder = None
|
||||
self.guidance_embedder = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
|
||||
|
||||
if guidance is not None and self.guidance_embedder is not None:
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
||||
time_guidance_emb = timesteps_emb + guidance_emb
|
||||
return time_guidance_emb
|
||||
else:
|
||||
return timesteps_emb
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
||||
|
||||
time_guidance_emb = timesteps_emb + guidance_emb
|
||||
|
||||
return time_guidance_emb
|
||||
|
||||
|
||||
class Flux2Modulation(nn.Module):
|
||||
@@ -708,7 +698,6 @@ class Flux2Transformer2DModel(
|
||||
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
||||
rope_theta: int = 2000,
|
||||
eps: float = 1e-6,
|
||||
guidance_embeds: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
@@ -719,10 +708,7 @@ class Flux2Transformer2DModel(
|
||||
|
||||
# 2. Combined timestep + guidance embedding
|
||||
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
||||
in_channels=timestep_guidance_channels,
|
||||
embedding_dim=self.inner_dim,
|
||||
bias=False,
|
||||
guidance_embeds=guidance_embeds,
|
||||
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
|
||||
)
|
||||
|
||||
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
||||
@@ -829,9 +815,7 @@ class Flux2Transformer2DModel(
|
||||
|
||||
# 1. Calculate timestep embedding and modulation parameters
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
temb = self.time_guidance_embed(timestep, guidance)
|
||||
|
||||
|
||||
@@ -761,14 +761,11 @@ class QwenImageTransformer2DModel(
|
||||
_no_split_modules = ["QwenImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||
# Make CP plan compatible with https://github.com/huggingface/diffusers/pull/12702
|
||||
_cp_plan = {
|
||||
"transformer_blocks.0": {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"transformer_blocks.*": {
|
||||
"modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
},
|
||||
"pos_embed": {
|
||||
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
|
||||
@@ -23,18 +23,10 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass(frozen=True)
|
||||
class MellonParam:
|
||||
"""
|
||||
Parameter definition for Mellon nodes.
|
||||
Parameter definition for Mellon nodes.
|
||||
|
||||
Use factory methods for common params (e.g., MellonParam.seed()) or create custom ones with
|
||||
MellonParam(name="...", label="...", type="...").
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Custom param
|
||||
MellonParam(name="my_param", label="My Param", type="float", default=0.5)
|
||||
# Output in Mellon node definition:
|
||||
# "my_param": {"label": "My Param", "type": "float", "default": 0.5}
|
||||
```
|
||||
Use factory methods for common params (e.g., MellonParam.seed()) or create custom ones with MellonParam(name="...",
|
||||
label="...", type="...").
|
||||
"""
|
||||
|
||||
name: str
|
||||
@@ -59,32 +51,14 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def image(cls) -> "MellonParam":
|
||||
"""
|
||||
Image input parameter.
|
||||
|
||||
Mellon node definition:
|
||||
"image": {"label": "Image", "type": "image", "display": "input"}
|
||||
"""
|
||||
return cls(name="image", label="Image", type="image", display="input", required_block_params=["image"])
|
||||
|
||||
@classmethod
|
||||
def images(cls) -> "MellonParam":
|
||||
"""
|
||||
Images output parameter.
|
||||
|
||||
Mellon node definition:
|
||||
"images": {"label": "Images", "type": "image", "display": "output"}
|
||||
"""
|
||||
return cls(name="images", label="Images", type="image", display="output", required_block_params=["images"])
|
||||
|
||||
@classmethod
|
||||
def control_image(cls, display: str = "input") -> "MellonParam":
|
||||
"""
|
||||
Control image parameter for ControlNet.
|
||||
|
||||
Mellon node definition (display="input"):
|
||||
"control_image": {"label": "Control Image", "type": "image", "display": "input"}
|
||||
"""
|
||||
return cls(
|
||||
name="control_image",
|
||||
label="Control Image",
|
||||
@@ -95,25 +69,10 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def latents(cls, display: str = "input") -> "MellonParam":
|
||||
"""
|
||||
Latents parameter.
|
||||
|
||||
Mellon node definition (display="input"):
|
||||
"latents": {"label": "Latents", "type": "latents", "display": "input"}
|
||||
|
||||
Mellon node definition (display="output"):
|
||||
"latents": {"label": "Latents", "type": "latents", "display": "output"}
|
||||
"""
|
||||
return cls(name="latents", label="Latents", type="latents", display=display, required_block_params=["latents"])
|
||||
|
||||
@classmethod
|
||||
def image_latents(cls, display: str = "input") -> "MellonParam":
|
||||
"""
|
||||
Image latents parameter for img2img workflows.
|
||||
|
||||
Mellon node definition (display="input"):
|
||||
"image_latents": {"label": "Image Latents", "type": "latents", "display": "input"}
|
||||
"""
|
||||
return cls(
|
||||
name="image_latents",
|
||||
label="Image Latents",
|
||||
@@ -124,12 +83,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def first_frame_latents(cls, display: str = "input") -> "MellonParam":
|
||||
"""
|
||||
First frame latents for video generation.
|
||||
|
||||
Mellon node definition (display="input"):
|
||||
"first_frame_latents": {"label": "First Frame Latents", "type": "latents", "display": "input"}
|
||||
"""
|
||||
return cls(
|
||||
name="first_frame_latents",
|
||||
label="First Frame Latents",
|
||||
@@ -140,16 +93,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def image_latents_with_strength(cls) -> "MellonParam":
|
||||
"""
|
||||
Image latents with strength-based onChange behavior. When connected, shows strength slider; when disconnected,
|
||||
shows height/width.
|
||||
|
||||
Mellon node definition:
|
||||
"image_latents": {
|
||||
"label": "Image Latents", "type": "latents", "display": "input", "onChange": {"false": ["height",
|
||||
"width"], "true": ["strength"]}
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="image_latents",
|
||||
label="Image Latents",
|
||||
@@ -162,34 +105,16 @@ class MellonParam:
|
||||
@classmethod
|
||||
def latents_preview(cls) -> "MellonParam":
|
||||
"""
|
||||
Latents preview output for visualizing latents in the UI.
|
||||
|
||||
Mellon node definition:
|
||||
"latents_preview": {"label": "Latents Preview", "type": "latent", "display": "output"}
|
||||
`Latents Preview` is a special output parameter that is used to preview the latents in the UI.
|
||||
"""
|
||||
return cls(name="latents_preview", label="Latents Preview", type="latent", display="output")
|
||||
|
||||
@classmethod
|
||||
def embeddings(cls, display: str = "output") -> "MellonParam":
|
||||
"""
|
||||
Text embeddings parameter.
|
||||
|
||||
Mellon node definition (display="output"):
|
||||
"embeddings": {"label": "Text Embeddings", "type": "embeddings", "display": "output"}
|
||||
|
||||
Mellon node definition (display="input"):
|
||||
"embeddings": {"label": "Text Embeddings", "type": "embeddings", "display": "input"}
|
||||
"""
|
||||
return cls(name="embeddings", label="Text Embeddings", type="embeddings", display=display)
|
||||
|
||||
@classmethod
|
||||
def image_embeds(cls, display: str = "output") -> "MellonParam":
|
||||
"""
|
||||
Image embeddings parameter for IP-Adapter workflows.
|
||||
|
||||
Mellon node definition (display="output"):
|
||||
"image_embeds": {"label": "Image Embeddings", "type": "image_embeds", "display": "output"}
|
||||
"""
|
||||
return cls(
|
||||
name="image_embeds",
|
||||
label="Image Embeddings",
|
||||
@@ -200,15 +125,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam":
|
||||
"""
|
||||
ControlNet conditioning scale slider.
|
||||
|
||||
Mellon node definition (default=0.5):
|
||||
"controlnet_conditioning_scale": {
|
||||
"label": "Controlnet Conditioning Scale", "type": "float", "default": 0.5, "min": 0.0, "max": 1.0,
|
||||
"step": 0.01
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="controlnet_conditioning_scale",
|
||||
label="Controlnet Conditioning Scale",
|
||||
@@ -222,15 +138,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def control_guidance_start(cls, default: float = 0.0) -> "MellonParam":
|
||||
"""
|
||||
Control guidance start timestep.
|
||||
|
||||
Mellon node definition (default=0.0):
|
||||
"control_guidance_start": {
|
||||
"label": "Control Guidance Start", "type": "float", "default": 0.0, "min": 0.0, "max": 1.0, "step":
|
||||
0.01
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="control_guidance_start",
|
||||
label="Control Guidance Start",
|
||||
@@ -244,14 +151,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def control_guidance_end(cls, default: float = 1.0) -> "MellonParam":
|
||||
"""
|
||||
Control guidance end timestep.
|
||||
|
||||
Mellon node definition (default=1.0):
|
||||
"control_guidance_end": {
|
||||
"label": "Control Guidance End", "type": "float", "default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="control_guidance_end",
|
||||
label="Control Guidance End",
|
||||
@@ -265,12 +164,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def prompt(cls, default: str = "") -> "MellonParam":
|
||||
"""
|
||||
Text prompt input as textarea.
|
||||
|
||||
Mellon node definition (default=""):
|
||||
"prompt": {"label": "Prompt", "type": "string", "default": "", "display": "textarea"}
|
||||
"""
|
||||
return cls(
|
||||
name="prompt",
|
||||
label="Prompt",
|
||||
@@ -282,12 +175,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def negative_prompt(cls, default: str = "") -> "MellonParam":
|
||||
"""
|
||||
Negative prompt input as textarea.
|
||||
|
||||
Mellon node definition (default=""):
|
||||
"negative_prompt": {"label": "Negative Prompt", "type": "string", "default": "", "display": "textarea"}
|
||||
"""
|
||||
return cls(
|
||||
name="negative_prompt",
|
||||
label="Negative Prompt",
|
||||
@@ -299,12 +186,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def strength(cls, default: float = 0.5) -> "MellonParam":
|
||||
"""
|
||||
Denoising strength for img2img.
|
||||
|
||||
Mellon node definition (default=0.5):
|
||||
"strength": {"label": "Strength", "type": "float", "default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}
|
||||
"""
|
||||
return cls(
|
||||
name="strength",
|
||||
label="Strength",
|
||||
@@ -318,15 +199,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def guidance_scale(cls, default: float = 5.0) -> "MellonParam":
|
||||
"""
|
||||
CFG guidance scale slider.
|
||||
|
||||
Mellon node definition (default=5.0):
|
||||
"guidance_scale": {
|
||||
"label": "Guidance Scale", "type": "float", "display": "slider", "default": 5.0, "min": 1.0, "max":
|
||||
30.0, "step": 0.1
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="guidance_scale",
|
||||
label="Guidance Scale",
|
||||
@@ -340,12 +212,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def height(cls, default: int = 1024) -> "MellonParam":
|
||||
"""
|
||||
Image height in pixels.
|
||||
|
||||
Mellon node definition (default=1024):
|
||||
"height": {"label": "Height", "type": "int", "default": 1024, "min": 64, "step": 8}
|
||||
"""
|
||||
return cls(
|
||||
name="height",
|
||||
label="Height",
|
||||
@@ -358,26 +224,12 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def width(cls, default: int = 1024) -> "MellonParam":
|
||||
"""
|
||||
Image width in pixels.
|
||||
|
||||
Mellon node definition (default=1024):
|
||||
"width": {"label": "Width", "type": "int", "default": 1024, "min": 64, "step": 8}
|
||||
"""
|
||||
return cls(
|
||||
name="width", label="Width", type="int", default=default, min=64, step=8, required_block_params=["width"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def seed(cls, default: int = 0) -> "MellonParam":
|
||||
"""
|
||||
Random seed with randomize button.
|
||||
|
||||
Mellon node definition (default=0):
|
||||
"seed": {
|
||||
"label": "Seed", "type": "int", "default": 0, "min": 0, "max": 4294967295, "display": "random"
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="seed",
|
||||
label="Seed",
|
||||
@@ -391,14 +243,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def num_inference_steps(cls, default: int = 25) -> "MellonParam":
|
||||
"""
|
||||
Number of denoising steps slider.
|
||||
|
||||
Mellon node definition (default=25):
|
||||
"num_inference_steps": {
|
||||
"label": "Steps", "type": "int", "default": 25, "min": 1, "max": 100, "display": "slider"
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="num_inference_steps",
|
||||
label="Steps",
|
||||
@@ -412,12 +256,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def num_frames(cls, default: int = 81) -> "MellonParam":
|
||||
"""
|
||||
Number of video frames slider.
|
||||
|
||||
Mellon node definition (default=81):
|
||||
"num_frames": {"label": "Frames", "type": "int", "default": 81, "min": 1, "max": 480, "display": "slider"}
|
||||
"""
|
||||
return cls(
|
||||
name="num_frames",
|
||||
label="Frames",
|
||||
@@ -431,12 +269,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def layers(cls, default: int = 4) -> "MellonParam":
|
||||
"""
|
||||
Number of layers slider (for layered diffusion).
|
||||
|
||||
Mellon node definition (default=4):
|
||||
"layers": {"label": "Layers", "type": "int", "default": 4, "min": 1, "max": 10, "display": "slider"}
|
||||
"""
|
||||
return cls(
|
||||
name="layers",
|
||||
label="Layers",
|
||||
@@ -450,24 +282,15 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def videos(cls) -> "MellonParam":
|
||||
"""
|
||||
Video output parameter.
|
||||
|
||||
Mellon node definition:
|
||||
"videos": {"label": "Videos", "type": "video", "display": "output"}
|
||||
"""
|
||||
return cls(name="videos", label="Videos", type="video", display="output", required_block_params=["videos"])
|
||||
|
||||
@classmethod
|
||||
def vae(cls) -> "MellonParam":
|
||||
"""
|
||||
VAE model input.
|
||||
VAE model info dict.
|
||||
|
||||
Mellon node definition:
|
||||
"vae": {"label": "VAE", "type": "diffusers_auto_model", "display": "input"}
|
||||
|
||||
Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
|
||||
components.get_one(model_id) to retrieve the actual model.
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(
|
||||
name="vae", label="VAE", type="diffusers_auto_model", display="input", required_block_params=["vae"]
|
||||
@@ -476,13 +299,10 @@ class MellonParam:
|
||||
@classmethod
|
||||
def image_encoder(cls) -> "MellonParam":
|
||||
"""
|
||||
Image encoder model input.
|
||||
Image Encoder model info dict.
|
||||
|
||||
Mellon node definition:
|
||||
"image_encoder": {"label": "Image Encoder", "type": "diffusers_auto_model", "display": "input"}
|
||||
|
||||
Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
|
||||
components.get_one(model_id) to retrieve the actual model.
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(
|
||||
name="image_encoder",
|
||||
@@ -495,39 +315,30 @@ class MellonParam:
|
||||
@classmethod
|
||||
def unet(cls) -> "MellonParam":
|
||||
"""
|
||||
Denoising model (UNet/Transformer) input.
|
||||
Denoising model (UNet/Transformer) info dict.
|
||||
|
||||
Mellon node definition:
|
||||
"unet": {"label": "Denoise Model", "type": "diffusers_auto_model", "display": "input"}
|
||||
|
||||
Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
|
||||
components.get_one(model_id) to retrieve the actual model.
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(name="unet", label="Denoise Model", type="diffusers_auto_model", display="input")
|
||||
|
||||
@classmethod
|
||||
def scheduler(cls) -> "MellonParam":
|
||||
"""
|
||||
Scheduler model input.
|
||||
Scheduler model info dict.
|
||||
|
||||
Mellon node definition:
|
||||
"scheduler": {"label": "Scheduler", "type": "diffusers_auto_model", "display": "input"}
|
||||
|
||||
Note: The value received is a model info dict with keys like 'model_id', 'repo_id'. Use
|
||||
components.get_one(model_id) to retrieve the actual scheduler.
|
||||
Contains keys like 'model_id', 'repo_id' etc. Use components.get_one(model_id) to retrieve the actual
|
||||
scheduler.
|
||||
"""
|
||||
return cls(name="scheduler", label="Scheduler", type="diffusers_auto_model", display="input")
|
||||
|
||||
@classmethod
|
||||
def controlnet(cls) -> "MellonParam":
|
||||
"""
|
||||
ControlNet model input.
|
||||
ControlNet model info dict.
|
||||
|
||||
Mellon node definition:
|
||||
"controlnet": {"label": "ControlNet Model", "type": "diffusers_auto_model", "display": "input"}
|
||||
|
||||
Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
|
||||
components.get_one(model_id) to retrieve the actual model.
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(
|
||||
name="controlnet",
|
||||
@@ -540,17 +351,12 @@ class MellonParam:
|
||||
@classmethod
|
||||
def text_encoders(cls) -> "MellonParam":
|
||||
"""
|
||||
Text encoders dict input (multiple encoders).
|
||||
Dict of text encoder model info dicts.
|
||||
|
||||
Mellon node definition:
|
||||
"text_encoders": {"label": "Text Encoders", "type": "diffusers_auto_models", "display": "input"}
|
||||
|
||||
Note: The value received is a dict of model info dicts:
|
||||
{
|
||||
'text_encoder': {'model_id': ..., 'execution_device': ..., ...}, 'tokenizer': {'model_id': ..., ...},
|
||||
'repo_id': '...'
|
||||
}
|
||||
Use components.get_one(model_id) to retrieve each model.
|
||||
Structure: {
|
||||
'text_encoder': {'model_id': ..., 'execution_device': ..., ...}, 'tokenizer': {'model_id': ..., ...},
|
||||
'repo_id': '...'
|
||||
} Use components.get_one(model_id) to retrieve each model.
|
||||
"""
|
||||
return cls(
|
||||
name="text_encoders",
|
||||
@@ -563,20 +369,15 @@ class MellonParam:
|
||||
@classmethod
|
||||
def controlnet_bundle(cls, display: str = "input") -> "MellonParam":
|
||||
"""
|
||||
ControlNet bundle containing model and processed control inputs. Output from ControlNet node, input to Denoise
|
||||
node.
|
||||
ControlNet bundle containing model info and processed control inputs.
|
||||
|
||||
Mellon node definition (display="input"):
|
||||
"controlnet_bundle": {"label": "ControlNet", "type": "custom_controlnet", "display": "input"}
|
||||
Structure: {
|
||||
'controlnet': {'model_id': ..., ...}, # controlnet model info dict 'control_image': ..., # processed
|
||||
control image/embeddings 'controlnet_conditioning_scale': ..., ... # other inputs expected by denoise
|
||||
blocks
|
||||
}
|
||||
|
||||
Mellon node definition (display="output"):
|
||||
"controlnet_bundle": {"label": "ControlNet", "type": "custom_controlnet", "display": "output"}
|
||||
|
||||
Note: The value is a dict containing:
|
||||
{
|
||||
'controlnet': {'model_id': ..., ...}, # controlnet model info 'control_image': ..., # processed control
|
||||
image/embeddings 'controlnet_conditioning_scale': ..., # and other denoise block inputs
|
||||
}
|
||||
Output from Controlnet node, input to Denoise node.
|
||||
"""
|
||||
return cls(
|
||||
name="controlnet_bundle",
|
||||
@@ -588,25 +389,10 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def ip_adapter(cls) -> "MellonParam":
|
||||
"""
|
||||
IP-Adapter input.
|
||||
|
||||
Mellon node definition:
|
||||
"ip_adapter": {"label": "IP Adapter", "type": "custom_ip_adapter", "display": "input"}
|
||||
"""
|
||||
return cls(name="ip_adapter", label="IP Adapter", type="custom_ip_adapter", display="input")
|
||||
|
||||
@classmethod
|
||||
def guider(cls) -> "MellonParam":
|
||||
"""
|
||||
Custom guider input. When connected, hides the guidance_scale slider.
|
||||
|
||||
Mellon node definition:
|
||||
"guider": {
|
||||
"label": "Guider", "type": "custom_guider", "display": "input", "onChange": {false: ["guidance_scale"],
|
||||
true: []}
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="guider",
|
||||
label="Guider",
|
||||
@@ -617,12 +403,6 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def doc(cls) -> "MellonParam":
|
||||
"""
|
||||
Documentation output for inspecting the underlying modular pipeline.
|
||||
|
||||
Mellon node definition:
|
||||
"doc": {"label": "Doc", "type": "string", "display": "output"}
|
||||
"""
|
||||
return cls(name="doc", label="Doc", type="string", display="output")
|
||||
|
||||
|
||||
@@ -635,7 +415,6 @@ DEFAULT_NODE_SPECS = {
|
||||
MellonParam.height(),
|
||||
MellonParam.seed(),
|
||||
MellonParam.num_inference_steps(),
|
||||
MellonParam.num_frames(),
|
||||
MellonParam.guidance_scale(),
|
||||
MellonParam.strength(),
|
||||
MellonParam.image_latents_with_strength(),
|
||||
@@ -890,9 +669,6 @@ class MellonPipelineConfig:
|
||||
@property
|
||||
def node_params(self) -> Dict[str, Any]:
|
||||
"""Lazily compute node_params from node_specs."""
|
||||
if self.node_specs is None:
|
||||
return self._node_params
|
||||
|
||||
params = {}
|
||||
for node_type, spec in self.node_specs.items():
|
||||
if spec is None:
|
||||
@@ -935,8 +711,7 @@ class MellonPipelineConfig:
|
||||
Note: The mellon_params are already in Mellon format when loading from JSON.
|
||||
"""
|
||||
instance = cls.__new__(cls)
|
||||
instance.node_specs = None
|
||||
instance._node_params = data.get("node_params", {})
|
||||
instance.node_params = data.get("node_params", {})
|
||||
instance.label = data.get("label", "")
|
||||
instance.default_repo = data.get("default_repo", "")
|
||||
instance.default_dtype = data.get("default_dtype", "")
|
||||
|
||||
@@ -130,7 +130,7 @@ else:
|
||||
]
|
||||
_import_structure["bria"] = ["BriaPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
@@ -155,7 +155,7 @@ else:
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
]
|
||||
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
|
||||
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline", "ChromaInpaintPipeline"]
|
||||
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"]
|
||||
_import_structure["cogvideo"] = [
|
||||
"CogVideoXPipeline",
|
||||
"CogVideoXImageToVideoPipeline",
|
||||
@@ -598,7 +598,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .bria import BriaPipeline
|
||||
from .bria_fibo import BriaFiboPipeline
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
|
||||
from .chronoedit import ChronoEditPipeline
|
||||
from .cogvideo import (
|
||||
CogVideoXFunControlPipeline,
|
||||
@@ -678,7 +678,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxPriorReduxPipeline,
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
|
||||
from .flux2 import Flux2Pipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
|
||||
|
||||
@@ -52,7 +52,6 @@ from .flux import (
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
)
|
||||
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .kandinsky import (
|
||||
@@ -165,8 +164,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("flux-control", FluxControlPipeline),
|
||||
("flux-controlnet", FluxControlNetPipeline),
|
||||
("flux-kontext", FluxKontextPipeline),
|
||||
("flux2-klein", Flux2KleinPipeline),
|
||||
("flux2", Flux2Pipeline),
|
||||
("lumina", LuminaPipeline),
|
||||
("lumina2", Lumina2Pipeline),
|
||||
("chroma", ChromaPipeline),
|
||||
@@ -205,8 +202,6 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("flux-controlnet", FluxControlNetImg2ImgPipeline),
|
||||
("flux-control", FluxControlImg2ImgPipeline),
|
||||
("flux-kontext", FluxKontextPipeline),
|
||||
("flux2-klein", Flux2KleinPipeline),
|
||||
("flux2", Flux2Pipeline),
|
||||
("qwenimage", QwenImageImg2ImgPipeline),
|
||||
("qwenimage-edit", QwenImageEditPipeline),
|
||||
("qwenimage-edit-plus", QwenImageEditPlusPipeline),
|
||||
|
||||
@@ -24,7 +24,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
|
||||
_import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
|
||||
_import_structure["pipeline_chroma_inpainting"] = ["ChromaInpaintPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -34,7 +33,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_chroma import ChromaPipeline
|
||||
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
|
||||
from .pipeline_chroma_inpainting import ChromaInpaintPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,7 +23,6 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
|
||||
_import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -32,7 +31,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flux2 import Flux2Pipeline
|
||||
from .pipeline_flux2_klein import Flux2KleinPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -725,8 +725,8 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
@@ -975,7 +975,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids, # B, text_seq_len, 4
|
||||
img_ids=latent_image_ids, # B, image_seq_len, 4
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
joint_attention_kwargs=self._attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
|
||||
@@ -1,918 +0,0 @@
|
||||
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .image_processor import Flux2ImageProcessor
|
||||
from .pipeline_output import Flux2PipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import Flux2KleinPipeline
|
||||
|
||||
>>> pipe = Flux2KleinPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.2-klein-base-9B", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
||||
>>> # Refer to the pipeline documentation for more details.
|
||||
>>> image = pipe(prompt, num_inference_steps=50, guidance_scale=4.0).images[0]
|
||||
>>> image.save("flux2_output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu
|
||||
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
|
||||
if image_seq_len > 4300:
|
||||
mu = a2 * image_seq_len + b2
|
||||
return float(mu)
|
||||
|
||||
m_200 = a2 * image_seq_len + b2
|
||||
m_10 = a1 * image_seq_len + b1
|
||||
|
||||
a = (m_200 - m_10) / 190.0
|
||||
b = m_200 - 200.0 * a
|
||||
mu = a * num_steps + b
|
||||
|
||||
return float(mu)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
r"""
|
||||
The Flux2 Klein pipeline for text-to-image generation.
|
||||
|
||||
Reference:
|
||||
[https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence)
|
||||
|
||||
Args:
|
||||
transformer ([`Flux2Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLFlux2`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen3ForCausalLM`]):
|
||||
[Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM)
|
||||
tokenizer (`Qwen2TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLFlux2,
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
transformer: Flux2Transformer2DModel,
|
||||
is_distilled: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
self.register_to_config(is_distilled=is_distilled)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.tokenizer_max_length = 512
|
||||
self.default_sample_size = 128
|
||||
|
||||
@staticmethod
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: List[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids
|
||||
def _prepare_text_ids(
|
||||
x: torch.Tensor, # (B, L, D) or (L, D)
|
||||
t_coord: Optional[torch.Tensor] = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids
|
||||
def _prepare_latent_ids(
|
||||
latents: torch.Tensor, # (B, C, H, W)
|
||||
):
|
||||
r"""
|
||||
Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor):
|
||||
Latent tensor of shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
|
||||
H=[0..H-1], W=[0..W-1], L=0
|
||||
"""
|
||||
|
||||
batch_size, _, height, width = latents.shape
|
||||
|
||||
t = torch.arange(1) # [0] - time dimension
|
||||
h = torch.arange(height)
|
||||
w = torch.arange(width)
|
||||
l = torch.arange(1) # [0] - layer dimension
|
||||
|
||||
# Create position IDs: (H*W, 4)
|
||||
latent_ids = torch.cartesian_prod(t, h, w, l)
|
||||
|
||||
# Expand to batch: (B, H*W, 4)
|
||||
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
||||
return latent_ids
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
|
||||
def _prepare_image_ids(
|
||||
image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
|
||||
scale: int = 10,
|
||||
):
|
||||
r"""
|
||||
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
||||
|
||||
This function creates a unique coordinate for every pixel/patch across all input latent with different
|
||||
dimensions.
|
||||
|
||||
Args:
|
||||
image_latents (List[torch.Tensor]):
|
||||
A list of image latent feature tensors, typically of shape (C, H, W).
|
||||
scale (int, optional):
|
||||
A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
|
||||
latent is: 'scale + scale * i'. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
|
||||
input latents.
|
||||
|
||||
Coordinate Components (Dimension 4):
|
||||
- T (Time): The unique index indicating which latent image the coordinate belongs to.
|
||||
- H (Height): The row index within that latent image.
|
||||
- W (Width): The column index within that latent image.
|
||||
- L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
|
||||
"""
|
||||
|
||||
if not isinstance(image_latents, list):
|
||||
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
||||
|
||||
# create time offset for each reference image
|
||||
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||
t_coords = [t.view(-1) for t in t_coords]
|
||||
|
||||
image_latent_ids = []
|
||||
for x, t in zip(image_latents, t_coords):
|
||||
x = x.squeeze(0)
|
||||
_, height, width = x.shape
|
||||
|
||||
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||
image_latent_ids.append(x_ids)
|
||||
|
||||
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||
|
||||
return image_latent_ids
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents
|
||||
def _patchify_latents(latents):
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
||||
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents
|
||||
def _unpatchify_latents(latents):
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
||||
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents
|
||||
def _pack_latents(latents):
|
||||
"""
|
||||
pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
|
||||
"""
|
||||
|
||||
batch_size, num_channels, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
|
||||
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
|
||||
"""
|
||||
using position ids to scatter tokens into place
|
||||
"""
|
||||
x_list = []
|
||||
for data, pos in zip(x, x_ids):
|
||||
_, ch = data.shape # noqa: F841
|
||||
h_ids = pos[:, 1].to(torch.int64)
|
||||
w_ids = pos[:, 2].to(torch.int64)
|
||||
|
||||
h = torch.max(h_ids) + 1
|
||||
w = torch.max(w_ids) + 1
|
||||
|
||||
flat_ids = h_ids * w + w_ids
|
||||
|
||||
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
||||
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
||||
|
||||
# reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
|
||||
|
||||
out = out.view(h, w, ch).permute(2, 0, 1)
|
||||
x_list.append(out)
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: Tuple[int] = (9, 18, 27),
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
hidden_states_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
text_ids = self._prepare_text_ids(prompt_embeds)
|
||||
text_ids = text_ids.to(device)
|
||||
return prompt_embeds, text_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if image.ndim != 4:
|
||||
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
||||
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
||||
image_latents = self._patchify_latents(image_latents)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
|
||||
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
||||
|
||||
return image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_latents_channels,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator: torch.Generator,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
latent_ids = self._prepare_latent_ids(latents)
|
||||
latent_ids = latent_ids.to(device)
|
||||
|
||||
latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
|
||||
return latents, latent_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
images: List[torch.Tensor],
|
||||
batch_size,
|
||||
generator: torch.Generator,
|
||||
device,
|
||||
dtype,
|
||||
):
|
||||
image_latents = []
|
||||
for image in images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
imagge_latent = self._encode_vae_image(image=image, generator=generator)
|
||||
image_latents.append(imagge_latent) # (1, 128, 32, 32)
|
||||
|
||||
image_latent_ids = self._prepare_image_ids(image_latents)
|
||||
|
||||
# Pack each latent and concatenate
|
||||
packed_latents = []
|
||||
for latent in image_latents:
|
||||
# latent: (1, 128, 32, 32)
|
||||
packed = self._pack_latents(latent) # (1, 1024, 128)
|
||||
packed = packed.squeeze(0) # (1024, 128) - remove batch dim
|
||||
packed_latents.append(packed)
|
||||
|
||||
# Concatenate all reference tokens along sequence dimension
|
||||
image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
|
||||
image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
|
||||
|
||||
image_latents = image_latents.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.to(device)
|
||||
|
||||
return image_latents, image_latent_ids
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
guidance_scale=None,
|
||||
):
|
||||
if (
|
||||
height is not None
|
||||
and height % (self.vae_scale_factor * 2) != 0
|
||||
or width is not None
|
||||
and width % (self.vae_scale_factor * 2) != 0
|
||||
):
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if guidance_scale > 1.0 and self.config.is_distilled:
|
||||
logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and not self.config.is_distilled
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: Optional[float] = 4.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[Union[str, List[str]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: Tuple[int] = (9, 18, 27),
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
||||
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
||||
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
||||
latents as `image`, but if passing latents directly it is not encoded again.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models,
|
||||
`guidance_scale` is ignored.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline.
|
||||
If not provided, will be generated from "".
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
text_encoder_out_layers (`Tuple[int]`):
|
||||
Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. prepare text embeddings
|
||||
prompt_embeds, text_ids = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
text_encoder_out_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
negative_prompt = ""
|
||||
if prompt is not None and isinstance(prompt, list):
|
||||
negative_prompt = [negative_prompt] * len(prompt)
|
||||
negative_prompt_embeds, negative_text_ids = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
text_encoder_out_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
# 4. process images
|
||||
if image is not None and not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
condition_images = None
|
||||
if image is not None:
|
||||
for img in image:
|
||||
self.image_processor.check_image_input(img)
|
||||
|
||||
condition_images = []
|
||||
for img in image:
|
||||
image_width, image_height = img.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||
image_width, image_height = img.size
|
||||
|
||||
multiple_of = self.vae_scale_factor * 2
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
||||
condition_images.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 5. prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, latent_ids = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_latents_channels=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
image_latents = None
|
||||
image_latent_ids = None
|
||||
if condition_images is not None:
|
||||
image_latents, image_latent_ids = self.prepare_image_latents(
|
||||
images=condition_images,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
|
||||
# 6. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 7. Denoising loop
|
||||
# We set the index here to remove DtoH sync, helpful especially during compilation.
|
||||
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
latent_image_ids = latent_ids
|
||||
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
|
||||
latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
|
||||
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input, # (B, image_seq_len, C)
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids, # B, text_seq_len, 4
|
||||
img_ids=latent_image_ids, # B, image_seq_len, 4
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred[:, : latents.size(1) :]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=negative_text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self._attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
neg_noise_pred = neg_noise_pred[:, : latents.size(1) :]
|
||||
noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
latents = self._unpatchify_latents(latents)
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return Flux2PipelineOutput(images=image)
|
||||
@@ -260,10 +260,10 @@ class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
|
||||
all_text.append(text)
|
||||
|
||||
inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(self.text_encoder.device)
|
||||
inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(device)
|
||||
|
||||
self.text_encoder.to(device)
|
||||
generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length)
|
||||
generated_ids.to(device)
|
||||
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
||||
output_text = self.text_processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
|
||||
|
||||
import math
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -36,30 +36,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
sigma_min (`float`, defaults to `0.3`):
|
||||
sigma_min (`float`, *optional*, defaults to 0.3):
|
||||
Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1].
|
||||
sigma_max (`float`, defaults to `500`):
|
||||
sigma_max (`float`, *optional*, defaults to 500):
|
||||
Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1].
|
||||
sigma_data (`float`, defaults to `1.0`):
|
||||
sigma_data (`float`, *optional*, defaults to 1.0):
|
||||
The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1].
|
||||
sigma_schedule (`str`, defaults to `"exponential"`):
|
||||
Sigma schedule to compute the `sigmas`. Must be one of `"exponential"` or `"karras"`. The exponential
|
||||
schedule was incorporated in [stabilityai/cosxl](https://huggingface.co/stabilityai/cosxl). The Karras
|
||||
schedule is introduced in the [EDM](https://huggingface.co/papers/2206.00364) paper.
|
||||
num_train_timesteps (`int`, defaults to `1000`):
|
||||
sigma_schedule (`str`, *optional*, defaults to `exponential`):
|
||||
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
|
||||
(https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
|
||||
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
solver_order (`int`, defaults to `2`):
|
||||
solver_order (`int`, defaults to 2):
|
||||
The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`.
|
||||
prediction_type (`str`, defaults to `"v_prediction"`):
|
||||
Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion
|
||||
process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen
|
||||
prediction_type (`str`, defaults to `v_prediction`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://huggingface.co/papers/2210.02303) paper).
|
||||
rho (`float`, defaults to `7.0`):
|
||||
The parameter for calculating the Karras sigma schedule from the EDM
|
||||
[paper](https://huggingface.co/papers/2206.00364).
|
||||
solver_type (`str`, defaults to `"midpoint"`):
|
||||
Solver type for the second-order solver. Must be one of `"midpoint"` or `"heun"`. The solver type slightly
|
||||
affects the sample quality, especially for a small number of steps. It is recommended to use `"midpoint"`.
|
||||
solver_type (`str`, defaults to `midpoint`):
|
||||
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
||||
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
||||
lower_order_final (`bool`, defaults to `True`):
|
||||
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
||||
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
||||
@@ -68,9 +65,8 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
|
||||
steps, but sometimes may result in blurring.
|
||||
final_sigmas_type (`str`, defaults to `"zero"`):
|
||||
The final `sigma` value for the noise schedule during the sampling process. Must be one of `"zero"` or
|
||||
`"sigma_min"`. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If
|
||||
`"zero"`, the final sigma is set to 0.
|
||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
||||
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
@@ -82,16 +78,16 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma_min: float = 0.3,
|
||||
sigma_max: float = 500,
|
||||
sigma_data: float = 1.0,
|
||||
sigma_schedule: Literal["exponential", "karras"] = "exponential",
|
||||
sigma_schedule: str = "exponential",
|
||||
num_train_timesteps: int = 1000,
|
||||
solver_order: int = 2,
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "v_prediction",
|
||||
prediction_type: str = "v_prediction",
|
||||
rho: float = 7.0,
|
||||
solver_type: Literal["midpoint", "heun"] = "midpoint",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
|
||||
) -> None:
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
):
|
||||
if solver_type not in ["midpoint", "heun"]:
|
||||
if solver_type in ["logrho", "bh1", "bh2"]:
|
||||
self.register_to_config(solver_type="midpoint")
|
||||
@@ -117,40 +113,26 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self) -> float:
|
||||
"""
|
||||
The standard deviation of the initial noise distribution.
|
||||
|
||||
Returns:
|
||||
`float`:
|
||||
The initial noise sigma value computed as `sqrt(sigma_max^2 + 1)`.
|
||||
"""
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
return (self.config.sigma_max**2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The current step index, or `None` if not yet initialized.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The begin index, or `None` if not yet set.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
@@ -179,18 +161,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
|
||||
def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the noise level by computing a normalized timestep representation.
|
||||
|
||||
Args:
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The sigma (noise level) value to precondition.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The preconditioned noise value computed as `atan(sigma) / pi * 2`.
|
||||
"""
|
||||
def precondition_noise(self, sigma):
|
||||
if not isinstance(sigma, torch.Tensor):
|
||||
sigma = torch.tensor([sigma])
|
||||
|
||||
@@ -257,14 +228,12 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None
|
||||
) -> None:
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`, *optional*):
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
@@ -365,7 +334,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
"""
|
||||
Convert sigma values to corresponding timestep values through interpolation.
|
||||
|
||||
@@ -401,19 +370,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert sigma to alpha and sigma_t values for the diffusion process.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing `alpha_t` (always 1 since inputs are pre-scaled) and `sigma_t` (same as input
|
||||
sigma).
|
||||
"""
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
|
||||
sigma_t = sigma
|
||||
|
||||
@@ -579,7 +536,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
@@ -600,7 +557,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -610,19 +567,20 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
@@ -744,12 +702,5 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Returns the number of training timesteps.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of training timesteps configured for the scheduler.
|
||||
"""
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -632,21 +632,6 @@ class ChromaImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ChromaInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ChromaPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -947,21 +932,6 @@ class EasyAnimatePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLFlux2,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
Flux2KleinPipeline,
|
||||
Flux2Transformer2DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class Flux2KleinPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = Flux2KleinPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = Flux2Transformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=16,
|
||||
timestep_guidance_channels=256,
|
||||
axes_dims_rope=[4, 4, 4, 4],
|
||||
guidance_embeds=False,
|
||||
)
|
||||
|
||||
# Create minimal Qwen3 config
|
||||
config = Qwen3Config(
|
||||
intermediate_size=16,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=151936,
|
||||
max_position_embeddings=512,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = Qwen3ForCausalLM(config)
|
||||
|
||||
# Use a simple tokenizer for testing
|
||||
tokenizer = Qwen2TokenizerFast.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLFlux2(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=1,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "a dog is dancing",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 4.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 64,
|
||||
"output_type": "np",
|
||||
"text_encoder_out_layers": (1,),
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
|
||||
("Fusion of QKV projections shouldn't affect the outputs."),
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
|
||||
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
|
||||
("Original outputs should match when fused QKV projections are disabled."),
|
||||
)
|
||||
|
||||
def test_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (72, 57)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height - height % (pipe.vae_scale_factor * 2)
|
||||
expected_width = width - width % (pipe.vae_scale_factor * 2)
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
self.assertEqual(
|
||||
(output_height, output_width),
|
||||
(expected_height, expected_width),
|
||||
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
|
||||
)
|
||||
|
||||
def test_image_input(self):
|
||||
device = "cpu"
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
|
||||
inputs["image"] = Image.new("RGB", (64, 64))
|
||||
image = pipe(**inputs).images.flatten()
|
||||
generated_slice = np.concatenate([image[:8], image[-8:]])
|
||||
# fmt: off
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.8255048 , 0.66054785, 0.6643694 , 0.67462724, 0.5494932 , 0.3480271 , 0.52535003, 0.44510138, 0.23549396, 0.21372932, 0.21166152, 0.63198495, 0.49942136, 0.39147034, 0.49156153, 0.3713916
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
assert np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skip("Needs to be revisited")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user