mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-03 22:31:46 +08:00
Compare commits
10 Commits
security/p
...
autoencode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7db9646053 | ||
|
|
760918e8a5 | ||
|
|
5adc544b79 | ||
|
|
a05c8e9452 | ||
|
|
8070f6ec54 | ||
|
|
3e53a383e1 | ||
|
|
cf6af6b4f8 | ||
|
|
3211cd9df0 | ||
|
|
119daffdf1 | ||
|
|
c346ad5eb8 |
@@ -148,5 +148,6 @@ ComponentSpec(
|
||||
- [ ] Create pipeline class with `default_blocks_name`
|
||||
- [ ] Assemble blocks in `modular_blocks_<model>.py`
|
||||
- [ ] Wire up `__init__.py` with lazy imports
|
||||
- [ ] Add `# auto_docstring` above all assembled blocks (SequentialPipelineBlocks, AutoPipelineBlocks, etc.), run `python utils/modular_auto_docstring.py --fix_and_overwrite`, and verify the generated docstrings — all parameters should have proper descriptions with no "TODO" placeholders indicating missing definitions
|
||||
- [ ] Run `make style` and `make quality`
|
||||
- [ ] Test all workflows for parity with reference
|
||||
|
||||
4
.github/workflows/benchmark.yml
vendored
4
.github/workflows/benchmark.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -58,7 +58,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
|
||||
with:
|
||||
name: benchmark_test_reports
|
||||
path: benchmarks/${{ env.BASE_PATH }}
|
||||
|
||||
16
.github/workflows/build_docker_images.yml
vendored
16
.github/workflows/build_docker_images.yml
vendored
@@ -25,14 +25,14 @@ jobs:
|
||||
if: github.event_name == 'pull_request'
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Find Changed Dockerfiles
|
||||
id: file_changes
|
||||
uses: jitterbit/get-changed-files@v1
|
||||
uses: jitterbit/get-changed-files@b17fbb00bdc0c0f63fcf166580804b4d2cdc2a42 # v1
|
||||
with:
|
||||
format: "space-delimited"
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -99,16 +99,16 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
username: ${{ env.REGISTRY }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
no-cache: true
|
||||
context: ./docker/${{ matrix.image-name }}
|
||||
@@ -117,7 +117,7 @@ jobs:
|
||||
|
||||
- name: Post to a Slack channel
|
||||
id: slack
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@a88e7fa2eaee28de5a4d6142381b1fb792349b67 # main
|
||||
with:
|
||||
# Slack channel id, channel name, or user id to post message.
|
||||
# See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||
|
||||
2
.github/workflows/build_documentation.yml
vendored
2
.github/workflows/build_documentation.yml
vendored
@@ -14,7 +14,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
install_libgl1: true
|
||||
|
||||
6
.github/workflows/build_pr_documentation.yml
vendored
6
.github/workflows/build_pr_documentation.yml
vendored
@@ -17,10 +17,10 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
|
||||
build:
|
||||
needs: check-links
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
|
||||
2
.github/workflows/codeql.yml
vendored
2
.github/workflows/codeql.yml
vendored
@@ -10,7 +10,7 @@ on:
|
||||
jobs:
|
||||
codeql:
|
||||
name: CodeQL Analysis
|
||||
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1
|
||||
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@dc6ca34688e6876c2dd18750719b44d177586c17 # v1
|
||||
permissions:
|
||||
security-events: write
|
||||
packages: read
|
||||
|
||||
46
.github/workflows/nightly_tests.yml
vendored
46
.github/workflows/nightly_tests.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -97,7 +97,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -119,7 +119,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others, single_file, examples]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -167,7 +167,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
|
||||
with:
|
||||
name: torch_${{ matrix.module }}_cuda_test_reports
|
||||
path: reports
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -211,7 +211,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
|
||||
with:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -263,7 +263,7 @@ jobs:
|
||||
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
|
||||
with:
|
||||
name: torch_cuda_big_gpu_test_reports
|
||||
path: reports
|
||||
@@ -280,7 +280,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -321,7 +321,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
|
||||
with:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
@@ -355,7 +355,7 @@ jobs:
|
||||
options: --shm-size "20gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -391,7 +391,7 @@ jobs:
|
||||
cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
|
||||
with:
|
||||
name: torch_cuda_${{ matrix.config.backend }}_reports
|
||||
path: reports
|
||||
@@ -408,7 +408,7 @@ jobs:
|
||||
options: --shm-size "20gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -441,7 +441,7 @@ jobs:
|
||||
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
|
||||
with:
|
||||
name: torch_cuda_pipeline_level_quant_reports
|
||||
path: reports
|
||||
@@ -466,7 +466,7 @@ jobs:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -474,7 +474,7 @@ jobs:
|
||||
run: mkdir -p combined_reports
|
||||
|
||||
- name: Download all test reports
|
||||
uses: actions/download-artifact@v7
|
||||
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
@@ -500,7 +500,7 @@ jobs:
|
||||
cat $CONSOLIDATED_REPORT_PATH >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
- name: Upload consolidated report
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
|
||||
with:
|
||||
name: consolidated_test_report
|
||||
path: ${{ env.CONSOLIDATED_REPORT_PATH }}
|
||||
@@ -514,7 +514,7 @@ jobs:
|
||||
#
|
||||
# steps:
|
||||
# - name: Checkout diffusers
|
||||
# uses: actions/checkout@v6
|
||||
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
# with:
|
||||
# fetch-depth: 2
|
||||
#
|
||||
@@ -554,7 +554,7 @@ jobs:
|
||||
#
|
||||
# - name: Test suite reports artifacts
|
||||
# if: ${{ always() }}
|
||||
# uses: actions/upload-artifact@v6
|
||||
# uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
|
||||
# with:
|
||||
# name: torch_mps_test_reports
|
||||
# path: reports
|
||||
@@ -570,7 +570,7 @@ jobs:
|
||||
#
|
||||
# steps:
|
||||
# - name: Checkout diffusers
|
||||
# uses: actions/checkout@v6
|
||||
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
# with:
|
||||
# fetch-depth: 2
|
||||
#
|
||||
@@ -610,7 +610,7 @@ jobs:
|
||||
#
|
||||
# - name: Test suite reports artifacts
|
||||
# if: ${{ always() }}
|
||||
# uses: actions/upload-artifact@v6
|
||||
# uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
|
||||
# with:
|
||||
# name: torch_mps_test_reports
|
||||
# path: reports
|
||||
|
||||
2
.github/workflows/pr_style_bot.yml
vendored
2
.github/workflows/pr_style_bot.yml
vendored
@@ -10,7 +10,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
style:
|
||||
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
|
||||
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@e000c1c89c65aee188041723456ac3a479416d4c # main
|
||||
with:
|
||||
python_quality_dependencies: "[quality]"
|
||||
secrets:
|
||||
|
||||
4
.github/workflows/ssh-pr-runner.yml
vendored
4
.github/workflows/ssh-pr-runner.yml
vendored
@@ -27,12 +27,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Tailscale # In order to be able to SSH when a test fails
|
||||
uses: huggingface/tailscale-action@main
|
||||
uses: huggingface/tailscale-action@7d53c9737e53934c30290b5524d1c9b4a7c98c8a # main
|
||||
with:
|
||||
authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}
|
||||
slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
|
||||
|
||||
4
.github/workflows/trufflehog.yml
vendored
4
.github/workflows/trufflehog.yml
vendored
@@ -8,11 +8,11 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
uses: trufflesecurity/trufflehog@6bd2d14f7a4bc1e569fa3550efa7ec632a4fa67b # main
|
||||
with:
|
||||
extra_args: --results=verified,unknown
|
||||
|
||||
|
||||
4
.github/workflows/typos.yml
vendored
4
.github/workflows/typos.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.42.1
|
||||
uses: crate-ci/typos@65120634e79d8374d1aa2f27e54baa0c364fff5a # v1.42.1
|
||||
|
||||
@@ -8,7 +8,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
with:
|
||||
package_name: diffusers
|
||||
secrets:
|
||||
|
||||
@@ -112,6 +112,8 @@
|
||||
title: ModularPipeline
|
||||
- local: modular_diffusers/components_manager
|
||||
title: ComponentsManager
|
||||
- local: modular_diffusers/auto_docstring
|
||||
title: Auto docstring and parameter templates
|
||||
- local: modular_diffusers/custom_blocks
|
||||
title: Building Custom Blocks
|
||||
- local: modular_diffusers/mellon
|
||||
|
||||
157
docs/source/en/modular_diffusers/auto_docstring.md
Normal file
157
docs/source/en/modular_diffusers/auto_docstring.md
Normal file
@@ -0,0 +1,157 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Auto docstring and parameter templates
|
||||
|
||||
Every [`~modular_pipelines.ModularPipelineBlocks`] has a `doc` property that is automatically generated from its `description`, `inputs`, `intermediate_outputs`, `expected_components`, and `expected_configs`. The auto docstring system keeps docstrings in sync with the block's actual interface. Parameter templates provide standardized descriptions for parameters that appear across many pipelines.
|
||||
|
||||
## Auto docstring
|
||||
|
||||
Modular pipeline blocks are composable — you can nest them, chain them in sequences, and rearrange them freely. Their docstrings follow the same pattern. When a [`~modular_pipelines.SequentialPipelineBlocks`] aggregates inputs and outputs from its sub-blocks, the documentation should update automatically without manual rewrites.
|
||||
|
||||
The `# auto_docstring` marker generates docstrings from the block's properties. Add it above a class definition to mark the class for automatic docstring generation.
|
||||
|
||||
```py
|
||||
# auto_docstring
|
||||
class FluxTextEncoderStep(SequentialPipelineBlocks):
|
||||
...
|
||||
```
|
||||
|
||||
Run the following command to generate and insert the docstrings.
|
||||
|
||||
```bash
|
||||
python utils/modular_auto_docstring.py --fix_and_overwrite
|
||||
```
|
||||
|
||||
The utility reads the block's `doc` property and inserts it as the class docstring.
|
||||
|
||||
```py
|
||||
# auto_docstring
|
||||
class FluxTextEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Text input processing step that standardizes text embeddings for the pipeline.
|
||||
|
||||
Inputs:
|
||||
prompt_embeds (`torch.Tensor`) *required*:
|
||||
text embeddings used to guide the image generation.
|
||||
...
|
||||
|
||||
Outputs:
|
||||
prompt_embeds (`torch.Tensor`):
|
||||
text embeddings used to guide the image generation.
|
||||
...
|
||||
"""
|
||||
```
|
||||
|
||||
You can also check without overwriting, or target a specific file or directory.
|
||||
|
||||
```bash
|
||||
# Check that all marked classes have up-to-date docstrings
|
||||
python utils/modular_auto_docstring.py
|
||||
|
||||
# Check a specific file or directory
|
||||
python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/flux/
|
||||
```
|
||||
|
||||
If any marked class is missing a docstring, the check fails and lists the classes that need updating.
|
||||
|
||||
```
|
||||
Found the following # auto_docstring markers that need docstrings:
|
||||
- src/diffusers/modular_pipelines/flux/encoders.py: FluxTextEncoderStep at line 42
|
||||
|
||||
Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them.
|
||||
```
|
||||
|
||||
## Parameter templates
|
||||
|
||||
`InputParam` and `OutputParam` define a block's inputs and outputs. Create them directly or use `.template()` for standardized definitions of common parameters like `prompt`, `num_inference_steps`, or `latents`.
|
||||
|
||||
### InputParam
|
||||
|
||||
[`~modular_pipelines.InputParam`] describes a single input to a block.
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `name` | `str` | Name of the parameter |
|
||||
| `type_hint` | `Any` | Type annotation (e.g., `str`, `torch.Tensor`) |
|
||||
| `default` | `Any` | Default value (if not set, parameter has no default) |
|
||||
| `required` | `bool` | Whether the parameter is required |
|
||||
| `description` | `str` | Human-readable description |
|
||||
| `kwargs_type` | `str` | Group name for related parameters (e.g., `"denoiser_input_fields"`) |
|
||||
| `metadata` | `dict` | Arbitrary additional information |
|
||||
|
||||
#### Creating InputParam directly
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import InputParam
|
||||
|
||||
InputParam(
|
||||
name="guidance_scale",
|
||||
type_hint=float,
|
||||
default=7.5,
|
||||
description="Scale for classifier-free guidance.",
|
||||
)
|
||||
```
|
||||
|
||||
#### Using a template
|
||||
|
||||
```py
|
||||
InputParam.template("prompt")
|
||||
# Equivalent to:
|
||||
# InputParam(name="prompt", type_hint=str, required=True,
|
||||
# description="The prompt or prompts to guide image generation.")
|
||||
```
|
||||
|
||||
Templates set `name`, `type_hint`, `default`, `required`, and `description` automatically. Override any field or add context with the `note` parameter.
|
||||
|
||||
```py
|
||||
# Override the default value
|
||||
InputParam.template("num_inference_steps", default=28)
|
||||
|
||||
# Add a note to the description
|
||||
InputParam.template("prompt_embeds", note="batch-expanded")
|
||||
# description becomes: "text embeddings used to guide the image generation. ... (batch-expanded)"
|
||||
```
|
||||
|
||||
### OutputParam
|
||||
|
||||
[`~modular_pipelines.OutputParam`] describes a single output from a block.
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `name` | `str` | Name of the parameter |
|
||||
| `type_hint` | `Any` | Type annotation |
|
||||
| `description` | `str` | Human-readable description |
|
||||
| `kwargs_type` | `str` | Group name for related parameters |
|
||||
| `metadata` | `dict` | Arbitrary additional information |
|
||||
|
||||
#### Creating OutputParam directly
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import OutputParam
|
||||
|
||||
OutputParam(name="image_latents", type_hint=torch.Tensor, description="Encoded image latents.")
|
||||
```
|
||||
|
||||
#### Using a template
|
||||
|
||||
```py
|
||||
OutputParam.template("latents")
|
||||
|
||||
# Add a note to the description
|
||||
OutputParam.template("prompt_embeds", note="batch-expanded")
|
||||
```
|
||||
|
||||
## Available templates
|
||||
|
||||
`INPUT_PARAM_TEMPLATES` and `OUTPUT_PARAM_TEMPLATES` are defined in [modular_pipeline_utils.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/modular_pipelines/modular_pipeline_utils.py). They include common parameters like `prompt`, `image`, `num_inference_steps`, `latents`, `prompt_embeds`, and more. Refer to the source for the full list of available template names.
|
||||
|
||||
@@ -100,7 +100,7 @@ accelerate launch train_text_to_image_sdxl.py \
|
||||
|
||||
The training script is also similar to the [Text-to-image](text2image#training-script) training guide, but it's been modified to support SDXL training. This guide will focus on the code that is unique to the SDXL training script.
|
||||
|
||||
It starts by creating functions to [tokenize the prompts](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L478) to calculate the prompt embeddings, and to compute the image embeddings with the [VAE](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L519). Next, you'll a function to [generate the timesteps weights](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L531) depending on the number of timesteps and the timestep bias strategy to apply.
|
||||
It starts by creating functions to [tokenize the prompts](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L478) to calculate the prompt embeddings, and to compute the image embeddings with the [VAE](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L519). Next, you'll create a function to [generate the timesteps weights](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L531) depending on the number of timesteps and the timestep bias strategy to apply.
|
||||
|
||||
Within the [`main()`](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L572) function, in addition to loading a tokenizer, the script loads a second tokenizer and text encoder because the SDXL architecture uses two of each:
|
||||
|
||||
@@ -250,5 +250,5 @@ print(f'Inference time is {time()-start} sec after compilation')
|
||||
|
||||
Congratulations on training a SDXL model! To learn more about how to use your new model, the following guides may be helpful:
|
||||
|
||||
- Read the [Stable Diffusion XL](../using-diffusers/sdxl) guide to learn how to use it for a variety of different tasks (text-to-image, image-to-image, inpainting), how to use it's refiner model, and the different types of micro-conditionings.
|
||||
- Read the [Stable Diffusion XL](../using-diffusers/sdxl) guide to learn how to use it for a variety of different tasks (text-to-image, image-to-image, inpainting), how to use its refiner model, and the different types of micro-conditionings.
|
||||
- Check out the [DreamBooth](dreambooth) and [LoRA](lora) training guides to learn how to train a personalized SDXL model with just a few example images. These two training techniques can even be combined!
|
||||
@@ -111,7 +111,7 @@ It conditions on a monocular depth estimate of the original image.
|
||||
[Paper](https://huggingface.co/papers/2302.08113)
|
||||
|
||||
MultiDiffusion Panorama defines a new generation process over a pre-trained diffusion model. This process binds together multiple diffusion generation methods that can be readily applied to generate high quality and diverse images. Results adhere to user-provided controls, such as desired aspect ratio (e.g., panorama), and spatial guiding signals, ranging from tight segmentation masks to bounding boxes.
|
||||
MultiDiffusion Panorama allows to generate high-quality images at arbitrary aspect ratios (e.g., panoramas).
|
||||
MultiDiffusion Panorama allows you to generate high-quality images at arbitrary aspect ratios (e.g., panoramas).
|
||||
|
||||
## Fine-tuning your own models
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ print(np.abs(image).sum())
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The `Generator` object should be passed to the pipeline instead of an integer seed. `Generator` maintains a *random state* that is consumed and modified when used. Once consumed, the same `Generator` object produces different results in subsequent calls, even across different pipelines, because it's *state* has changed.
|
||||
The `Generator` object should be passed to the pipeline instead of an integer seed. `Generator` maintains a *random state* that is consumed and modified when used. Once consumed, the same `Generator` object produces different results in subsequent calls, even across different pipelines, because its *state* has changed.
|
||||
|
||||
```py
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
@@ -178,13 +178,12 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["image_processor"] = [
|
||||
"IPAdapterMaskProcessor",
|
||||
"InpaintProcessor",
|
||||
"IPAdapterMaskProcessor",
|
||||
"PixArtImageProcessor",
|
||||
"VaeImageProcessor",
|
||||
"VaeImageProcessorLDM3D",
|
||||
]
|
||||
_import_structure["video_processor"] = ["VideoProcessor"]
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AllegroTransformer3DModel",
|
||||
@@ -396,6 +395,7 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["training_utils"] = ["EMAModel"]
|
||||
_import_structure["video_processor"] = ["VideoProcessor"]
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_scipy_available()):
|
||||
|
||||
@@ -423,7 +423,9 @@ def dispatch_attention_fn(
|
||||
**attention_kwargs,
|
||||
"_parallel_config": parallel_config,
|
||||
}
|
||||
if is_torch_version(">=", "2.5.0"):
|
||||
# Equivalent to `is_torch_version(">=", "2.5.0")` — use module-level constant to avoid
|
||||
# Dynamo tracing into the lru_cache-wrapped `is_torch_version` during torch.compile.
|
||||
if _CAN_USE_FLEX_ATTN:
|
||||
kwargs["enable_gqa"] = enable_gqa
|
||||
|
||||
if _AttentionBackendRegistry._checks_enabled:
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DModel
|
||||
@@ -21,6 +23,9 @@ from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
@@ -129,6 +134,13 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
else:
|
||||
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
|
||||
|
||||
if not 0.0 <= eta <= 1.0:
|
||||
logger.warning(
|
||||
f"`eta` should be between 0 and 1 (inclusive), but received {eta}. "
|
||||
"A value of 0 corresponds to DDIM and 1 corresponds to DDPM. "
|
||||
"Unexpected results may occur for values outside this range."
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
|
||||
@@ -347,7 +347,17 @@ def lru_cache_unless_export(maxsize=128, typed=False):
|
||||
|
||||
@functools.wraps(fn)
|
||||
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
if torch.compiler.is_exporting():
|
||||
compiler = getattr(torch, "compiler", None)
|
||||
is_exporting = bool(compiler and hasattr(compiler, "is_exporting") and compiler.is_exporting())
|
||||
is_compiling = bool(compiler and hasattr(compiler, "is_compiling") and compiler.is_compiling())
|
||||
|
||||
# Fallback for older builds where compiler.is_compiling is unavailable.
|
||||
if not is_compiling:
|
||||
dynamo = getattr(torch, "_dynamo", None)
|
||||
if dynamo is not None and hasattr(dynamo, "is_compiling"):
|
||||
is_compiling = dynamo.is_compiling()
|
||||
|
||||
if is_exporting or is_compiling:
|
||||
return fn(*args, **kwargs)
|
||||
return cached(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -13,24 +13,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderDC
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import NewAutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderDC
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
class AutoencoderDCTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderDC
|
||||
|
||||
def get_autoencoder_dc_config(self):
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
@@ -56,33 +66,29 @@ class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
|
||||
"scaling_factor": 0.41407,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
class TestAutoencoderDC(AutoencoderDCTesterConfig, ModelTesterMixin):
|
||||
base_precision = 1e-2
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_dc_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_inference(self):
|
||||
super().test_layerwise_casting_inference()
|
||||
class TestAutoencoderDCTraining(AutoencoderDCTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for AutoencoderDC."""
|
||||
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
|
||||
class TestAutoencoderDCMemory(AutoencoderDCTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderDC."""
|
||||
|
||||
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_memory(self):
|
||||
super().test_layerwise_casting_memory()
|
||||
|
||||
|
||||
class TestAutoencoderDCSlicingTiling(AutoencoderDCTesterConfig, NewAutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderDC."""
|
||||
|
||||
@@ -13,24 +13,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKLWan
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import NewAutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLWan
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
class AutoencoderKLWanTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderKLWan
|
||||
|
||||
def get_autoencoder_kl_wan_config(self):
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"base_dim": 3,
|
||||
"z_dim": 16,
|
||||
@@ -39,54 +49,40 @@ class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.T
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
image = randn_tensor(
|
||||
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
|
||||
)
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def dummy_input_tiling(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (128, 128)
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
class TestAutoencoderKLWan(AutoencoderKLWanTesterConfig, ModelTesterMixin):
|
||||
base_precision = 1e-2
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_wan_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
class TestAutoencoderKLWanTraining(AutoencoderKLWanTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for AutoencoderKLWan."""
|
||||
|
||||
def prepare_init_args_and_inputs_for_tiling(self):
|
||||
init_dict = self.get_autoencoder_kl_wan_config()
|
||||
inputs_dict = self.dummy_input_tiling
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Gradient checkpointing has not been implemented yet")
|
||||
@pytest.mark.skip(reason="Gradient checkpointing has not been implemented yet")
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_forward_with_norm_groups(self):
|
||||
|
||||
class TestAutoencoderKLWanMemory(AutoencoderKLWanTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderKLWan."""
|
||||
|
||||
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestAutoencoderKLWanSlicingTiling(AutoencoderKLWanTesterConfig, NewAutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderKLWan."""
|
||||
|
||||
@@ -145,3 +145,138 @@ class AutoencoderTesterMixin:
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
), "Without slicing outputs should match with the outputs when slicing is manually disabled."
|
||||
|
||||
|
||||
class NewAutoencoderTesterMixin:
|
||||
@staticmethod
|
||||
def _accepts_generator(model):
|
||||
model_sig = inspect.signature(model.forward)
|
||||
accepts_generator = "generator" in model_sig.parameters
|
||||
return accepts_generator
|
||||
|
||||
@staticmethod
|
||||
def _accepts_norm_num_groups(model_class):
|
||||
model_sig = inspect.signature(model_class.__init__)
|
||||
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
|
||||
return accepts_norm_groups
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
if not self._accepts_norm_num_groups(self.model_class):
|
||||
pytest.skip(f"Test not supported for {self.model_class.__name__}")
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
if not hasattr(self.model_class, "enable_tiling"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not hasattr(model, "use_tiling"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator", None)
|
||||
accepts_generator = self._accepts_generator(model)
|
||||
|
||||
with torch.no_grad():
|
||||
torch.manual_seed(0)
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_tiling, DecoderOutput):
|
||||
output_without_tiling = output_without_tiling.sample
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_with_tiling = model(**inputs_dict)[0]
|
||||
if isinstance(output_with_tiling, DecoderOutput):
|
||||
output_with_tiling = output_with_tiling.sample
|
||||
|
||||
assert (output_without_tiling.cpu() - output_with_tiling.cpu()).max() < 0.5, (
|
||||
"VAE tiling should not affect the inference results"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_tiling_2 = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_tiling_2, DecoderOutput):
|
||||
output_without_tiling_2 = output_without_tiling_2.sample
|
||||
|
||||
assert torch.allclose(output_without_tiling.cpu(), output_without_tiling_2.cpu()), (
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled."
|
||||
)
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
if not hasattr(self.model_class, "enable_slicing"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
if not hasattr(model, "use_slicing"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator", None)
|
||||
accepts_generator = self._accepts_generator(model)
|
||||
|
||||
with torch.no_grad():
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_slicing, DecoderOutput):
|
||||
output_without_slicing = output_without_slicing.sample
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_with_slicing = model(**inputs_dict)[0]
|
||||
if isinstance(output_with_slicing, DecoderOutput):
|
||||
output_with_slicing = output_with_slicing.sample
|
||||
|
||||
assert (output_without_slicing.cpu() - output_with_slicing.cpu()).max() < 0.5, (
|
||||
"VAE slicing should not affect the inference results"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_slicing_2 = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_slicing_2, DecoderOutput):
|
||||
output_without_slicing_2 = output_without_slicing_2.sample
|
||||
|
||||
assert torch.allclose(output_without_slicing.cpu(), output_without_slicing_2.cpu()), (
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user