mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
17 Commits
shm-size
...
ssh-into-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6cc9aa1e5b | ||
|
|
1441b1156f | ||
|
|
588fb5c105 | ||
|
|
5829c8c6c6 | ||
|
|
2862617b2b | ||
|
|
eb24e4bdb2 | ||
|
|
e02ec27e51 | ||
|
|
a41e4c506b | ||
|
|
12625c1c9c | ||
|
|
6de06fc3aa | ||
|
|
c1dc2ae619 | ||
|
|
e15a8e7f17 | ||
|
|
c2fbf8da02 | ||
|
|
0f09b01ab3 | ||
|
|
f6cfe0a1e5 | ||
|
|
e87bf62940 | ||
|
|
3b37fefee9 |
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -73,7 +73,7 @@ body:
|
||||
- ControlNet @sayakpaul @yiyixuxu @DN6
|
||||
- T2I Adapter @sayakpaul @yiyixuxu @DN6
|
||||
- IF @DN6
|
||||
- Text-to-Video / Video-to-Video @DN6 @sayakpaul
|
||||
- Text-to-Video / Video-to-Video @DN6 @a-r-r-o-w
|
||||
- Wuerstchen @DN6
|
||||
- Other: @yiyixuxu @DN6
|
||||
- Improving generation quality: @asomoza
|
||||
|
||||
1
.github/PULL_REQUEST_TEMPLATE.md
vendored
1
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -49,6 +49,7 @@ Core library:
|
||||
Integrations:
|
||||
|
||||
- deepspeed: HF Trainer/Accelerate: @SunMarc
|
||||
- PEFT: @sayakpaul @BenjaminBossan
|
||||
|
||||
HF projects:
|
||||
|
||||
|
||||
100
.github/workflows/push_check.yml
vendored
100
.github/workflows/push_check.yml
vendored
@@ -1,100 +0,0 @@
|
||||
name: Slow Test Memory Checks
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ shm-size ]
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: yes
|
||||
PIPELINE_USAGE_CUTOFF: 50000
|
||||
|
||||
jobs:
|
||||
setup_torch_cuda_pipeline_matrix:
|
||||
name: Setup Torch Pipelines CUDA Slow Tests Matrix
|
||||
runs-on: [ self-hosted, intel-cpu, 8-cpu, ci ]
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
outputs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Fetch Pipeline Matrix
|
||||
id: fetch_pipeline_matrix
|
||||
run: |
|
||||
matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)
|
||||
echo $matrix
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
|
||||
torch_pipelines_cuda_tests:
|
||||
name: Torch Pipelines CUDA Slow Tests
|
||||
needs: setup_torch_cuda_pipeline_matrix
|
||||
strategy:
|
||||
max-parallel: 4
|
||||
fail-fast: false
|
||||
matrix:
|
||||
module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
|
||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus 0
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install hf_transfer
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Slow PyTorch CUDA checkpoint tests on Ubuntu
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
39
.github/workflows/ssh-pr-runner.yml
vendored
Normal file
39
.github/workflows/ssh-pr-runner.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
name: SSH into PR runners
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
docker_image:
|
||||
description: 'Name of the Docker image'
|
||||
required: true
|
||||
|
||||
env:
|
||||
IS_GITHUB_CI: "1"
|
||||
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
|
||||
HF_HOME: /mnt/cache
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
RUN_SLOW: yes
|
||||
|
||||
jobs:
|
||||
ssh_runner:
|
||||
name: "SSH"
|
||||
runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci]
|
||||
container:
|
||||
image: ${{ github.event.inputs.docker_image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --privileged
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Tailscale # In order to be able to SSH when a test fails
|
||||
uses: huggingface/tailscale-action@main
|
||||
with:
|
||||
authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}
|
||||
slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
|
||||
slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
waitForSSH: true
|
||||
2
.github/workflows/ssh-runner.yml
vendored
2
.github/workflows/ssh-runner.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: SSH into runners
|
||||
name: SSH into GPU runners
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
@@ -16,24 +16,24 @@ RUN apt install -y bash \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.9 \
|
||||
python3.9-dev \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
python3.9-venv && \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3.9 -m venv /opt/venv
|
||||
RUN python3.10 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.9 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.9 -m uv pip install --no-cache-dir \
|
||||
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3.9 -m pip install --no-cache-dir \
|
||||
python3.10 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
|
||||
@@ -16,6 +16,7 @@ RUN apt install -y bash \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
libgl1 \
|
||||
python3.10-venv && \
|
||||
|
||||
@@ -17,6 +17,7 @@ RUN apt install -y bash \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
@@ -17,6 +17,7 @@ RUN apt install -y bash \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
@@ -332,6 +332,8 @@
|
||||
title: Latent Consistency Models
|
||||
- local: api/pipelines/latent_diffusion
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/lumina
|
||||
|
||||
75
docs/source/en/api/pipelines/latte.md
Normal file
75
docs/source/en/api/pipelines/latte.md
Normal file
@@ -0,0 +1,75 @@
|
||||
<!-- # Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# Latte
|
||||
|
||||

|
||||
|
||||
[Latte: Latent Diffusion Transformer for Video Generation](https://arxiv.org/abs/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We propose a novel Latent Diffusion Transformer, namely Latte, for video generation. Latte first extracts spatio-temporal tokens from input videos and then adopts a series of Transformer blocks to model video distribution in the latent space. In order to model a substantial number of tokens extracted from videos, four efficient variants are introduced from the perspective of decomposing the spatial and temporal dimensions of input videos. To improve the quality of generated videos, we determine the best practices of Latte through rigorous experimental analysis, including video clip patch embedding, model variants, timestep-class information injection, temporal positional embedding, and learning strategies. Our comprehensive evaluation demonstrates that Latte achieves state-of-the-art performance across four standard video generation datasets, i.e., FaceForensics, SkyTimelapse, UCF101, and Taichi-HD. In addition, we extend Latte to text-to-video generation (T2V) task, where Latte achieves comparable results compared to recent T2V models. We strongly believe that Latte provides valuable insights for future research on incorporating Transformers into diffusion models for video generation.*
|
||||
|
||||
**Highlights**: Latte is a latent diffusion transformer proposed as a backbone for modeling different modalities (trained for text-to-video generation here). It achieves state-of-the-art performance across four standard video benchmarks - [FaceForensics](https://arxiv.org/abs/1803.09179), [SkyTimelapse](https://arxiv.org/abs/1709.07592), [UCF101](https://arxiv.org/abs/1212.0402) and [Taichi-HD](https://arxiv.org/abs/2003.00196). To prepare and download the datasets for evaluation, please refer to [this https URL](https://github.com/Vchitect/Latte/blob/main/docs/datasets_evaluation.md).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Inference
|
||||
|
||||
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
|
||||
|
||||
First, load the pipeline:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LattePipeline
|
||||
|
||||
pipeline = LattePipeline.from_pretrained(
|
||||
"maxin-cn/Latte-1", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
|
||||
|
||||
```python
|
||||
pipeline.transformer.to(memory_format=torch.channels_last)
|
||||
pipeline.vae.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Finally, compile the components and run inference:
|
||||
|
||||
```python
|
||||
pipeline.transformer = torch.compile(pipeline.transformer)
|
||||
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
|
||||
|
||||
video = pipeline(prompt="A dog wearing sunglasses floating in space, surreal, nebulae in background").frames[0]
|
||||
```
|
||||
|
||||
The [benchmark](https://gist.github.com/a-r-r-o-w/4e1694ca46374793c0361d740a99ff19) results on an 80GB A100 machine are:
|
||||
|
||||
```
|
||||
Without torch.compile(): Average inference time: 16.246 seconds.
|
||||
With torch.compile(): Average inference time: 14.573 seconds.
|
||||
```
|
||||
|
||||
## LattePipeline
|
||||
|
||||
[[autodoc]] LattePipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -71,7 +71,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
**kwargs:
|
||||
Supports all the default DiffusionPipeline.get_config_dict kwargs viz..
|
||||
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, token, revision, torch_dtype, device_map.
|
||||
cache_dir, force_download, proxies, local_files_only, token, revision, torch_dtype, device_map.
|
||||
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
@@ -86,7 +86,6 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
"""
|
||||
# Default kwargs from DiffusionPipeline
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
@@ -124,7 +123,6 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
config_dict = DiffusionPipeline.load_config(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
@@ -160,7 +158,6 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
else snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -267,7 +267,6 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_name, **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -283,7 +282,6 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -783,7 +783,6 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
@validate_hf_hub_args
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -795,7 +794,6 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
else snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -783,7 +783,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
@validate_hf_hub_args
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -795,7 +794,6 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
else snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -695,7 +695,6 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
@validate_hf_hub_args
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -707,7 +706,6 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
else snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -1195,7 +1195,7 @@ def main(args):
|
||||
|
||||
# Resolve the c parameter for the Pseudo-Huber loss
|
||||
if args.huber_c is None:
|
||||
args.huber_c = 0.00054 * args.resolution * math.sqrt(unet.config.in_channels)
|
||||
args.huber_c = 0.00054 * args.resolution * math.sqrt(unwrap_model(unet).config.in_channels)
|
||||
|
||||
# Get current number of discretization steps N according to our discretization curriculum
|
||||
current_discretization_steps = get_discretization_steps(
|
||||
|
||||
@@ -310,9 +310,6 @@ class ConfigMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -343,7 +340,6 @@ class ConfigMixin:
|
||||
local_dir = kwargs.pop("local_dir", None)
|
||||
local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
@@ -386,7 +382,6 @@ class ConfigMixin:
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
|
||||
@@ -90,9 +90,7 @@ class IPAdapterMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -135,7 +133,6 @@ class IPAdapterMixin:
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -171,7 +168,6 @@ class IPAdapterMixin:
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -170,9 +170,7 @@ class LoraLoaderMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -194,7 +192,6 @@ class LoraLoaderMixin:
|
||||
# UNet and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -235,7 +232,6 @@ class LoraLoaderMixin:
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -261,7 +257,6 @@ class LoraLoaderMixin:
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -1427,9 +1422,7 @@ class SD3LoraLoaderMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -1450,7 +1443,6 @@ class SD3LoraLoaderMixin:
|
||||
# UNet and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -1481,7 +1473,6 @@ class SD3LoraLoaderMixin:
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -1503,7 +1494,6 @@ class SD3LoraLoaderMixin:
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -242,7 +242,6 @@ def _download_diffusers_model_config_from_hub(
|
||||
revision,
|
||||
proxies,
|
||||
force_download=None,
|
||||
resume_download=None,
|
||||
local_files_only=None,
|
||||
token=None,
|
||||
):
|
||||
@@ -253,7 +252,6 @@ def _download_diffusers_model_config_from_hub(
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
allow_patterns=allow_patterns,
|
||||
@@ -288,9 +286,7 @@ class FromSingleFileMixin:
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -352,7 +348,6 @@ class FromSingleFileMixin:
|
||||
deprecate("original_config_file", "1.0.0", deprecation_message)
|
||||
original_config = original_config_file
|
||||
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -382,7 +377,6 @@ class FromSingleFileMixin:
|
||||
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
@@ -412,7 +406,6 @@ class FromSingleFileMixin:
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
)
|
||||
@@ -435,7 +428,6 @@ class FromSingleFileMixin:
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
local_files_only=False,
|
||||
token=token,
|
||||
)
|
||||
|
||||
@@ -137,9 +137,7 @@ class FromOriginalModelMixin:
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -188,7 +186,6 @@ class FromOriginalModelMixin:
|
||||
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
|
||||
)
|
||||
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -203,7 +200,6 @@ class FromOriginalModelMixin:
|
||||
else:
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path_or_dict,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
|
||||
@@ -313,7 +313,6 @@ def _is_model_weights_in_cached_folder(cached_folder, name):
|
||||
|
||||
def load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=False,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
token=None,
|
||||
@@ -331,7 +330,6 @@ def load_single_file_checkpoint(
|
||||
weights_name=weights_name,
|
||||
force_download=force_download,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -38,7 +38,6 @@ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
||||
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -72,7 +71,6 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -93,7 +91,6 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -308,9 +305,7 @@ class TextualInversionLoaderMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
|
||||
@@ -97,9 +97,7 @@ class UNet2DConditionLoadersMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -140,7 +138,6 @@ class UNet2DConditionLoadersMixin:
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -174,7 +171,6 @@ class UNet2DConditionLoadersMixin:
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -194,7 +190,6 @@ class UNet2DConditionLoadersMixin:
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -191,7 +191,6 @@ def _fetch_index_file(
|
||||
cache_dir,
|
||||
variant,
|
||||
force_download,
|
||||
resume_download,
|
||||
proxies,
|
||||
local_files_only,
|
||||
token,
|
||||
@@ -216,7 +215,6 @@ def _fetch_index_file(
|
||||
weights_name=index_file_in_repo,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -245,9 +245,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -296,7 +294,6 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
from_pt = kwargs.pop("from_pt", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -316,7 +313,6 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -362,7 +358,6 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
|
||||
@@ -434,9 +434,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -518,7 +515,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
@@ -619,7 +615,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -641,7 +636,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
cache_dir=cache_dir,
|
||||
variant=variant,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -663,7 +657,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
weights_name=FLAX_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -685,7 +678,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
index_file,
|
||||
cache_dir=cache_dir,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
@@ -700,7 +692,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -724,7 +715,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -1177,7 +1167,6 @@ class LegacyModelMixin(ModelMixin):
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -1200,7 +1189,6 @@ class LegacyModelMixin(ModelMixin):
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -139,6 +139,18 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||
def disable_forward_chunking(self):
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
|
||||
@@ -391,8 +391,8 @@ class AuraFlowPipeline(DiffusionPipeline):
|
||||
sigmas: List[float] = None,
|
||||
guidance_scale: float = 3.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
height: Optional[int] = 1024,
|
||||
width: Optional[int] = 1024,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
@@ -415,9 +415,9 @@ class AuraFlowPipeline(DiffusionPipeline):
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 512 by default.
|
||||
The height in pixels of the generated image. This is set to 1024 by default for best results.
|
||||
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 512 by default.
|
||||
The width in pixels of the generated image. This is set to 1024 by default for best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
|
||||
@@ -18,6 +18,7 @@ from collections import OrderedDict
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
@@ -45,6 +46,7 @@ from .kandinsky2_2 import (
|
||||
KandinskyV22Pipeline,
|
||||
)
|
||||
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
|
||||
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
|
||||
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
|
||||
from .pag import (
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
@@ -63,6 +65,7 @@ from .stable_diffusion import (
|
||||
)
|
||||
from .stable_diffusion_3 import (
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3InpaintPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import (
|
||||
@@ -94,6 +97,8 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline),
|
||||
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline),
|
||||
("auraflow", AuraFlowPipeline),
|
||||
("kolors", KolorsPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -110,6 +115,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
|
||||
("lcm", LatentConsistencyModelImg2ImgPipeline),
|
||||
("kolors", KolorsImg2ImgPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -117,6 +123,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion", StableDiffusionInpaintPipeline),
|
||||
("stable-diffusion-xl", StableDiffusionXLInpaintPipeline),
|
||||
("stable-diffusion-3", StableDiffusion3InpaintPipeline),
|
||||
("if", IFInpaintingPipeline),
|
||||
("kandinsky", KandinskyInpaintCombinedPipeline),
|
||||
("kandinsky22", KandinskyV22InpaintCombinedPipeline),
|
||||
@@ -256,9 +263,7 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -333,7 +338,6 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
@@ -342,7 +346,6 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"resume_download": resume_download,
|
||||
"proxies": proxies,
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
@@ -547,9 +550,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -624,7 +625,6 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
@@ -633,7 +633,6 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"resume_download": resume_download,
|
||||
"proxies": proxies,
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
@@ -843,9 +842,7 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -920,7 +917,6 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
@@ -929,7 +925,6 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
load_config_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"resume_download": resume_download,
|
||||
"proxies": proxies,
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
|
||||
@@ -254,9 +254,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -316,7 +314,6 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -332,7 +329,6 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
config_dict = cls.load_config(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -363,7 +359,6 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -435,7 +435,6 @@ def _load_empty_model(
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
force_download=kwargs.pop("force_download", False),
|
||||
resume_download=kwargs.pop("resume_download", None),
|
||||
proxies=kwargs.pop("proxies", None),
|
||||
local_files_only=kwargs.pop("local_files_only", False),
|
||||
token=kwargs.pop("token", None),
|
||||
@@ -454,7 +453,6 @@ def _load_empty_model(
|
||||
cached_folder,
|
||||
subfolder=name,
|
||||
force_download=kwargs.pop("force_download", False),
|
||||
resume_download=kwargs.pop("resume_download", None),
|
||||
proxies=kwargs.pop("proxies", None),
|
||||
local_files_only=kwargs.pop("local_files_only", False),
|
||||
token=kwargs.pop("token", None),
|
||||
@@ -544,7 +542,6 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
|
||||
torch_dtype=torch_dtype,
|
||||
cached_folder=kwargs.get("cached_folder", None),
|
||||
force_download=kwargs.get("force_download", None),
|
||||
resume_download=kwargs.get("resume_download", None),
|
||||
proxies=kwargs.get("proxies", None),
|
||||
local_files_only=kwargs.get("local_files_only", None),
|
||||
token=kwargs.get("token", None),
|
||||
|
||||
@@ -533,9 +533,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -625,7 +623,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
@@ -702,7 +699,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
cached_folder = cls.download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
@@ -842,7 +838,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
torch_dtype=torch_dtype,
|
||||
cached_folder=cached_folder,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -910,7 +905,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
|
||||
load_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"resume_download": resume_download,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
@@ -1216,9 +1210,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -1271,7 +1263,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
@@ -1311,7 +1302,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
)
|
||||
|
||||
@@ -1500,7 +1490,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -1523,7 +1512,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
for connected_pipe_repo_id in connected_pipes:
|
||||
download_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"resume_download": resume_download,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -108,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
||||
`algorithm_type="dpmsolver++"`.
|
||||
algorithm_type (`str`, defaults to `dpmsolver++`):
|
||||
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
|
||||
algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type
|
||||
implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is
|
||||
recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in
|
||||
Stable Diffusion.
|
||||
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver`
|
||||
type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
|
||||
`dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
|
||||
paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
|
||||
sampling like in Stable Diffusion.
|
||||
solver_type (`str`, defaults to `midpoint`):
|
||||
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
||||
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
||||
@@ -186,7 +187,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
|
||||
if algorithm_type == "deis":
|
||||
self.register_to_config(algorithm_type="dpmsolver++")
|
||||
else:
|
||||
@@ -197,7 +198,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
||||
|
||||
if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
|
||||
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
|
||||
)
|
||||
@@ -493,10 +494,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned_range"]:
|
||||
if self.config.variance_type in ["learned", "learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
@@ -517,34 +518,43 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
|
||||
return x0_pred
|
||||
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
if self.config.prediction_type == "epsilon":
|
||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||
if self.config.variance_type in ["learned_range"]:
|
||||
model_output = model_output[:, :3]
|
||||
return model_output
|
||||
if self.config.variance_type in ["learned", "learned_range"]:
|
||||
epsilon = model_output[:, :3]
|
||||
else:
|
||||
epsilon = model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma = self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction` for the DPMSolverSinglestepScheduler."
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * epsilon) / alpha_t
|
||||
x0_pred = self._threshold_sample(x0_pred)
|
||||
epsilon = (sample - alpha_t * x0_pred) / sigma_t
|
||||
|
||||
return epsilon
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -594,6 +604,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
x_t = (
|
||||
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
return x_t
|
||||
|
||||
def singlestep_dpm_solver_second_order_update(
|
||||
@@ -601,6 +618,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -688,6 +706,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
- (sigma_t * (torch.exp(h) - 1.0)) * D0
|
||||
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
|
||||
)
|
||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||
assert noise is not None
|
||||
if self.config.solver_type == "midpoint":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||||
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
elif self.config.solver_type == "heun":
|
||||
x_t = (
|
||||
(sigma_t / sigma_s1 * torch.exp(-h)) * sample
|
||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
||||
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||
)
|
||||
return x_t
|
||||
|
||||
def singlestep_dpm_solver_third_order_update(
|
||||
@@ -800,6 +834,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
order: int = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -848,9 +883,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if order == 1:
|
||||
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
|
||||
return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample, noise=noise)
|
||||
elif order == 2:
|
||||
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
|
||||
return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
|
||||
elif order == 3:
|
||||
return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
|
||||
else:
|
||||
@@ -894,6 +929,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -929,6 +965,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
if self.config.algorithm_type == "sde-dpmsolver++":
|
||||
noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
)
|
||||
else:
|
||||
noise = None
|
||||
|
||||
order = self.order_list[self.step_index]
|
||||
|
||||
# For img2img denoising might start with order>1 which is not possible
|
||||
@@ -940,9 +983,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if order == 1:
|
||||
self.sample = sample
|
||||
|
||||
prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)
|
||||
prev_sample = self.singlestep_dpm_solver_update(
|
||||
self.model_outputs, sample=self.sample, order=order, noise=noise
|
||||
)
|
||||
|
||||
# upon completion increase step index by one
|
||||
# upon completion increase step index by one, noise=noise
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -121,9 +121,7 @@ class SchedulerMixin(PushToHubMixin):
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
|
||||
@@ -102,9 +102,7 @@ class FlaxSchedulerMixin(PushToHubMixin):
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
|
||||
@@ -199,7 +199,6 @@ def get_cached_module_file(
|
||||
module_file: str,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: Optional[bool] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
@@ -226,9 +225,7 @@ def get_cached_module_file(
|
||||
cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||
exist. resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
@@ -309,7 +306,6 @@ def get_cached_module_file(
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
)
|
||||
@@ -366,7 +362,6 @@ def get_cached_module_file(
|
||||
f"{module_needed}.py",
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
@@ -382,7 +377,6 @@ def get_class_from_dynamic_module(
|
||||
class_name: Optional[str] = None,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: Optional[bool] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
@@ -419,9 +413,6 @@ def get_class_from_dynamic_module(
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||
exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 of
|
||||
Diffusers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
@@ -458,7 +449,6 @@ def get_class_from_dynamic_module(
|
||||
module_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
|
||||
@@ -271,7 +271,8 @@ if cache_version < 1:
|
||||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
if variant is not None:
|
||||
splits = weights_name.split(".")
|
||||
splits = splits[:-1] + [variant] + splits[-1:]
|
||||
split_index = -2 if weights_name.endswith(".index.json") else -1
|
||||
splits = splits[:-split_index] + [variant] + splits[-split_index:]
|
||||
weights_name = ".".join(splits)
|
||||
|
||||
return weights_name
|
||||
@@ -286,7 +287,6 @@ def _get_model_file(
|
||||
cache_dir: Optional[str] = None,
|
||||
force_download: bool = False,
|
||||
proxies: Optional[Dict] = None,
|
||||
resume_download: Optional[bool] = None,
|
||||
local_files_only: bool = False,
|
||||
token: Optional[str] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
@@ -324,7 +324,6 @@ def _get_model_file(
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
@@ -349,7 +348,6 @@ def _get_model_file(
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
@@ -417,7 +415,6 @@ def _get_checkpoint_shard_files(
|
||||
index_filename,
|
||||
cache_dir=None,
|
||||
proxies=None,
|
||||
resume_download=False,
|
||||
local_files_only=False,
|
||||
token=None,
|
||||
user_agent=None,
|
||||
@@ -475,7 +472,6 @@ def _get_checkpoint_shard_files(
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -40,6 +40,7 @@ from diffusers.models.attention_processor import (
|
||||
)
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging
|
||||
from diffusers.utils.hub_utils import _add_variant
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
get_python_version,
|
||||
@@ -915,6 +916,43 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_torch_gpu
|
||||
def test_sharded_checkpoints_with_variant(self):
|
||||
torch.manual_seed(0)
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
variant = "fp16"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# It doesn't matter if the actual model is in fp16 or not. Just adding the variant and
|
||||
# testing if loading works with the variant when the checkpoint is sharded should be
|
||||
# enough.
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
|
||||
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename)))
|
||||
|
||||
# Now check if the right number of shards exists. First, let's get the number of shards.
|
||||
# Since this number can be dependent on the model being tested, it's important that we calculate it
|
||||
# instead of hardcoding it.
|
||||
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if "generator" in inputs_dict:
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_torch_gpu
|
||||
def test_sharded_checkpoints_device_map(self):
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -194,16 +194,20 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_solver_order_and_type(self):
|
||||
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
|
||||
for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
|
||||
for solver_type in ["midpoint", "heun"]:
|
||||
for order in [1, 2, 3]:
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
self.check_over_configs(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
prediction_type=prediction_type,
|
||||
algorithm_type=algorithm_type,
|
||||
)
|
||||
if algorithm_type == "sde-dpmsolver++":
|
||||
if order == 3:
|
||||
continue
|
||||
else:
|
||||
self.check_over_configs(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
prediction_type=prediction_type,
|
||||
algorithm_type=algorithm_type,
|
||||
)
|
||||
sample = self.full_loop(
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
|
||||
@@ -76,6 +76,7 @@ def fetch_pipeline_modules_to_test():
|
||||
test_modules = []
|
||||
for pipeline_name in pipeline_objects:
|
||||
module = getattr(diffusers, pipeline_name)
|
||||
|
||||
test_module = module.__module__.split(".")[-2].strip()
|
||||
test_modules.append(test_module)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user