mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 22:44:38 +08:00
Compare commits
91 Commits
sdxl-inpat
...
fix-widget
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa39fd7cb6 | ||
|
|
43672b4a22 | ||
|
|
9df3d84382 | ||
|
|
c751449011 | ||
|
|
c1e8bdf1d4 | ||
|
|
78b87dc25a | ||
|
|
0af12f1f8a | ||
|
|
6e123688dc | ||
|
|
f0a588b8e2 | ||
|
|
fa31704420 | ||
|
|
9d79991da0 | ||
|
|
7d865ac9c6 | ||
|
|
fb02316db8 | ||
|
|
98a2b3d2d8 | ||
|
|
2026ec0a02 | ||
|
|
3706aa3305 | ||
|
|
d4f10ea362 | ||
|
|
3aba99af8f | ||
|
|
6683f97959 | ||
|
|
4e7b0cb396 | ||
|
|
35b81fffae | ||
|
|
e0d8c910e9 | ||
|
|
a3d31e3a3e | ||
|
|
84c403aedb | ||
|
|
f4b0b26f7e | ||
|
|
89459a5d56 | ||
|
|
008d9818a2 | ||
|
|
2d43094ffc | ||
|
|
7c05b975b7 | ||
|
|
fe574c8b29 | ||
|
|
90b9479903 | ||
|
|
df76a39e1b | ||
|
|
3369bc810a | ||
|
|
7fe47596af | ||
|
|
59d1caa238 | ||
|
|
c022e52923 | ||
|
|
4039815276 | ||
|
|
5b186b7128 | ||
|
|
ab0459f2b7 | ||
|
|
9c7cc36011 | ||
|
|
325f6c53ed | ||
|
|
43979c2890 | ||
|
|
9ea6ac1b07 | ||
|
|
2c34c7d6dd | ||
|
|
bffadde126 | ||
|
|
35a969d297 | ||
|
|
c5ff469d0e | ||
|
|
bcecfbc873 | ||
|
|
6269045c5b | ||
|
|
6ca9c4af05 | ||
|
|
0532cece97 | ||
|
|
22b45304bf | ||
|
|
457abdf2cf | ||
|
|
ff43dba7ea | ||
|
|
5433962992 | ||
|
|
df476d9f63 | ||
|
|
3e71a20650 | ||
|
|
bf40d7d82a | ||
|
|
32ff4773d4 | ||
|
|
288ceebea5 | ||
|
|
9221da4063 | ||
|
|
57fde871e1 | ||
|
|
68e962395c | ||
|
|
781775ea56 | ||
|
|
fa3c86beaf | ||
|
|
7d0a47f387 | ||
|
|
67b3d3267e | ||
|
|
4e77056885 | ||
|
|
a0c54828a1 | ||
|
|
8d891e6e1b | ||
|
|
cce1fe2d41 | ||
|
|
d816bcb5e8 | ||
|
|
6976cab7ca | ||
|
|
fcbed3fa79 | ||
|
|
b98b314b7a | ||
|
|
74558ff65b | ||
|
|
49644babd3 | ||
|
|
56b3b21693 | ||
|
|
9cef07da5a | ||
|
|
2d94c7838e | ||
|
|
a81334e3f0 | ||
|
|
d704a730cd | ||
|
|
49db233b35 | ||
|
|
93ea26f272 | ||
|
|
f5dfe2a8b0 | ||
|
|
4836cfad98 | ||
|
|
1ccbfbb663 | ||
|
|
29dfe22a8e | ||
|
|
56806cdbfd | ||
|
|
8ccc76ab37 | ||
|
|
c46711e895 |
1
.github/workflows/push_tests_fast.yml
vendored
1
.github/workflows/push_tests_fast.yml
vendored
@@ -98,6 +98,7 @@ jobs:
|
||||
- name: Run example PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_examples' }}
|
||||
run: |
|
||||
python -m pip install peft
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
examples
|
||||
|
||||
@@ -162,6 +162,25 @@ class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
def benchmark(self, args):
|
||||
flush()
|
||||
|
||||
print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
|
||||
|
||||
time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
|
||||
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
|
||||
benchmark_info = BenchmarkInfo(time=time, memory=memory)
|
||||
|
||||
pipeline_class_name = str(self.pipe.__class__.__name__)
|
||||
flush()
|
||||
csv_dict = generate_csv_dict(
|
||||
pipeline_cls=pipeline_class_name, ckpt=self.lora_id, args=args, benchmark_info=benchmark_info
|
||||
)
|
||||
filepath = self.get_result_filepath(args)
|
||||
write_to_csv(filepath, csv_dict)
|
||||
print(f"Logs written to: {filepath}")
|
||||
flush()
|
||||
|
||||
|
||||
class ImageToImageBenchmark(TextToImageBenchmark):
|
||||
pipeline_class = AutoPipelineForImage2Image
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
title: Train a diffusion model
|
||||
- local: tutorials/using_peft_for_inference
|
||||
title: Inference with PEFT
|
||||
- local: tutorials/fast_diffusion
|
||||
title: Accelerate inference of text-to-image diffusion models
|
||||
title: Tutorials
|
||||
- sections:
|
||||
- sections:
|
||||
@@ -198,6 +200,8 @@
|
||||
title: Outputs
|
||||
title: Main Classes
|
||||
- sections:
|
||||
- local: api/loaders/ip_adapter
|
||||
title: IP-Adapter
|
||||
- local: api/loaders/lora
|
||||
title: LoRA
|
||||
- local: api/loaders/single_file
|
||||
@@ -242,14 +246,12 @@
|
||||
- sections:
|
||||
- local: api/pipelines/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/alt_diffusion
|
||||
title: AltDiffusion
|
||||
- local: api/pipelines/amused
|
||||
title: aMUSEd
|
||||
- local: api/pipelines/animatediff
|
||||
title: AnimateDiff
|
||||
- local: api/pipelines/attend_and_excite
|
||||
title: Attend-and-Excite
|
||||
- local: api/pipelines/audio_diffusion
|
||||
title: Audio Diffusion
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/audioldm2
|
||||
@@ -264,12 +266,6 @@
|
||||
title: ControlNet
|
||||
- local: api/pipelines/controlnet_sdxl
|
||||
title: ControlNet with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnetxs
|
||||
title: ControlNet-XS
|
||||
- local: api/pipelines/controlnetxs_sdxl
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/cycle_diffusion
|
||||
title: Cycle Diffusion
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: Dance Diffusion
|
||||
- local: api/pipelines/ddim
|
||||
@@ -300,26 +296,14 @@
|
||||
title: MusicLDM
|
||||
- local: api/pipelines/paint_by_example
|
||||
title: Paint by Example
|
||||
- local: api/pipelines/paradigms
|
||||
title: Parallel Sampling of Diffusion Models
|
||||
- local: api/pipelines/pix2pix_zero
|
||||
title: Pix2Pix Zero
|
||||
- local: api/pipelines/pixart
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/pndm
|
||||
title: PNDM
|
||||
- local: api/pipelines/repaint
|
||||
title: RePaint
|
||||
- local: api/pipelines/score_sde_ve
|
||||
title: Score SDE VE
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
title: Semantic Guidance
|
||||
- local: api/pipelines/shap_e
|
||||
title: Shap-E
|
||||
- local: api/pipelines/spectrogram_diffusion
|
||||
title: Spectrogram Diffusion
|
||||
- sections:
|
||||
- local: api/pipelines/stable_diffusion/overview
|
||||
title: Overview
|
||||
@@ -354,26 +338,16 @@
|
||||
title: Stable Diffusion
|
||||
- local: api/pipelines/stable_unclip
|
||||
title: Stable unCLIP
|
||||
- local: api/pipelines/stochastic_karras_ve
|
||||
title: Stochastic Karras VE
|
||||
- local: api/pipelines/model_editing
|
||||
title: Text-to-image model editing
|
||||
- local: api/pipelines/text_to_video
|
||||
title: Text-to-video
|
||||
- local: api/pipelines/text_to_video_zero
|
||||
title: Text2Video-Zero
|
||||
- local: api/pipelines/unclip
|
||||
title: unCLIP
|
||||
- local: api/pipelines/latent_diffusion_uncond
|
||||
title: Unconditional Latent Diffusion
|
||||
- local: api/pipelines/unidiffuser
|
||||
title: UniDiffuser
|
||||
- local: api/pipelines/value_guided_sampling
|
||||
title: Value-guided sampling
|
||||
- local: api/pipelines/versatile_diffusion
|
||||
title: Versatile Diffusion
|
||||
- local: api/pipelines/vq_diffusion
|
||||
title: VQ Diffusion
|
||||
- local: api/pipelines/wuerstchen
|
||||
title: Wuerstchen
|
||||
title: Pipelines
|
||||
|
||||
25
docs/source/en/api/loaders/ip_adapter.md
Normal file
25
docs/source/en/api/loaders/ip_adapter.md
Normal file
@@ -0,0 +1,25 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# IP-Adapter
|
||||
|
||||
[IP-Adapter](https://hf.co/papers/2308.06721) is a lightweight adapter that enables prompting a diffusion model with an image. This method decouples the cross-attention layers of the image and text features. The image features are generated from an image encoder. Files generated from IP-Adapter are only ~100MBs.
|
||||
|
||||
<Tip>
|
||||
|
||||
Learn how to load an IP-Adapter checkpoint and image in the [IP-Adapter](../../using-diffusers/loading_adapters#ip-adapter) loading guide.
|
||||
|
||||
</Tip>
|
||||
|
||||
## IPAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
|
||||
@@ -49,12 +49,12 @@ make_image_grid([original_image, mask_image, image], rows=1, cols=3)
|
||||
|
||||
## AsymmetricAutoencoderKL
|
||||
|
||||
[[autodoc]] models.autoencoder_asym_kl.AsymmetricAutoencoderKL
|
||||
[[autodoc]] models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.vae.DecoderOutput
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
|
||||
@@ -54,4 +54,4 @@ image
|
||||
|
||||
## AutoencoderTinyOutput
|
||||
|
||||
[[autodoc]] models.autoencoder_tiny.AutoencoderTinyOutput
|
||||
[[autodoc]] models.autoencoders.autoencoder_tiny.AutoencoderTinyOutput
|
||||
|
||||
@@ -36,11 +36,11 @@ model = AutoencoderKL.from_single_file(url)
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.vae.DecoderOutput
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
|
||||
## FlaxAutoencoderKL
|
||||
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AltDiffusion
|
||||
|
||||
AltDiffusion was proposed in [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://huggingface.co/papers/2211.06679) by Zhongzhi Chen, Guang Liu, Bo-Wen Zhang, Fulong Ye, Qinghong Yang, Ledell Wu.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*In this work, we present a conceptually simple and effective method to train a strong bilingual/multilingual multimodal representation model. Starting from the pre-trained multimodal representation model CLIP released by OpenAI, we altered its text encoder with a pre-trained multilingual text encoder XLM-R, and aligned both languages and image representations by a two-stage training schema consisting of teacher learning and contrastive learning. We validate our method through evaluations of a wide range of tasks. We set new state-of-the-art performances on a bunch of tasks including ImageNet-CN, Flicker30k-CN, COCO-CN and XTD. Further, we obtain very close performances with CLIP on almost all tasks, suggesting that one can simply alter the text encoder in CLIP for extended capabilities such as multilingual understanding. Our models and code are available at [this https URL](https://github.com/FlagAI-Open/FlagAI).*
|
||||
|
||||
## Tips
|
||||
|
||||
`AltDiffusion` is conceptually the same as [Stable Diffusion](./stable_diffusion/overview).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## AltDiffusionPipeline
|
||||
|
||||
[[autodoc]] AltDiffusionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AltDiffusionImg2ImgPipeline
|
||||
|
||||
[[autodoc]] AltDiffusionImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AltDiffusionPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.alt_diffusion.AltDiffusionPipelineOutput
|
||||
- all
|
||||
- __call__
|
||||
42
docs/source/en/api/pipelines/amused.md
Normal file
42
docs/source/en/api/pipelines/amused.md
Normal file
@@ -0,0 +1,42 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# aMUSEd
|
||||
|
||||
Amused is a lightweight text to image model based off of the [muse](https://arxiv.org/pdf/2301.00704.pdf) architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once.
|
||||
|
||||
Amused is a vqvae token based transformer that can generate an image in fewer forward passes than many diffusion models. In contrast with muse, it uses the smaller text encoder CLIP-L/14 instead of t5-xxl. Due to its small parameter count and few forward pass generation process, amused can generate many images quickly. This benefit is seen particularly at larger batch sizes.
|
||||
|
||||
| Model | Params |
|
||||
|-------|--------|
|
||||
| [amused-256](https://huggingface.co/amused/amused-256) | 603M |
|
||||
| [amused-512](https://huggingface.co/amused/amused-512) | 608M |
|
||||
|
||||
## AmusedPipeline
|
||||
|
||||
[[autodoc]] AmusedPipeline
|
||||
- __call__
|
||||
- all
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
[[autodoc]] AmusedImg2ImgPipeline
|
||||
- __call__
|
||||
- all
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
[[autodoc]] AmusedInpaintPipeline
|
||||
- __call__
|
||||
- all
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
@@ -38,16 +38,21 @@ The following example demonstrates how to use a *MotionAdapter* checkpoint with
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
|
||||
from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
# Load the motion adapter
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
|
||||
# load SD 1.5 based finetuned model
|
||||
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
|
||||
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter)
|
||||
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
|
||||
scheduler = DDIMScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
clip_sample=False,
|
||||
timestep_spacing="linspace",
|
||||
beta_schedule="linear",
|
||||
steps_offset=1,
|
||||
)
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
@@ -70,6 +75,7 @@ output = pipe(
|
||||
)
|
||||
frames = output.frames[0]
|
||||
export_to_gif(frames, "animation.gif")
|
||||
|
||||
```
|
||||
|
||||
Here are some sample outputs:
|
||||
@@ -88,7 +94,7 @@ Here are some sample outputs:
|
||||
|
||||
<Tip>
|
||||
|
||||
AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples.
|
||||
AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the AnimateDiff checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -98,18 +104,25 @@ Motion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-mo
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
|
||||
from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
# Load the motion adapter
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
|
||||
# load SD 1.5 based finetuned model
|
||||
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
|
||||
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter)
|
||||
pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
|
||||
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
|
||||
pipe.load_lora_weights(
|
||||
"guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out"
|
||||
)
|
||||
|
||||
scheduler = DDIMScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
clip_sample=False,
|
||||
beta_schedule="linear",
|
||||
timestep_spacing="linspace",
|
||||
steps_offset=1,
|
||||
)
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
@@ -132,6 +145,7 @@ output = pipe(
|
||||
)
|
||||
frames = output.frames[0]
|
||||
export_to_gif(frames, "animation.gif")
|
||||
|
||||
```
|
||||
|
||||
<table>
|
||||
@@ -160,21 +174,30 @@ Then you can use the following code to combine Motion LoRAs.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
|
||||
from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
# Load the motion adapter
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
|
||||
# load SD 1.5 based finetuned model
|
||||
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
|
||||
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter)
|
||||
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
|
||||
|
||||
pipe.load_lora_weights("diffusers/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
|
||||
pipe.load_lora_weights("diffusers/animatediff-motion-lora-pan-left", adapter_name="pan-left")
|
||||
pipe.load_lora_weights(
|
||||
"diffusers/animatediff-motion-lora-zoom-out", adapter_name="zoom-out",
|
||||
)
|
||||
pipe.load_lora_weights(
|
||||
"diffusers/animatediff-motion-lora-pan-left", adapter_name="pan-left",
|
||||
)
|
||||
pipe.set_adapters(["zoom-out", "pan-left"], adapter_weights=[1.0, 1.0])
|
||||
|
||||
scheduler = DDIMScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
clip_sample=False,
|
||||
timestep_spacing="linspace",
|
||||
beta_schedule="linear",
|
||||
steps_offset=1,
|
||||
)
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
@@ -197,6 +220,7 @@ output = pipe(
|
||||
)
|
||||
frames = output.frames[0]
|
||||
export_to_gif(frames, "animation.gif")
|
||||
|
||||
```
|
||||
|
||||
<table>
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Audio Diffusion
|
||||
|
||||
[Audio Diffusion](https://github.com/teticio/audio-diffusion) is by Robert Dargavel Smith, and it leverages the recent advances in image generation from diffusion models by converting audio samples to and from Mel spectrogram images.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## AudioDiffusionPipeline
|
||||
[[autodoc]] AudioDiffusionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AudioPipelineOutput
|
||||
[[autodoc]] pipelines.AudioPipelineOutput
|
||||
|
||||
## ImagePipelineOutput
|
||||
[[autodoc]] pipelines.ImagePipelineOutput
|
||||
|
||||
## Mel
|
||||
[[autodoc]] Mel
|
||||
@@ -1,33 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Cycle Diffusion
|
||||
|
||||
Cycle Diffusion is a text guided image-to-image generation model proposed in [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://huggingface.co/papers/2210.05559) by Chen Henry Wu, Fernando De la Torre.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Diffusion models have achieved unprecedented performance in generative modeling. The commonly-adopted formulation of the latent code of diffusion models is a sequence of gradually denoised samples, as opposed to the simpler (e.g., Gaussian) latent space of GANs, VAEs, and normalizing flows. This paper provides an alternative, Gaussian formulation of the latent space of various diffusion models, as well as an invertible DPM-Encoder that maps images into the latent space. While our formulation is purely based on the definition of diffusion models, we demonstrate several intriguing consequences. (1) Empirically, we observe that a common latent space emerges from two diffusion models trained independently on related domains. In light of this finding, we propose CycleDiffusion, which uses DPM-Encoder for unpaired image-to-image translation. Furthermore, applying CycleDiffusion to text-to-image diffusion models, we show that large-scale text-to-image diffusion models can be used as zero-shot image-to-image editors. (2) One can guide pre-trained diffusion models and GANs by controlling the latent codes in a unified, plug-and-play formulation based on energy-based models. Using the CLIP model and a face recognition model as guidance, we demonstrate that diffusion models have better coverage of low-density sub-populations and individuals than GANs. The code is publicly available at [this https URL](https://github.com/ChenWu98/cycle-diffusion).*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## CycleDiffusionPipeline
|
||||
[[autodoc]] CycleDiffusionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPiplineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
@@ -1,35 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Unconditional Latent Diffusion
|
||||
|
||||
Unconditional Latent Diffusion was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*By decomposing the image formation process into a sequential application of denoising autoencoders, diffusion models (DMs) achieve state-of-the-art synthesis results on image data and beyond. Additionally, their formulation allows for a guiding mechanism to control the image generation process without retraining. However, since these models typically operate directly in pixel space, optimization of powerful DMs often consumes hundreds of GPU days and inference is expensive due to sequential evaluations. To enable DM training on limited computational resources while retaining their quality and flexibility, we apply them in the latent space of powerful pretrained autoencoders. In contrast to previous work, training diffusion models on such a representation allows for the first time to reach a near-optimal point between complexity reduction and detail preservation, greatly boosting visual fidelity. By introducing cross-attention layers into the model architecture, we turn diffusion models into powerful and flexible generators for general conditioning inputs such as text or bounding boxes and high-resolution synthesis becomes possible in a convolutional manner. Our latent diffusion models (LDMs) achieve a new state of the art for image inpainting and highly competitive performance on various tasks, including unconditional image generation, semantic scene synthesis, and super-resolution, while significantly reducing computational requirements compared to pixel-based DMs.*
|
||||
|
||||
The original codebase can be found at [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## LDMPipeline
|
||||
[[autodoc]] LDMPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ImagePipelineOutput
|
||||
[[autodoc]] pipelines.ImagePipelineOutput
|
||||
@@ -1,35 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Text-to-image model editing
|
||||
|
||||
[Editing Implicit Assumptions in Text-to-Image Diffusion Models](https://huggingface.co/papers/2303.08084) is by Hadas Orgad, Bahjat Kawar, and Yonatan Belinkov. This pipeline enables editing diffusion model weights, such that its assumptions of a given concept are changed. The resulting change is expected to take effect in all prompt generations related to the edited concept.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Text-to-image diffusion models often make implicit assumptions about the world when generating images. While some assumptions are useful (e.g., the sky is blue), they can also be outdated, incorrect, or reflective of social biases present in the training data. Thus, there is a need to control these assumptions without requiring explicit user input or costly re-training. In this work, we aim to edit a given implicit assumption in a pre-trained diffusion model. Our Text-to-Image Model Editing method, TIME for short, receives a pair of inputs: a "source" under-specified prompt for which the model makes an implicit assumption (e.g., "a pack of roses"), and a "destination" prompt that describes the same setting, but with a specified desired attribute (e.g., "a pack of blue roses"). TIME then updates the model's cross-attention layers, as these layers assign visual meaning to textual tokens. We edit the projection matrices in these layers such that the source prompt is projected close to the destination prompt. Our method is highly efficient, as it modifies a mere 2.2% of the model's parameters in under one second. To evaluate model editing approaches, we introduce TIMED (TIME Dataset), containing 147 source and destination prompt pairs from various domains. Our experiments (using Stable Diffusion) show that TIME is successful in model editing, generalizes well for related prompts unseen during editing, and imposes minimal effect on unrelated generations.*
|
||||
|
||||
You can find additional information about model editing on the [project page](https://time-diffusion.github.io/), [original codebase](https://github.com/bahjat-kawar/time-diffusion), and try it out in a [demo](https://huggingface.co/spaces/bahjat-kawar/time-diffusion).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionModelEditingPipeline
|
||||
[[autodoc]] StableDiffusionModelEditingPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
@@ -1,51 +0,0 @@
|
||||
<!--Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Parallel Sampling of Diffusion Models
|
||||
|
||||
[Parallel Sampling of Diffusion Models](https://huggingface.co/papers/2305.16317) is by Andy Shih, Suneel Belkhale, Stefano Ermon, Dorsa Sadigh, Nima Anari.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Diffusion models are powerful generative models but suffer from slow sampling, often taking 1000 sequential denoising steps for one sample. As a result, considerable efforts have been directed toward reducing the number of denoising steps, but these methods hurt sample quality. Instead of reducing the number of denoising steps (trading quality for speed), in this paper we explore an orthogonal approach: can we run the denoising steps in parallel (trading compute for speed)? In spite of the sequential nature of the denoising steps, we show that surprisingly it is possible to parallelize sampling via Picard iterations, by guessing the solution of future denoising steps and iteratively refining until convergence. With this insight, we present ParaDiGMS, a novel method to accelerate the sampling of pretrained diffusion models by denoising multiple steps in parallel. ParaDiGMS is the first diffusion sampling method that enables trading compute for speed and is even compatible with existing fast sampling techniques such as DDIM and DPMSolver. Using ParaDiGMS, we improve sampling speed by 2-4x across a range of robotics and image generation models, giving state-of-the-art sampling speeds of 0.2s on 100-step DiffusionPolicy and 14.6s on 1000-step StableDiffusion-v2 with no measurable degradation of task reward, FID score, or CLIP score.*
|
||||
|
||||
The original codebase can be found at [AndyShih12/paradigms](https://github.com/AndyShih12/paradigms), and the pipeline was contributed by [AndyShih12](https://github.com/AndyShih12). ❤️
|
||||
|
||||
## Tips
|
||||
|
||||
This pipeline improves sampling speed by running denoising steps in parallel, at the cost of increased total FLOPs.
|
||||
Therefore, it is better to call this pipeline when running on multiple GPUs. Otherwise, without enough GPU bandwidth
|
||||
sampling may be even slower than sequential sampling.
|
||||
|
||||
The two parameters to play with are `parallel` (batch size) and `tolerance`.
|
||||
- If it fits in memory, for a 1000-step DDPM you can aim for a batch size of around 100 (for example, 8 GPUs and `batch_per_device=12` to get `parallel=96`). A higher batch size may not fit in memory, and lower batch size gives less parallelism.
|
||||
- For tolerance, using a higher tolerance may get better speedups but can risk sample quality degradation. If there is quality degradation with the default tolerance, then use a lower tolerance like `0.001`.
|
||||
|
||||
For a 1000-step DDPM on 8 A100 GPUs, you can expect around a 3x speedup from [`StableDiffusionParadigmsPipeline`] compared to the [`StableDiffusionPipeline`]
|
||||
by setting `parallel=80` and `tolerance=0.1`.
|
||||
|
||||
🤗 Diffusers offers [distributed inference support](../../training/distributed_inference) for generating multiple prompts
|
||||
in parallel on multiple GPUs. But [`StableDiffusionParadigmsPipeline`] is designed for speeding up sampling of a single prompt by using multiple GPUs.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionParadigmsPipeline
|
||||
[[autodoc]] StableDiffusionParadigmsPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
@@ -1,289 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Pix2Pix Zero
|
||||
|
||||
[Zero-shot Image-to-Image Translation](https://huggingface.co/papers/2302.03027) is by Gaurav Parmar, Krishna Kumar Singh, Richard Zhang, Yijun Li, Jingwan Lu, and Jun-Yan Zhu.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Large-scale text-to-image generative models have shown their remarkable ability to synthesize diverse and high-quality images. However, it is still challenging to directly apply these models for editing real images for two reasons. First, it is hard for users to come up with a perfect text prompt that accurately describes every visual detail in the input image. Second, while existing models can introduce desirable changes in certain regions, they often dramatically alter the input content and introduce unexpected changes in unwanted regions. In this work, we propose pix2pix-zero, an image-to-image translation method that can preserve the content of the original image without manual prompting. We first automatically discover editing directions that reflect desired edits in the text embedding space. To preserve the general content structure after editing, we further propose cross-attention guidance, which aims to retain the cross-attention maps of the input image throughout the diffusion process. In addition, our method does not need additional training for these edits and can directly use the existing pre-trained text-to-image diffusion model. We conduct extensive experiments and show that our method outperforms existing and concurrent works for both real and synthetic image editing.*
|
||||
|
||||
You can find additional information about Pix2Pix Zero on the [project page](https://pix2pixzero.github.io/), [original codebase](https://github.com/pix2pixzero/pix2pix-zero), and try it out in a [demo](https://huggingface.co/spaces/pix2pix-zero-library/pix2pix-zero-demo).
|
||||
|
||||
## Tips
|
||||
|
||||
* The pipeline can be conditioned on real input images. Check out the code examples below to know more.
|
||||
* The pipeline exposes two arguments namely `source_embeds` and `target_embeds`
|
||||
that let you control the direction of the semantic edits in the final image to be generated. Let's say,
|
||||
you wanted to translate from "cat" to "dog". In this case, the edit direction will be "cat -> dog". To reflect
|
||||
this in the pipeline, you simply have to set the embeddings related to the phrases including "cat" to
|
||||
`source_embeds` and "dog" to `target_embeds`. Refer to the code example below for more details.
|
||||
* When you're using this pipeline from a prompt, specify the _source_ concept in the prompt. Taking
|
||||
the above example, a valid input prompt would be: "a high resolution painting of a **cat** in the style of van gogh".
|
||||
* If you wanted to reverse the direction in the example above, i.e., "dog -> cat", then it's recommended to:
|
||||
* Swap the `source_embeds` and `target_embeds`.
|
||||
* Change the input prompt to include "dog".
|
||||
* To learn more about how the source and target embeddings are generated, refer to the [original paper](https://arxiv.org/abs/2302.03027). Below, we also provide some directions on how to generate the embeddings.
|
||||
* Note that the quality of the outputs generated with this pipeline is dependent on how good the `source_embeds` and `target_embeds` are. Please, refer to [this discussion](#generating-source-and-target-embeddings) for some suggestions on the topic.
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Demo
|
||||
|---|---|:---:|
|
||||
| [StableDiffusionPix2PixZeroPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py) | *Text-Based Image Editing* | [🤗 Space](https://huggingface.co/spaces/pix2pix-zero-library/pix2pix-zero-demo) |
|
||||
|
||||
<!-- TODO: add Colab -->
|
||||
|
||||
## Usage example
|
||||
|
||||
### Based on an image generated with the input prompt
|
||||
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline
|
||||
|
||||
|
||||
def download(embedding_url, local_filepath):
|
||||
r = requests.get(embedding_url)
|
||||
with open(local_filepath, "wb") as f:
|
||||
f.write(r.content)
|
||||
|
||||
|
||||
model_ckpt = "CompVis/stable-diffusion-v1-4"
|
||||
pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
model_ckpt, conditions_input_image=False, torch_dtype=torch.float16
|
||||
)
|
||||
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline.to("cuda")
|
||||
|
||||
prompt = "a high resolution painting of a cat in the style of van gogh"
|
||||
src_embs_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt"
|
||||
target_embs_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt"
|
||||
|
||||
for url in [src_embs_url, target_embs_url]:
|
||||
download(url, url.split("/")[-1])
|
||||
|
||||
src_embeds = torch.load(src_embs_url.split("/")[-1])
|
||||
target_embeds = torch.load(target_embs_url.split("/")[-1])
|
||||
|
||||
image = pipeline(
|
||||
prompt,
|
||||
source_embeds=src_embeds,
|
||||
target_embeds=target_embeds,
|
||||
num_inference_steps=50,
|
||||
cross_attention_guidance_amount=0.15,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
### Based on an input image
|
||||
|
||||
When the pipeline is conditioned on an input image, we first obtain an inverted
|
||||
noise from it using a `DDIMInverseScheduler` with the help of a generated caption. Then the inverted noise is used to start the generation process.
|
||||
|
||||
First, let's load our pipeline:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import BlipForConditionalGeneration, BlipProcessor
|
||||
from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline
|
||||
|
||||
captioner_id = "Salesforce/blip-image-captioning-base"
|
||||
processor = BlipProcessor.from_pretrained(captioner_id)
|
||||
model = BlipForConditionalGeneration.from_pretrained(captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
|
||||
sd_model_ckpt = "CompVis/stable-diffusion-v1-4"
|
||||
pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
sd_model_ckpt,
|
||||
caption_generator=model,
|
||||
caption_processor=processor,
|
||||
torch_dtype=torch.float16,
|
||||
safety_checker=None,
|
||||
)
|
||||
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
Then, we load an input image for conditioning and obtain a suitable caption for it:
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
|
||||
raw_image = load_image(url).resize((512, 512))
|
||||
caption = pipeline.generate_caption(raw_image)
|
||||
caption
|
||||
```
|
||||
|
||||
Then we employ the generated caption and the input image to get the inverted noise:
|
||||
|
||||
```py
|
||||
generator = torch.manual_seed(0)
|
||||
inv_latents = pipeline.invert(caption, image=raw_image, generator=generator).latents
|
||||
```
|
||||
|
||||
Now, generate the image with edit directions:
|
||||
|
||||
```py
|
||||
# See the "Generating source and target embeddings" section below to
|
||||
# automate the generation of these captions with a pre-trained model like Flan-T5 as explained below.
|
||||
source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
|
||||
target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
|
||||
|
||||
source_embeds = pipeline.get_embeds(source_prompts, batch_size=2)
|
||||
target_embeds = pipeline.get_embeds(target_prompts, batch_size=2)
|
||||
|
||||
|
||||
image = pipeline(
|
||||
caption,
|
||||
source_embeds=source_embeds,
|
||||
target_embeds=target_embeds,
|
||||
num_inference_steps=50,
|
||||
cross_attention_guidance_amount=0.15,
|
||||
generator=generator,
|
||||
latents=inv_latents,
|
||||
negative_prompt=caption,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
## Generating source and target embeddings
|
||||
|
||||
The authors originally used the [GPT-3 API](https://openai.com/api/) to generate the source and target captions for discovering
|
||||
edit directions. However, we can also leverage open source and public models for the same purpose.
|
||||
Below, we provide an end-to-end example with the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model
|
||||
for generating captions and [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for
|
||||
computing embeddings on the generated captions.
|
||||
|
||||
**1. Load the generation model**:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
|
||||
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
**2. Construct a starting prompt**:
|
||||
|
||||
```py
|
||||
source_concept = "cat"
|
||||
target_concept = "dog"
|
||||
|
||||
source_text = f"Provide a caption for images containing a {source_concept}. "
|
||||
"The captions should be in English and should be no longer than 150 characters."
|
||||
|
||||
target_text = f"Provide a caption for images containing a {target_concept}. "
|
||||
"The captions should be in English and should be no longer than 150 characters."
|
||||
```
|
||||
|
||||
Here, we're interested in the "cat -> dog" direction.
|
||||
|
||||
**3. Generate captions**:
|
||||
|
||||
We can use a utility like so for this purpose.
|
||||
|
||||
```py
|
||||
def generate_captions(input_prompt):
|
||||
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda")
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids, temperature=0.8, num_return_sequences=16, do_sample=True, max_new_tokens=128, top_k=10
|
||||
)
|
||||
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
```
|
||||
|
||||
And then we just call it to generate our captions:
|
||||
|
||||
```py
|
||||
source_captions = generate_captions(source_text)
|
||||
target_captions = generate_captions(target_concept)
|
||||
print(source_captions, target_captions, sep='\n')
|
||||
```
|
||||
|
||||
We encourage you to play around with the different parameters supported by the
|
||||
`generate()` method ([documentation](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.generation_tf_utils.TFGenerationMixin.generate)) for the generation quality you are looking for.
|
||||
|
||||
**4. Load the embedding model**:
|
||||
|
||||
Here, we need to use the same text encoder model used by the subsequent Stable Diffusion model.
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPix2PixZeroPipeline
|
||||
|
||||
pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline = pipeline.to("cuda")
|
||||
tokenizer = pipeline.tokenizer
|
||||
text_encoder = pipeline.text_encoder
|
||||
```
|
||||
|
||||
**5. Compute embeddings**:
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
def embed_captions(sentences, tokenizer, text_encoder, device="cuda"):
|
||||
with torch.no_grad():
|
||||
embeddings = []
|
||||
for sent in sentences:
|
||||
text_inputs = tokenizer(
|
||||
sent,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
|
||||
embeddings.append(prompt_embeds)
|
||||
return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0)
|
||||
|
||||
source_embeddings = embed_captions(source_captions, tokenizer, text_encoder)
|
||||
target_embeddings = embed_captions(target_captions, tokenizer, text_encoder)
|
||||
```
|
||||
|
||||
And you're done! [Here](https://colab.research.google.com/drive/1tz2C1EdfZYAPlzXXbTnf-5PRBiR8_R1F?usp=sharing) is a Colab Notebook that you can use to interact with the entire process.
|
||||
|
||||
Now, you can use these embeddings directly while calling the pipeline:
|
||||
|
||||
```py
|
||||
from diffusers import DDIMScheduler
|
||||
|
||||
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
image = pipeline(
|
||||
prompt,
|
||||
source_embeds=source_embeddings,
|
||||
target_embeds=target_embeddings,
|
||||
num_inference_steps=50,
|
||||
cross_attention_guidance_amount=0.15,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionPix2PixZeroPipeline
|
||||
[[autodoc]] StableDiffusionPix2PixZeroPipeline
|
||||
- __call__
|
||||
- all
|
||||
@@ -1,35 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# PNDM
|
||||
|
||||
[Pseudo Numerical Methods for Diffusion Models on Manifolds](https://huggingface.co/papers/2202.09778) (PNDM) is by Luping Liu, Yi Ren, Zhijie Lin and Zhou Zhao.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Denoising Diffusion Probabilistic Models (DDPMs) can generate high-quality samples such as image and audio samples. However, DDPMs require hundreds to thousands of iterations to produce final samples. Several prior works have successfully accelerated DDPMs through adjusting the variance schedule (e.g., Improved Denoising Diffusion Probabilistic Models) or the denoising equation (e.g., Denoising Diffusion Implicit Models (DDIMs)). However, these acceleration methods cannot maintain the quality of samples and even introduce new noise at a high speedup rate, which limit their practicability. To accelerate the inference process while keeping the sample quality, we provide a fresh perspective that DDPMs should be treated as solving differential equations on manifolds. Under such a perspective, we propose pseudo numerical methods for diffusion models (PNDMs). Specifically, we figure out how to solve differential equations on manifolds and show that DDIMs are simple cases of pseudo numerical methods. We change several classical numerical methods to corresponding pseudo numerical methods and find that the pseudo linear multi-step method is the best in most situations. According to our experiments, by directly using pre-trained models on Cifar10, CelebA and LSUN, PNDMs can generate higher quality synthetic images with only 50 steps compared with 1000-step DDIMs (20x speedup), significantly outperform DDIMs with 250 steps (by around 0.4 in FID) and have good generalization on different variance schedules.*
|
||||
|
||||
The original codebase can be found at [luping-liu/PNDM](https://github.com/luping-liu/PNDM).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## PNDMPipeline
|
||||
[[autodoc]] PNDMPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ImagePipelineOutput
|
||||
[[autodoc]] pipelines.ImagePipelineOutput
|
||||
@@ -1,37 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# RePaint
|
||||
|
||||
[RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://huggingface.co/papers/2201.09865) is by Andreas Lugmayr, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, Luc Van Gool.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Free-form inpainting is the task of adding new content to an image in the regions specified by an arbitrary binary mask. Most existing approaches train for a certain distribution of masks, which limits their generalization capabilities to unseen mask types. Furthermore, training with pixel-wise and perceptual losses often leads to simple textural extensions towards the missing areas instead of semantically meaningful generation. In this work, we propose RePaint: A Denoising Diffusion Probabilistic Model (DDPM) based inpainting approach that is applicable to even extreme masks. We employ a pretrained unconditional DDPM as the generative prior. To condition the generation process, we only alter the reverse diffusion iterations by sampling the unmasked regions using the given image information. Since this technique does not modify or condition the original DDPM network itself, the model produces high-quality and diverse output images for any inpainting form. We validate our method for both faces and general-purpose image inpainting using standard and extreme masks.
|
||||
RePaint outperforms state-of-the-art Autoregressive, and GAN approaches for at least five out of six mask distributions.*
|
||||
|
||||
The original codebase can be found at [andreas128/RePaint](https://github.com/andreas128/RePaint).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## RePaintPipeline
|
||||
[[autodoc]] RePaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ImagePipelineOutput
|
||||
[[autodoc]] pipelines.ImagePipelineOutput
|
||||
@@ -1,35 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Score SDE VE
|
||||
|
||||
[Score-Based Generative Modeling through Stochastic Differential Equations](https://huggingface.co/papers/2011.13456) (Score SDE) is by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon and Ben Poole. This pipeline implements the variance expanding (VE) variant of the stochastic differential equation method.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling. We present a stochastic differential equation (SDE) that smoothly transforms a complex data distribution to a known prior distribution by slowly injecting noise, and a corresponding reverse-time SDE that transforms the prior distribution back into the data distribution by slowly removing the noise. Crucially, the reverse-time SDE depends only on the time-dependent gradient field (\aka, score) of the perturbed data distribution. By leveraging advances in score-based generative modeling, we can accurately estimate these scores with neural networks, and use numerical SDE solvers to generate samples. We show that this framework encapsulates previous approaches in score-based generative modeling and diffusion probabilistic modeling, allowing for new sampling procedures and new modeling capabilities. In particular, we introduce a predictor-corrector framework to correct errors in the evolution of the discretized reverse-time SDE. We also derive an equivalent neural ODE that samples from the same distribution as the SDE, but additionally enables exact likelihood computation, and improved sampling efficiency. In addition, we provide a new way to solve inverse problems with score-based models, as demonstrated with experiments on class-conditional generation, image inpainting, and colorization. Combined with multiple architectural improvements, we achieve record-breaking performance for unconditional image generation on CIFAR-10 with an Inception score of 9.89 and FID of 2.20, a competitive likelihood of 2.99 bits/dim, and demonstrate high fidelity generation of 1024 x 1024 images for the first time from a score-based generative model.*
|
||||
|
||||
The original codebase can be found at [yang-song/score_sde_pytorch](https://github.com/yang-song/score_sde_pytorch).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## ScoreSdeVePipeline
|
||||
[[autodoc]] ScoreSdeVePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ImagePipelineOutput
|
||||
[[autodoc]] pipelines.ImagePipelineOutput
|
||||
@@ -1,37 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Spectrogram Diffusion
|
||||
|
||||
[Spectrogram Diffusion](https://huggingface.co/papers/2206.05408) is by Curtis Hawthorne, Ian Simon, Adam Roberts, Neil Zeghidour, Josh Gardner, Ethan Manilow, and Jesse Engel.
|
||||
|
||||
*An ideal music synthesizer should be both interactive and expressive, generating high-fidelity audio in realtime for arbitrary combinations of instruments and notes. Recent neural synthesizers have exhibited a tradeoff between domain-specific models that offer detailed control of only specific instruments, or raw waveform models that can train on any music but with minimal control and slow generation. In this work, we focus on a middle ground of neural synthesizers that can generate audio from MIDI sequences with arbitrary combinations of instruments in realtime. This enables training on a wide range of transcription datasets with a single model, which in turn offers note-level control of composition and instrumentation across a wide range of instruments. We use a simple two-stage process: MIDI to spectrograms with an encoder-decoder Transformer, then spectrograms to audio with a generative adversarial network (GAN) spectrogram inverter. We compare training the decoder as an autoregressive model and as a Denoising Diffusion Probabilistic Model (DDPM) and find that the DDPM approach is superior both qualitatively and as measured by audio reconstruction and Fréchet distance metrics. Given the interactivity and generality of this approach, we find this to be a promising first step towards interactive and expressive neural synthesis for arbitrary combinations of instruments and notes.*
|
||||
|
||||
The original codebase can be found at [magenta/music-spectrogram-diffusion](https://github.com/magenta/music-spectrogram-diffusion).
|
||||
|
||||

|
||||
|
||||
As depicted above the model takes as input a MIDI file and tokenizes it into a sequence of 5 second intervals. Each tokenized interval then together with positional encodings is passed through the Note Encoder and its representation is concatenated with the previous window's generated spectrogram representation obtained via the Context Encoder. For the initial 5 second window this is set to zero. The resulting context is then used as conditioning to sample the denoised Spectrogram from the MIDI window and we concatenate this spectrogram to the final output as well as use it for the context of the next MIDI window. The process repeats till we have gone over all the MIDI inputs. Finally a MelGAN decoder converts the potentially long spectrogram to audio which is the final result of this pipeline.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## SpectrogramDiffusionPipeline
|
||||
[[autodoc]] SpectrogramDiffusionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AudioPipelineOutput
|
||||
[[autodoc]] pipelines.AudioPipelineOutput
|
||||
@@ -31,14 +31,14 @@ Make sure to check out the Stable Diffusion [Tips](overview#tips) section to lea
|
||||
|
||||
## StableDiffusionLDM3DPipeline
|
||||
|
||||
[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline
|
||||
[[autodoc]] pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## LDM3DPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.LDM3DPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d.LDM3DPipelineOutput
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Stochastic Karras VE
|
||||
|
||||
[Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) is by Tero Karras, Miika Aittala, Timo Aila and Samuli Laine. This pipeline implements the stochastic sampling tailored to variance expanding (VE) models.
|
||||
|
||||
The abstract from the paper:
|
||||
|
||||
*We argue that the theory and practice of diffusion-based generative models are currently unnecessarily convoluted and seek to remedy the situation by presenting a design space that clearly separates the concrete design choices. This lets us identify several changes to both the sampling and training processes, as well as preconditioning of the score networks. Together, our improvements yield new state-of-the-art FID of 1.79 for CIFAR-10 in a class-conditional setting and 1.97 in an unconditional setting, with much faster sampling (35 network evaluations per image) than prior designs. To further demonstrate their modular nature, we show that our design changes dramatically improve both the efficiency and quality obtainable with pre-trained score networks from previous work, including improving the FID of a previously trained ImageNet-64 model from 2.07 to near-SOTA 1.55, and after re-training with our proposed improvements to a new SOTA of 1.36.*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## KarrasVePipeline
|
||||
[[autodoc]] KarrasVePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ImagePipelineOutput
|
||||
[[autodoc]] pipelines.ImagePipelineOutput
|
||||
@@ -1,54 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Versatile Diffusion
|
||||
|
||||
Versatile Diffusion was proposed in [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://huggingface.co/papers/2211.08332) by Xingqian Xu, Zhangyang Wang, Eric Zhang, Kai Wang, Humphrey Shi.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Recent advances in diffusion models have set an impressive milestone in many generation tasks, and trending works such as DALL-E2, Imagen, and Stable Diffusion have attracted great interest. Despite the rapid landscape changes, recent new approaches focus on extensions and performance rather than capacity, thus requiring separate models for separate tasks. In this work, we expand the existing single-flow diffusion pipeline into a multi-task multimodal network, dubbed Versatile Diffusion (VD), that handles multiple flows of text-to-image, image-to-text, and variations in one unified model. The pipeline design of VD instantiates a unified multi-flow diffusion framework, consisting of sharable and swappable layer modules that enable the crossmodal generality beyond images and text. Through extensive experiments, we demonstrate that VD successfully achieves the following: a) VD outperforms the baseline approaches and handles all its base tasks with competitive quality; b) VD enables novel extensions such as disentanglement of style and semantics, dual- and multi-context blending, etc.; c) The success of our multi-flow multimodal framework over images and text may inspire further diffusion-based universal AI research.*
|
||||
|
||||
## Tips
|
||||
|
||||
You can load the more memory intensive "all-in-one" [`VersatileDiffusionPipeline`] that supports all the tasks or use the individual pipelines which are more memory efficient.
|
||||
|
||||
| **Pipeline** | **Supported tasks** |
|
||||
|------------------------------------------------------|-----------------------------------|
|
||||
| [`VersatileDiffusionPipeline`] | all of the below |
|
||||
| [`VersatileDiffusionTextToImagePipeline`] | text-to-image |
|
||||
| [`VersatileDiffusionImageVariationPipeline`] | image variation |
|
||||
| [`VersatileDiffusionDualGuidedPipeline`] | image-text dual guided generation |
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## VersatileDiffusionPipeline
|
||||
[[autodoc]] VersatileDiffusionPipeline
|
||||
|
||||
## VersatileDiffusionTextToImagePipeline
|
||||
[[autodoc]] VersatileDiffusionTextToImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## VersatileDiffusionImageVariationPipeline
|
||||
[[autodoc]] VersatileDiffusionImageVariationPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## VersatileDiffusionDualGuidedPipeline
|
||||
[[autodoc]] VersatileDiffusionDualGuidedPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -1,35 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# VQ Diffusion
|
||||
|
||||
[Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://huggingface.co/papers/2111.14822) is by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality.*
|
||||
|
||||
The original codebase can be found at [microsoft/VQ-Diffusion](https://github.com/microsoft/VQ-Diffusion).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## VQDiffusionPipeline
|
||||
[[autodoc]] VQDiffusionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ImagePipelineOutput
|
||||
[[autodoc]] pipelines.ImagePipelineOutput
|
||||
@@ -179,7 +179,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--dataloader_num_workers=8 \
|
||||
--resolution=512
|
||||
--resolution=512 \
|
||||
--center_crop \
|
||||
--random_flip \
|
||||
--train_batch_size=1 \
|
||||
@@ -214,4 +214,4 @@ image = pipeline("A pokemon with blue eyes").images[0]
|
||||
Congratulations on training a new model with LoRA! To learn more about how to use your new model, the following guides may be helpful:
|
||||
|
||||
- Learn how to [load different LoRA formats](../using-diffusers/loading_adapters#LoRA) trained using community trainers like Kohya and TheLastBen.
|
||||
- Learn how to use and [combine multiple LoRA's](../tutorials/using_peft_for_inference) with PEFT for inference.
|
||||
- Learn how to use and [combine multiple LoRA's](../tutorials/using_peft_for_inference) with PEFT for inference.
|
||||
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# T2I-Adapter
|
||||
|
||||
[T2I-Adapter]((https://hf.co/papers/2302.08453)) is a lightweight adapter model that provides an additional conditioning input image (line art, canny, sketch, depth, pose) to better control image generation. It is similar to a ControlNet, but it is a lot smaller (~77M parameters and ~300MB file size) because its only inserts weights into the UNet instead of copying and training it.
|
||||
[T2I-Adapter](https://hf.co/papers/2302.08453) is a lightweight adapter model that provides an additional conditioning input image (line art, canny, sketch, depth, pose) to better control image generation. It is similar to a ControlNet, but it is a lot smaller (~77M parameters and ~300MB file size) because its only inserts weights into the UNet instead of copying and training it.
|
||||
|
||||
The T2I-Adapter is only available for training with the Stable Diffusion XL (SDXL) model.
|
||||
|
||||
@@ -224,4 +224,4 @@ image.save("./output.png")
|
||||
|
||||
Congratulations on training a T2I-Adapter model! 🎉 To learn more:
|
||||
|
||||
- Read the [Efficient Controllable Generation for SDXL with T2I-Adapters](https://www.cs.cmu.edu/~custom-diffusion/) blog post to learn more details about the experimental results from the T2I-Adapter team.
|
||||
- Read the [Efficient Controllable Generation for SDXL with T2I-Adapters](https://huggingface.co/blog/t2i-sdxl-adapters) blog post to learn more details about the experimental results from the T2I-Adapter team.
|
||||
|
||||
@@ -186,7 +186,7 @@ accelerate launch train_unconditional.py \
|
||||
If you're training with more than one GPU, add the `--multi_gpu` parameter to the training command:
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
|
||||
accelerate launch --multi_gpu train_unconditional.py \
|
||||
--dataset_name="huggan/flowers-102-categories" \
|
||||
--output_dir="ddpm-ema-flowers-64" \
|
||||
--mixed_precision="fp16" \
|
||||
|
||||
318
docs/source/en/tutorials/fast_diffusion.md
Normal file
318
docs/source/en/tutorials/fast_diffusion.md
Normal file
@@ -0,0 +1,318 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Accelerate inference of text-to-image diffusion models
|
||||
|
||||
Diffusion models are known to be slower than their counter parts, GANs, because of the iterative and sequential reverse diffusion process. Recent works try to address limitation with:
|
||||
|
||||
* progressive timestep distillation (such as [LCM LoRA](../using-diffusers/inference_with_lcm_lora.md))
|
||||
* model compression (such as [SSD-1B](https://huggingface.co/segmind/SSD-1B))
|
||||
* reusing adjacent features of the denoiser (such as [DeepCache](https://github.com/horseee/DeepCache))
|
||||
|
||||
In this tutorial, we focus on leveraging the power of PyTorch 2 to accelerate the inference latency of text-to-image diffusion pipeline, instead. We will use [Stable Diffusion XL (SDXL)](../using-diffusers/sdxl.md) as a case study, but the techniques we will discuss should extend to other text-to-image diffusion pipelines.
|
||||
|
||||
## Setup
|
||||
|
||||
Make sure you're on the latest version of `diffusers`:
|
||||
|
||||
```bash
|
||||
pip install -U diffusers
|
||||
```
|
||||
|
||||
Then upgrade the other required libraries too:
|
||||
|
||||
```bash
|
||||
pip install -U transformers accelerate peft
|
||||
```
|
||||
|
||||
To benefit from the fastest kernels, use PyTorch nightly. You can find the installation instructions [here](https://pytorch.org/).
|
||||
|
||||
To report the numbers shown below, we used an 80GB 400W A100 with its clock rate set to the maximum.
|
||||
|
||||
_This tutorial doesn't present the benchmarking code and focuses on how to perform the optimizations, instead. For the full benchmarking code, refer to: [https://github.com/huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast)._
|
||||
|
||||
## Baseline
|
||||
|
||||
Let's start with a baseline. Disable the use of a reduced precision and [`scaled_dot_product_attention`](../optimization/torch2.0.md):
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
# Load the pipeline in full-precision and place its model components on CUDA.
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0"
|
||||
).to("cuda")
|
||||
|
||||
# Run the attention ops without efficiency.
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.vae.set_default_attn_processor()
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
This takes 7.36 seconds:
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_0.png" width=500>
|
||||
|
||||
</div>
|
||||
|
||||
## Running inference in bfloat16
|
||||
|
||||
Enable the first optimization: use a reduced precision to run the inference.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
# Run the attention ops without efficiency.
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.vae.set_default_attn_processor()
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
bfloat16 reduces the latency from 7.36 seconds to 4.63 seconds:
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_1.png" width=500>
|
||||
|
||||
</div>
|
||||
|
||||
**Why bfloat16?**
|
||||
|
||||
* Using a reduced numerical precision (such as float16, bfloat16) to run inference doesn’t affect the generation quality but significantly improves latency.
|
||||
* The benefits of using the bfloat16 numerical precision as compared to float16 are hardware-dependent. Modern generations of GPUs tend to favor bfloat16.
|
||||
* Furthermore, in our experiments, we bfloat16 to be much more resilient when used with quantization in comparison to float16.
|
||||
|
||||
We have a [dedicated guide](../optimization/fp16.md) for running inference in a reduced precision.
|
||||
|
||||
## Running attention efficiently
|
||||
|
||||
Attention blocks are intensive to run. But with PyTorch's [`scaled_dot_product_attention`](../optimization/torch2.0.md), we can run them efficiently.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
`scaled_dot_product_attention` improves the latency from 4.63 seconds to 3.31 seconds.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_2.png" width=500>
|
||||
|
||||
</div>
|
||||
|
||||
## Use faster kernels with torch.compile
|
||||
|
||||
Compile the UNet and the VAE to benefit from the faster kernels. First, configure a few compiler flags:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
torch._inductor.config.conv_1x1_as_mm = True
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.epilogue_fusion = False
|
||||
torch._inductor.config.coordinate_descent_check_all_directions = True
|
||||
```
|
||||
|
||||
For the full list of compiler flags, refer to [this file](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py).
|
||||
|
||||
It is also important to change the memory layout of the UNet and the VAE to “channels_last” when compiling them. This ensures maximum speed:
|
||||
|
||||
```python
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.vae.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Then, compile and perform inference:
|
||||
|
||||
```python
|
||||
# Compile the UNet and VAE.
|
||||
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
|
||||
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
# First call to `pipe` will be slow, subsequent ones will be faster.
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
`torch.compile` offers different backends and modes. As we’re aiming for maximum inference speed, we opt for the inductor backend using the “max-autotune”. “max-autotune” uses CUDA graphs and optimizes the compilation graph specifically for latency. Specifying fullgraph to be True ensures that there are no graph breaks in the underlying model, ensuring the fullest potential of `torch.compile`.
|
||||
|
||||
Using SDPA attention and compiling both the UNet and VAE reduces the latency from 3.31 seconds to 2.54 seconds.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_3.png" width=500>
|
||||
|
||||
</div>
|
||||
|
||||
## Combine the projection matrices of attention
|
||||
|
||||
Both the UNet and the VAE used in SDXL make use of Transformer-like blocks. A Transformer block consists of attention blocks and feed-forward blocks.
|
||||
|
||||
In an attention block, the input is projected into three sub-spaces using three different projection matrices – Q, K, and V. In the naive implementation, these projections are performed separately on the input. But we can horizontally combine the projection matrices into a single matrix and perform the projection in one shot. This increases the size of the matmuls of the input projections and improves the impact of quantization (to be discussed next).
|
||||
|
||||
Enabling this kind of computation in Diffusers just takes a single line of code:
|
||||
|
||||
```python
|
||||
pipe.fuse_qkv_projections()
|
||||
```
|
||||
|
||||
It provides a minor boost from 2.54 seconds to 2.52 seconds.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_4.png" width=500>
|
||||
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Support for `fuse_qkv_projections()` is limited and experimental. As such, it's not available for many non-SD pipelines such as [Kandinsky](../using-diffusers/kandinsky.md). You can refer to [this PR](https://github.com/huggingface/diffusers/pull/6179) to get an idea about how to support this kind of computation.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Dynamic quantization
|
||||
|
||||
Aapply [dynamic int8 quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) to both the UNet and the VAE. This is because quantization adds additional conversion overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization). If the matmuls are too small, these techniques may degrade performance.
|
||||
|
||||
<Tip>
|
||||
|
||||
Through experimentation, we found that certain linear layers in the UNet and the VAE don’t benefit from dynamic int8 quantization. You can check out the full code for filtering those layers [here](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16) (referred to as `dynamic_quant_filter_fn` below).
|
||||
|
||||
</Tip>
|
||||
|
||||
You will leverage the ultra-lightweight pure PyTorch library [torchao](https://github.com/pytorch-labs/ao) to use its user-friendly APIs for quantization.
|
||||
|
||||
First, configure all the compiler tags:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
# Notice the two new flags at the end.
|
||||
torch._inductor.config.conv_1x1_as_mm = True
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.epilogue_fusion = False
|
||||
torch._inductor.config.coordinate_descent_check_all_directions = True
|
||||
torch._inductor.config.force_fuse_int_mm_with_mul = True
|
||||
torch._inductor.config.use_mixed_mm = True
|
||||
```
|
||||
|
||||
Define the filtering functions:
|
||||
|
||||
```python
|
||||
def dynamic_quant_filter_fn(mod, *args):
|
||||
return (
|
||||
isinstance(mod, torch.nn.Linear)
|
||||
and mod.in_features > 16
|
||||
and (mod.in_features, mod.out_features)
|
||||
not in [
|
||||
(1280, 640),
|
||||
(1920, 1280),
|
||||
(1920, 640),
|
||||
(2048, 1280),
|
||||
(2048, 2560),
|
||||
(2560, 1280),
|
||||
(256, 128),
|
||||
(2816, 1280),
|
||||
(320, 640),
|
||||
(512, 1536),
|
||||
(512, 256),
|
||||
(512, 512),
|
||||
(640, 1280),
|
||||
(640, 1920),
|
||||
(640, 320),
|
||||
(640, 5120),
|
||||
(640, 640),
|
||||
(960, 320),
|
||||
(960, 640),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def conv_filter_fn(mod, *args):
|
||||
return (
|
||||
isinstance(mod, torch.nn.Conv2d) and mod.kernel_size == (1, 1) and 128 in [mod.in_channels, mod.out_channels]
|
||||
)
|
||||
```
|
||||
|
||||
Then apply all the optimizations discussed so far:
|
||||
|
||||
```python
|
||||
# SDPA + bfloat16.
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
# Combine attention projection matrices.
|
||||
pipe.fuse_qkv_projections()
|
||||
|
||||
# Change the memory layout.
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.vae.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Since this quantization support is limited to linear layers only, we also turn suitable pointwise convolution layers into linear layers to maximize the benefit.
|
||||
|
||||
```python
|
||||
from torchao import swap_conv2d_1x1_to_linear
|
||||
|
||||
swap_conv2d_1x1_to_linear(pipe.unet, conv_filter_fn)
|
||||
swap_conv2d_1x1_to_linear(pipe.vae, conv_filter_fn)
|
||||
```
|
||||
|
||||
Apply dynamic quantization:
|
||||
|
||||
```python
|
||||
from torchao import apply_dynamic_quant
|
||||
|
||||
apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
|
||||
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)
|
||||
```
|
||||
|
||||
Finally, compile and perform inference:
|
||||
|
||||
```python
|
||||
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
|
||||
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt, num_inference_steps=30).images[0]
|
||||
```
|
||||
|
||||
Applying dynamic quantization improves the latency from 2.52 seconds to 2.43 seconds.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/progressive-acceleration-sdxl/SDXL%2C_Batch_Size%3A_1%2C_Steps%3A_30_5.png" width=500>
|
||||
|
||||
</div>
|
||||
@@ -183,3 +183,26 @@ image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).ima
|
||||
# Gets the Unet back to the original state
|
||||
pipe.unfuse_lora()
|
||||
```
|
||||
|
||||
You can also fuse some adapters using `adapter_names` for faster generation:
|
||||
|
||||
```py
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
|
||||
pipe.set_adapters(["pixel"], adapter_weights=[0.5, 1.0])
|
||||
# Fuses the LoRAs into the Unet
|
||||
pipe.fuse_lora(adapter_names=["pixel"])
|
||||
|
||||
prompt = "a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
|
||||
# Gets the Unet back to the original state
|
||||
pipe.unfuse_lora()
|
||||
|
||||
# Fuse all adapters
|
||||
pipe.fuse_lora(adapter_names=["pixel", "toy"])
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
```
|
||||
|
||||
@@ -63,3 +63,42 @@ With callbacks, you can implement features such as dynamic CFG without having to
|
||||
🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## Using Callbacks to interrupt the Diffusion Process
|
||||
|
||||
The following Pipelines support interrupting the diffusion process via callback
|
||||
|
||||
- [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview.md)
|
||||
- [StableDiffusionImg2ImgPipeline](..api/pipelines/stable_diffusion/img2img.md)
|
||||
- [StableDiffusionInpaintPipeline](..api/pipelines/stable_diffusion/inpaint.md)
|
||||
- [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
- [StableDiffusionXLImg2ImgPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
- [StableDiffusionXLInpaintPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
|
||||
|
||||
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
|
||||
|
||||
This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
|
||||
|
||||
In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipe.enable_model_cpu_offload()
|
||||
num_inference_steps = 50
|
||||
|
||||
def interrupt_callback(pipe, i, t, callback_kwargs):
|
||||
stop_idx = 10
|
||||
if i == stop_idx:
|
||||
pipe._interrupt = True
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
pipe(
|
||||
"A photo of a cat",
|
||||
num_inference_steps=num_inference_steps,
|
||||
callback_on_step_end=interrupt_callback,
|
||||
)
|
||||
```
|
||||
|
||||
@@ -203,7 +203,7 @@ def make_inpaint_condition(image, image_mask):
|
||||
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
|
||||
|
||||
assert image.shape[0:1] == image_mask.shape[0:1]
|
||||
image[image_mask > 0.5] = 1.0 # set as masked pixel
|
||||
image[image_mask > 0.5] = -1.0 # set as masked pixel
|
||||
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return image
|
||||
|
||||
@@ -41,6 +41,20 @@ Now, define four different `Generator`s and assign each `Generator` a seed (`0`
|
||||
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
To create a batched seed, you should use a list comprehension that iterates over the length specified in `range()`. This creates a unique `Generator` object for each image in the batch. If you only multiply the `Generator` by the batch size, this only creates one `Generator` object that is used sequentially for each image in the batch.
|
||||
|
||||
For example, if you want to use the same seed to create 4 identical images:
|
||||
|
||||
```py
|
||||
❌ [torch.Generator().manual_seed(seed)] * 4
|
||||
|
||||
✅ [torch.Generator().manual_seed(seed) for _ in range(4)]
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
Generate the images and have a look:
|
||||
|
||||
```python
|
||||
|
||||
@@ -44,7 +44,7 @@ pipe = StableVideoDiffusionPipeline.from_pretrained(
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Load the conditioning image
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
|
||||
image = image.resize((1024, 576))
|
||||
|
||||
generator = torch.manual_seed(42)
|
||||
@@ -58,6 +58,11 @@ export_to_video(frames, "generated.mp4", fps=7)
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.mp4" type="video/mp4" />
|
||||
</video>
|
||||
|
||||
| **Source Image** | **Video** |
|
||||
|:------------:|:-----:|
|
||||
|  |  |
|
||||
|
||||
|
||||
<Tip>
|
||||
Since generating videos is more memory intensive we can use the `decode_chunk_size` argument to control how many frames are decoded at once. This will reduce the memory usage. It's recommended to tweak this value based on your GPU memory.
|
||||
Setting `decode_chunk_size=1` will decode one frame at a time and will use the least amount of memory but the video might have some flickering.
|
||||
@@ -120,7 +125,7 @@ pipe = StableVideoDiffusionPipeline.from_pretrained(
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# Load the conditioning image
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
|
||||
image = image.resize((1024, 576))
|
||||
|
||||
generator = torch.manual_seed(42)
|
||||
@@ -128,7 +133,5 @@ frames = pipe(image, decode_chunk_size=8, generator=generator, motion_bucket_id=
|
||||
export_to_video(frames, "generated.mp4", fps=7)
|
||||
```
|
||||
|
||||
<video width="1024" height="576" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated_motion.mp4" type="video/mp4">
|
||||
</video>
|
||||

|
||||
|
||||
|
||||
@@ -112,7 +112,7 @@ def save_model_card(
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
):
|
||||
img_str = "widget:\n" if images else ""
|
||||
img_str = "widget:\n"
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"""
|
||||
@@ -121,6 +121,10 @@ def save_model_card(
|
||||
url:
|
||||
"image_{i}.png"
|
||||
"""
|
||||
if not images:
|
||||
img_str += f"""
|
||||
- text: '{instance_prompt}'
|
||||
"""
|
||||
|
||||
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
|
||||
diffusers_imports_pivotal = ""
|
||||
@@ -157,8 +161,6 @@ tags:
|
||||
base_model: {base_model}
|
||||
instance_prompt: {instance_prompt}
|
||||
license: openrail++
|
||||
widget:
|
||||
- text: '{validation_prompt if validation_prompt else instance_prompt}'
|
||||
---
|
||||
"""
|
||||
|
||||
@@ -2010,43 +2012,42 @@ def main(args):
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
images = [
|
||||
|
||||
326
examples/amused/README.md
Normal file
326
examples/amused/README.md
Normal file
@@ -0,0 +1,326 @@
|
||||
## Amused training
|
||||
|
||||
Amused can be finetuned on simple datasets relatively cheaply and quickly. Using 8bit optimizers, lora, and gradient accumulation, amused can be finetuned with as little as 5.5 GB. Here are a set of examples for finetuning amused on some relatively simple datasets. These training recipies are aggressively oriented towards minimal resources and fast verification -- i.e. the batch sizes are quite low and the learning rates are quite high. For optimal quality, you will probably want to increase the batch sizes and decrease learning rates.
|
||||
|
||||
All training examples use fp16 mixed precision and gradient checkpointing. We don't show 8 bit adam + lora as its about the same memory use as just using lora (bitsandbytes uses full precision optimizer states for weights below a minimum size).
|
||||
|
||||
### Finetuning the 256 checkpoint
|
||||
|
||||
These examples finetune on this [nouns](https://huggingface.co/datasets/m1guelpf/nouns) dataset.
|
||||
|
||||
Example results:
|
||||
|
||||
  
|
||||
|
||||
|
||||
#### Full finetuning
|
||||
|
||||
Batch size: 8, Learning rate: 1e-4, Gives decent results in 750-1000 steps
|
||||
|
||||
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|
||||
|------------|-----------------------------|------------------|-------------|
|
||||
| 8 | 1 | 8 | 19.7 GB |
|
||||
| 4 | 2 | 8 | 18.3 GB |
|
||||
| 1 | 8 | 8 | 17.9 GB |
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 1e-4 \
|
||||
--pretrained_model_name_or_path amused/amused-256 \
|
||||
--instance_data_dataset 'm1guelpf/nouns' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
--resolution 256 \
|
||||
--mixed_precision fp16 \
|
||||
--lr_scheduler constant \
|
||||
--validation_prompts \
|
||||
'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
|
||||
'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
|
||||
'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
|
||||
'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
|
||||
'a pixel art character with square red glasses' \
|
||||
'a pixel art character' \
|
||||
'square red glasses on a pixel art character' \
|
||||
'square red glasses on a pixel art character with a baseball-shaped head' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 250 \
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
#### Full finetuning + 8 bit adam
|
||||
|
||||
Note that this training config keeps the batch size low and the learning rate high to get results fast with low resources. However, due to 8 bit adam, it will diverge eventually. If you want to train for longer, you will have to up the batch size and lower the learning rate.
|
||||
|
||||
Batch size: 16, Learning rate: 2e-5, Gives decent results in ~750 steps
|
||||
|
||||
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|
||||
|------------|-----------------------------|------------------|-------------|
|
||||
| 16 | 1 | 16 | 20.1 GB |
|
||||
| 8 | 2 | 16 | 15.6 GB |
|
||||
| 1 | 16 | 16 | 10.7 GB |
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 2e-5 \
|
||||
--use_8bit_adam \
|
||||
--pretrained_model_name_or_path amused/amused-256 \
|
||||
--instance_data_dataset 'm1guelpf/nouns' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
--resolution 256 \
|
||||
--mixed_precision fp16 \
|
||||
--lr_scheduler constant \
|
||||
--validation_prompts \
|
||||
'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
|
||||
'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
|
||||
'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
|
||||
'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
|
||||
'a pixel art character with square red glasses' \
|
||||
'a pixel art character' \
|
||||
'square red glasses on a pixel art character' \
|
||||
'square red glasses on a pixel art character with a baseball-shaped head' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 250 \
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
#### Full finetuning + lora
|
||||
|
||||
Batch size: 16, Learning rate: 8e-4, Gives decent results in 1000-1250 steps
|
||||
|
||||
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|
||||
|------------|-----------------------------|------------------|-------------|
|
||||
| 16 | 1 | 16 | 14.1 GB |
|
||||
| 8 | 2 | 16 | 10.1 GB |
|
||||
| 1 | 16 | 16 | 6.5 GB |
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 8e-4 \
|
||||
--use_lora \
|
||||
--pretrained_model_name_or_path amused/amused-256 \
|
||||
--instance_data_dataset 'm1guelpf/nouns' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
--resolution 256 \
|
||||
--mixed_precision fp16 \
|
||||
--lr_scheduler constant \
|
||||
--validation_prompts \
|
||||
'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \
|
||||
'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \
|
||||
'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \
|
||||
'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \
|
||||
'a pixel art character with square red glasses' \
|
||||
'a pixel art character' \
|
||||
'square red glasses on a pixel art character' \
|
||||
'square red glasses on a pixel art character with a baseball-shaped head' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 250 \
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
### Finetuning the 512 checkpoint
|
||||
|
||||
These examples finetune on this [minecraft](https://huggingface.co/monadical-labs/minecraft-preview) dataset.
|
||||
|
||||
Example results:
|
||||
|
||||
  
|
||||
|
||||
#### Full finetuning
|
||||
|
||||
Batch size: 8, Learning rate: 8e-5, Gives decent results in 500-1000 steps
|
||||
|
||||
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|
||||
|------------|-----------------------------|------------------|-------------|
|
||||
| 8 | 1 | 8 | 24.2 GB |
|
||||
| 4 | 2 | 8 | 19.7 GB |
|
||||
| 1 | 8 | 8 | 16.99 GB |
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 8e-5 \
|
||||
--pretrained_model_name_or_path amused/amused-512 \
|
||||
--instance_data_dataset 'monadical-labs/minecraft-preview' \
|
||||
--prompt_prefix 'minecraft ' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
--resolution 512 \
|
||||
--mixed_precision fp16 \
|
||||
--lr_scheduler constant \
|
||||
--validation_prompts \
|
||||
'minecraft Avatar' \
|
||||
'minecraft character' \
|
||||
'minecraft' \
|
||||
'minecraft president' \
|
||||
'minecraft pig' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 250 \
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
#### Full finetuning + 8 bit adam
|
||||
|
||||
Batch size: 8, Learning rate: 5e-6, Gives decent results in 500-1000 steps
|
||||
|
||||
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|
||||
|------------|-----------------------------|------------------|-------------|
|
||||
| 8 | 1 | 8 | 21.2 GB |
|
||||
| 4 | 2 | 8 | 13.3 GB |
|
||||
| 1 | 8 | 8 | 9.9 GB |
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 5e-6 \
|
||||
--pretrained_model_name_or_path amused/amused-512 \
|
||||
--instance_data_dataset 'monadical-labs/minecraft-preview' \
|
||||
--prompt_prefix 'minecraft ' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
--resolution 512 \
|
||||
--mixed_precision fp16 \
|
||||
--lr_scheduler constant \
|
||||
--validation_prompts \
|
||||
'minecraft Avatar' \
|
||||
'minecraft character' \
|
||||
'minecraft' \
|
||||
'minecraft president' \
|
||||
'minecraft pig' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 250 \
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
#### Full finetuning + lora
|
||||
|
||||
Batch size: 8, Learning rate: 1e-4, Gives decent results in 500-1000 steps
|
||||
|
||||
| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |
|
||||
|------------|-----------------------------|------------------|-------------|
|
||||
| 8 | 1 | 8 | 12.7 GB |
|
||||
| 4 | 2 | 8 | 9.0 GB |
|
||||
| 1 | 8 | 8 | 5.6 GB |
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--train_batch_size <batch size> \
|
||||
--gradient_accumulation_steps <gradient accumulation steps> \
|
||||
--learning_rate 1e-4 \
|
||||
--use_lora \
|
||||
--pretrained_model_name_or_path amused/amused-512 \
|
||||
--instance_data_dataset 'monadical-labs/minecraft-preview' \
|
||||
--prompt_prefix 'minecraft ' \
|
||||
--image_key image \
|
||||
--prompt_key text \
|
||||
--resolution 512 \
|
||||
--mixed_precision fp16 \
|
||||
--lr_scheduler constant \
|
||||
--validation_prompts \
|
||||
'minecraft Avatar' \
|
||||
'minecraft character' \
|
||||
'minecraft' \
|
||||
'minecraft president' \
|
||||
'minecraft pig' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 250 \
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
### Styledrop
|
||||
|
||||
[Styledrop](https://arxiv.org/abs/2306.00983) is an efficient finetuning method for learning a new style from just one or very few images. It has an optional first stage to generate human picked additional training samples. The additional training samples can be used to augment the initial images. Our examples exclude the optional additional image selection stage and instead we just finetune on a single image.
|
||||
|
||||
This is our example style image:
|
||||

|
||||
|
||||
Download it to your local directory with
|
||||
```sh
|
||||
wget https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png
|
||||
```
|
||||
|
||||
#### 256
|
||||
|
||||
Example results:
|
||||
|
||||
  
|
||||
|
||||
Learning rate: 4e-4, Gives decent results in 1500-2000 steps
|
||||
|
||||
Memory used: 6.5 GB
|
||||
|
||||
```sh
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--mixed_precision fp16 \
|
||||
--report_to wandb \
|
||||
--use_lora \
|
||||
--pretrained_model_name_or_path amused/amused-256 \
|
||||
--train_batch_size 1 \
|
||||
--lr_scheduler constant \
|
||||
--learning_rate 4e-4 \
|
||||
--validation_prompts \
|
||||
'A chihuahua walking on the street in [V] style' \
|
||||
'A banana on the table in [V] style' \
|
||||
'A church on the street in [V] style' \
|
||||
'A tabby cat walking in the forest in [V] style' \
|
||||
--instance_data_image 'A mushroom in [V] style.png' \
|
||||
--max_train_steps 10000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 100 \
|
||||
--resolution 256
|
||||
```
|
||||
|
||||
#### 512
|
||||
|
||||
Example results:
|
||||
|
||||
  
|
||||
|
||||
Learning rate: 1e-3, Lora alpha 1, Gives decent results in 1500-2000 steps
|
||||
|
||||
Memory used: 5.6 GB
|
||||
|
||||
```
|
||||
accelerate launch train_amused.py \
|
||||
--output_dir <output path> \
|
||||
--mixed_precision fp16 \
|
||||
--report_to wandb \
|
||||
--use_lora \
|
||||
--pretrained_model_name_or_path amused/amused-512 \
|
||||
--train_batch_size 1 \
|
||||
--lr_scheduler constant \
|
||||
--learning_rate 1e-3 \
|
||||
--validation_prompts \
|
||||
'A chihuahua walking on the street in [V] style' \
|
||||
'A banana on the table in [V] style' \
|
||||
'A church on the street in [V] style' \
|
||||
'A tabby cat walking in the forest in [V] style' \
|
||||
--instance_data_image 'A mushroom in [V] style.png' \
|
||||
--max_train_steps 100000 \
|
||||
--checkpointing_steps 500 \
|
||||
--validation_steps 100 \
|
||||
--resolution 512 \
|
||||
--lora_alpha 1
|
||||
```
|
||||
972
examples/amused/train_amused.py
Normal file
972
examples/amused/train_amused.py
Normal file
@@ -0,0 +1,972 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torch.utils.data import DataLoader, Dataset, default_collate
|
||||
from torchvision import transforms
|
||||
from transformers import (
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
import diffusers.optimization
|
||||
from diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.utils import is_wandb_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="A Hugging Face dataset containing the training images",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="A folder containing the training data of instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_image", type=str, default=None, required=False, help="A single training image"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
||||
parser.add_argument("--ema_decay", type=float, default=0.9999)
|
||||
parser.add_argument("--ema_update_after_step", type=int, default=0)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="muse_training",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
||||
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
||||
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
||||
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
||||
"instructions."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
||||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
||||
" for more details"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=0.0003,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=100,
|
||||
help=(
|
||||
"Run validation every X steps. Validation consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
||||
" and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="wandb",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--validation_prompts", type=str, nargs="*")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--split_vae_encode", type=int, required=False, default=None)
|
||||
parser.add_argument("--min_masking_rate", type=float, default=0.0)
|
||||
parser.add_argument("--cond_dropout_prob", type=float, default=0.0)
|
||||
parser.add_argument("--max_grad_norm", default=None, type=float, help="Max gradient norm.", required=False)
|
||||
parser.add_argument("--use_lora", action="store_true", help="Fine tune the model using LoRa")
|
||||
parser.add_argument("--text_encoder_use_lora", action="store_true", help="Fine tune the model using LoRa")
|
||||
parser.add_argument("--lora_r", default=16, type=int)
|
||||
parser.add_argument("--lora_alpha", default=32, type=int)
|
||||
parser.add_argument("--lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
|
||||
parser.add_argument("--text_encoder_lora_r", default=16, type=int)
|
||||
parser.add_argument("--text_encoder_lora_alpha", default=32, type=int)
|
||||
parser.add_argument("--text_encoder_lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+")
|
||||
parser.add_argument("--train_text_encoder", action="store_true")
|
||||
parser.add_argument("--image_key", type=str, required=False)
|
||||
parser.add_argument("--prompt_key", type=str, required=False)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument("--prompt_prefix", type=str, required=False, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
num_datasources = sum(
|
||||
[x is not None for x in [args.instance_data_dir, args.instance_data_image, args.instance_data_dataset]]
|
||||
)
|
||||
|
||||
if num_datasources != 1:
|
||||
raise ValueError(
|
||||
"provide one and only one of `--instance_data_dir`, `--instance_data_image`, or `--instance_data_dataset`"
|
||||
)
|
||||
|
||||
if args.instance_data_dir is not None:
|
||||
if not os.path.exists(args.instance_data_dir):
|
||||
raise ValueError(f"Does not exist: `--args.instance_data_dir` {args.instance_data_dir}")
|
||||
|
||||
if args.instance_data_image is not None:
|
||||
if not os.path.exists(args.instance_data_image):
|
||||
raise ValueError(f"Does not exist: `--args.instance_data_image` {args.instance_data_image}")
|
||||
|
||||
if args.instance_data_dataset is not None and (args.image_key is None or args.prompt_key is None):
|
||||
raise ValueError("`--instance_data_dataset` requires setting `--image_key` and `--prompt_key`")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class InstanceDataRootDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
instance_data_root,
|
||||
tokenizer,
|
||||
size=512,
|
||||
):
|
||||
self.size = size
|
||||
self.tokenizer = tokenizer
|
||||
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.instance_images_path)
|
||||
|
||||
def __getitem__(self, index):
|
||||
image_path = self.instance_images_path[index % len(self.instance_images_path)]
|
||||
instance_image = Image.open(image_path)
|
||||
rv = process_image(instance_image, self.size)
|
||||
|
||||
prompt = os.path.splitext(os.path.basename(image_path))[0]
|
||||
rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0]
|
||||
return rv
|
||||
|
||||
|
||||
class InstanceDataImageDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
instance_data_image,
|
||||
train_batch_size,
|
||||
size=512,
|
||||
):
|
||||
self.value = process_image(Image.open(instance_data_image), size)
|
||||
self.train_batch_size = train_batch_size
|
||||
|
||||
def __len__(self):
|
||||
# Needed so a full batch of the data can be returned. Otherwise will return
|
||||
# batches of size 1
|
||||
return self.train_batch_size
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.value
|
||||
|
||||
|
||||
class HuggingFaceDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
hf_dataset,
|
||||
tokenizer,
|
||||
image_key,
|
||||
prompt_key,
|
||||
prompt_prefix=None,
|
||||
size=512,
|
||||
):
|
||||
self.size = size
|
||||
self.image_key = image_key
|
||||
self.prompt_key = prompt_key
|
||||
self.tokenizer = tokenizer
|
||||
self.hf_dataset = hf_dataset
|
||||
self.prompt_prefix = prompt_prefix
|
||||
|
||||
def __len__(self):
|
||||
return len(self.hf_dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.hf_dataset[index]
|
||||
|
||||
rv = process_image(item[self.image_key], self.size)
|
||||
|
||||
prompt = item[self.prompt_key]
|
||||
|
||||
if self.prompt_prefix is not None:
|
||||
prompt = self.prompt_prefix + prompt
|
||||
|
||||
rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0]
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def process_image(image, size):
|
||||
image = exif_transpose(image)
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
orig_height = image.height
|
||||
orig_width = image.width
|
||||
|
||||
image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image)
|
||||
|
||||
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size))
|
||||
image = transforms.functional.crop(image, c_top, c_left, size, size)
|
||||
|
||||
image = transforms.ToTensor()(image)
|
||||
|
||||
micro_conds = torch.tensor(
|
||||
[orig_width, orig_height, c_top, c_left, 6.0],
|
||||
)
|
||||
|
||||
return {"image": image, "micro_conds": micro_conds}
|
||||
|
||||
|
||||
def tokenize_prompt(tokenizer, prompt):
|
||||
return tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
|
||||
def encode_prompt(text_encoder, input_ids):
|
||||
outputs = text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
||||
encoder_hidden_states = outputs.hidden_states[-2]
|
||||
cond_embeds = outputs[0]
|
||||
return encoder_hidden_states, cond_embeds
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("amused", config=vars(copy.deepcopy(args)))
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# TODO - will have to fix loading if training text encoder
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, variant=args.variant
|
||||
)
|
||||
vq_model = VQModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vqvae", revision=args.revision, variant=args.variant
|
||||
)
|
||||
|
||||
if args.train_text_encoder:
|
||||
if args.text_encoder_use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=args.text_encoder_lora_r,
|
||||
lora_alpha=args.text_encoder_lora_alpha,
|
||||
target_modules=args.text_encoder_lora_target_modules,
|
||||
)
|
||||
text_encoder.add_adapter(lora_config)
|
||||
text_encoder.train()
|
||||
text_encoder.requires_grad_(True)
|
||||
else:
|
||||
text_encoder.eval()
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
vq_model.requires_grad_(False)
|
||||
|
||||
model = UVit2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
|
||||
if args.use_lora:
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
target_modules=args.lora_target_modules,
|
||||
)
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
model.train()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
model.enable_gradient_checkpointing()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if args.use_ema:
|
||||
ema = EMAModel(
|
||||
model.parameters(),
|
||||
decay=args.ema_decay,
|
||||
update_after_step=args.ema_update_after_step,
|
||||
model_cls=UVit2DModel,
|
||||
model_config=model.config,
|
||||
)
|
||||
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
text_encoder_lora_layers_to_save = None
|
||||
|
||||
for model_ in models:
|
||||
if isinstance(model_, type(accelerator.unwrap_model(model))):
|
||||
if args.use_lora:
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model_)
|
||||
else:
|
||||
model_.save_pretrained(os.path.join(output_dir, "transformer"))
|
||||
elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):
|
||||
if args.text_encoder_use_lora:
|
||||
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model_)
|
||||
else:
|
||||
model_.save_pretrained(os.path.join(output_dir, "text_encoder"))
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model_.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
ema.save_pretrained(os.path.join(output_dir, "ema_model"))
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer = None
|
||||
text_encoder_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model_ = models.pop()
|
||||
|
||||
if isinstance(model_, type(accelerator.unwrap_model(model))):
|
||||
if args.use_lora:
|
||||
transformer = model_
|
||||
else:
|
||||
load_model = UVit2DModel.from_pretrained(os.path.join(input_dir, "transformer"))
|
||||
model_.load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
if args.text_encoder_use_lora:
|
||||
text_encoder_ = model_
|
||||
else:
|
||||
load_model = CLIPTextModelWithProjection.from_pretrained(os.path.join(input_dir, "text_encoder"))
|
||||
model_.load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer is not None or text_encoder_ is not None:
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
|
||||
)
|
||||
LoraLoaderMixin.load_lora_into_transformer(
|
||||
lora_state_dict, network_alphas=network_alphas, transformer=transformer
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
load_from = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), model_cls=UVit2DModel)
|
||||
ema.load_state_dict(load_from.state_dict())
|
||||
del load_from
|
||||
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
||||
)
|
||||
|
||||
optimizer_cls = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
# no decay on bias and layernorm and embedding
|
||||
no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.adam_weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
if args.train_text_encoder:
|
||||
optimizer_grouped_parameters.append(
|
||||
{"params": text_encoder.parameters(), "weight_decay": args.adam_weight_decay}
|
||||
)
|
||||
|
||||
optimizer = optimizer_cls(
|
||||
optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
logger.info("Creating dataloaders and lr_scheduler")
|
||||
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
if args.instance_data_dir is not None:
|
||||
dataset = InstanceDataRootDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
)
|
||||
elif args.instance_data_image is not None:
|
||||
dataset = InstanceDataImageDataset(
|
||||
instance_data_image=args.instance_data_image,
|
||||
train_batch_size=args.train_batch_size,
|
||||
size=args.resolution,
|
||||
)
|
||||
elif args.instance_data_dataset is not None:
|
||||
dataset = HuggingFaceDataset(
|
||||
hf_dataset=load_dataset(args.instance_data_dataset, split="train"),
|
||||
tokenizer=tokenizer,
|
||||
image_key=args.image_key,
|
||||
prompt_key=args.prompt_key,
|
||||
prompt_prefix=args.prompt_prefix,
|
||||
size=args.resolution,
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
collate_fn=default_collate,
|
||||
)
|
||||
train_dataloader.num_batches = len(train_dataloader)
|
||||
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
logger.info("Preparing model, optimizer and dataloaders")
|
||||
|
||||
if args.train_text_encoder:
|
||||
model, optimizer, lr_scheduler, train_dataloader, text_encoder = accelerator.prepare(
|
||||
model, optimizer, lr_scheduler, train_dataloader, text_encoder
|
||||
)
|
||||
else:
|
||||
model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(
|
||||
model, optimizer, lr_scheduler, train_dataloader
|
||||
)
|
||||
|
||||
train_dataloader.num_batches = len(train_dataloader)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.to(device=accelerator.device, dtype=weight_dtype)
|
||||
|
||||
vq_model.to(device=accelerator.device)
|
||||
|
||||
if args.use_ema:
|
||||
ema.to(accelerator.device)
|
||||
|
||||
with nullcontext() if args.train_text_encoder else torch.no_grad():
|
||||
empty_embeds, empty_clip_embeds = encode_prompt(
|
||||
text_encoder, tokenize_prompt(tokenizer, "").to(text_encoder.device, non_blocking=True)
|
||||
)
|
||||
|
||||
# There is a single image, we can just pre-encode the single prompt
|
||||
if args.instance_data_image is not None:
|
||||
prompt = os.path.splitext(os.path.basename(args.instance_data_image))[0]
|
||||
encoder_hidden_states, cond_embeds = encode_prompt(
|
||||
text_encoder, tokenize_prompt(tokenizer, prompt).to(text_encoder.device, non_blocking=True)
|
||||
)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat(args.train_batch_size, 1, 1)
|
||||
cond_embeds = cond_embeds.repeat(args.train_batch_size, 1)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
|
||||
# Afterwards we recalculate our number of training epochs.
|
||||
# Note: We are not doing epoch based training here, but just using this for book keeping and being able to
|
||||
# reuse the same training loop with other datasets/loaders.
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num training steps = {args.max_train_steps}")
|
||||
logger.info(f" Instantaneous batch size per device = { args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
|
||||
resume_from_checkpoint = args.resume_from_checkpoint
|
||||
if resume_from_checkpoint:
|
||||
if resume_from_checkpoint == "latest":
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
if len(dirs) > 0:
|
||||
resume_from_checkpoint = os.path.join(args.output_dir, dirs[-1])
|
||||
else:
|
||||
resume_from_checkpoint = None
|
||||
|
||||
if resume_from_checkpoint is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {resume_from_checkpoint}")
|
||||
|
||||
if resume_from_checkpoint is None:
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
else:
|
||||
accelerator.load_state(resume_from_checkpoint)
|
||||
global_step = int(os.path.basename(resume_from_checkpoint).split("-")[1])
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
|
||||
# As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to
|
||||
# reuse the same training loop with other datasets/loaders.
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
for batch in train_dataloader:
|
||||
with torch.no_grad():
|
||||
micro_conds = batch["micro_conds"].to(accelerator.device, non_blocking=True)
|
||||
pixel_values = batch["image"].to(accelerator.device, non_blocking=True)
|
||||
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
split_batch_size = args.split_vae_encode if args.split_vae_encode is not None else batch_size
|
||||
num_splits = math.ceil(batch_size / split_batch_size)
|
||||
image_tokens = []
|
||||
for i in range(num_splits):
|
||||
start_idx = i * split_batch_size
|
||||
end_idx = min((i + 1) * split_batch_size, batch_size)
|
||||
bs = pixel_values.shape[0]
|
||||
image_tokens.append(
|
||||
vq_model.quantize(vq_model.encode(pixel_values[start_idx:end_idx]).latents)[2][2].reshape(
|
||||
bs, -1
|
||||
)
|
||||
)
|
||||
image_tokens = torch.cat(image_tokens, dim=0)
|
||||
|
||||
batch_size, seq_len = image_tokens.shape
|
||||
|
||||
timesteps = torch.rand(batch_size, device=image_tokens.device)
|
||||
mask_prob = torch.cos(timesteps * math.pi * 0.5)
|
||||
mask_prob = mask_prob.clip(args.min_masking_rate)
|
||||
|
||||
num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
|
||||
batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)
|
||||
mask = batch_randperm < num_token_masked.unsqueeze(-1)
|
||||
|
||||
mask_id = accelerator.unwrap_model(model).config.vocab_size - 1
|
||||
input_ids = torch.where(mask, mask_id, image_tokens)
|
||||
labels = torch.where(mask, image_tokens, -100)
|
||||
|
||||
if args.cond_dropout_prob > 0.0:
|
||||
assert encoder_hidden_states is not None
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
mask = (
|
||||
torch.zeros((batch_size, 1, 1), device=encoder_hidden_states.device).float().uniform_(0, 1)
|
||||
< args.cond_dropout_prob
|
||||
)
|
||||
|
||||
empty_embeds_ = empty_embeds.expand(batch_size, -1, -1)
|
||||
encoder_hidden_states = torch.where(
|
||||
(encoder_hidden_states * mask).bool(), encoder_hidden_states, empty_embeds_
|
||||
)
|
||||
|
||||
empty_clip_embeds_ = empty_clip_embeds.expand(batch_size, -1)
|
||||
cond_embeds = torch.where((cond_embeds * mask.squeeze(-1)).bool(), cond_embeds, empty_clip_embeds_)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
vae_scale_factor = 2 ** (len(vq_model.config.block_out_channels) - 1)
|
||||
resolution = args.resolution // vae_scale_factor
|
||||
input_ids = input_ids.reshape(bs, resolution, resolution)
|
||||
|
||||
if "prompt_input_ids" in batch:
|
||||
with nullcontext() if args.train_text_encoder else torch.no_grad():
|
||||
encoder_hidden_states, cond_embeds = encode_prompt(
|
||||
text_encoder, batch["prompt_input_ids"].to(accelerator.device, non_blocking=True)
|
||||
)
|
||||
|
||||
# Train Step
|
||||
with accelerator.accumulate(model):
|
||||
codebook_size = accelerator.unwrap_model(model).config.codebook_size
|
||||
|
||||
logits = (
|
||||
model(
|
||||
input_ids=input_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
micro_conds=micro_conds,
|
||||
pooled_text_emb=cond_embeds,
|
||||
)
|
||||
.reshape(bs, codebook_size, -1)
|
||||
.permute(0, 2, 1)
|
||||
.reshape(-1, codebook_size)
|
||||
)
|
||||
|
||||
loss = F.cross_entropy(
|
||||
logits,
|
||||
labels.view(-1),
|
||||
ignore_index=-100,
|
||||
reduction="mean",
|
||||
)
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||
avg_masking_rate = accelerator.gather(mask_prob.repeat(args.train_batch_size)).mean()
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if args.max_grad_norm is not None and accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
if args.use_ema:
|
||||
ema.step(model.parameters())
|
||||
|
||||
if (global_step + 1) % args.logging_steps == 0:
|
||||
logs = {
|
||||
"step_loss": avg_loss.item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"avg_masking_rate": avg_masking_rate.item(),
|
||||
}
|
||||
accelerator.log(logs, step=global_step + 1)
|
||||
|
||||
logger.info(
|
||||
f"Step: {global_step + 1} "
|
||||
f"Loss: {avg_loss.item():0.4f} "
|
||||
f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}"
|
||||
)
|
||||
|
||||
if (global_step + 1) % args.checkpointing_steps == 0:
|
||||
save_checkpoint(args, accelerator, global_step + 1)
|
||||
|
||||
if (global_step + 1) % args.validation_steps == 0 and accelerator.is_main_process:
|
||||
if args.use_ema:
|
||||
ema.store(model.parameters())
|
||||
ema.copy_to(model.parameters())
|
||||
|
||||
with torch.no_grad():
|
||||
logger.info("Generating images...")
|
||||
|
||||
model.eval()
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder.eval()
|
||||
|
||||
scheduler = AmusedScheduler.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="scheduler",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
|
||||
pipe = AmusedPipeline(
|
||||
transformer=accelerator.unwrap_model(model),
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vqvae=vq_model,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pil_images = pipe(prompt=args.validation_prompts).images
|
||||
wandb_images = [
|
||||
wandb.Image(image, caption=args.validation_prompts[i])
|
||||
for i, image in enumerate(pil_images)
|
||||
]
|
||||
|
||||
wandb.log({"generated_images": wandb_images}, step=global_step + 1)
|
||||
|
||||
model.train()
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
|
||||
if args.use_ema:
|
||||
ema.restore(model.parameters())
|
||||
|
||||
global_step += 1
|
||||
|
||||
# Stop training if max steps is reached
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
# End for
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Evaluate and save checkpoint at the end of training
|
||||
save_checkpoint(args, accelerator, global_step)
|
||||
|
||||
# Save the final trained checkpoint
|
||||
if accelerator.is_main_process:
|
||||
model = accelerator.unwrap_model(model)
|
||||
if args.use_ema:
|
||||
ema.copy_to(model.parameters())
|
||||
model.save_pretrained(args.output_dir)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def save_checkpoint(args, accelerator, global_step):
|
||||
output_dir = args.output_dir
|
||||
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if accelerator.is_main_process and args.checkpoints_total_limit is not None:
|
||||
checkpoints = os.listdir(output_dir)
|
||||
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
||||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
||||
|
||||
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
||||
if len(checkpoints) >= args.checkpoints_total_limit:
|
||||
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(output_dir, removing_checkpoint)
|
||||
shutil.rmtree(removing_checkpoint)
|
||||
|
||||
save_path = Path(output_dir) / f"checkpoint-{global_step}"
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(parse_args())
|
||||
@@ -8,6 +8,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
|
||||
| Example | Description | Code Example | Colab | Author |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|
||||
| Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [](https://huggingface.co/spaces/toshas/marigold) [](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) |
|
||||
| LLM-grounded Diffusion (LMD+) | LMD greatly improves the prompt following ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) | [LLM-grounded Diffusion (LMD+)](#llm-grounded-diffusion) | [Huggingface Demo](https://huggingface.co/spaces/longlian/llm-grounded-diffusion) [](https://colab.research.google.com/drive/1SXzMSeAB-LJYISb2yrUOdypLz4OYWUKj) | [Long (Tony) Lian](https://tonylian.com/) |
|
||||
| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) |
|
||||
| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
@@ -41,13 +42,14 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| TensorRT Stable Diffusion Inpainting Pipeline | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
|
||||
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#Zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
|
||||
Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | - | [Andrew Zhu](https://xhinker.medium.com/) |
|
||||
| Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) |
|
||||
FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
|
||||
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
|
||||
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
| SDE Drag Pipeline | The pipeline supports drag editing of images using stochastic differential equations | [SDE Drag Pipeline](#sde-drag-pipeline) | - | [NieShen](https://github.com/NieShenRuc) [Fengqi Zhu](https://github.com/Monohydroxides) |
|
||||
| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |
|
||||
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
|
||||
| AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |
|
||||
@@ -60,6 +62,53 @@ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custo
|
||||
|
||||
## Example usages
|
||||
|
||||
### Marigold Depth Estimation
|
||||
|
||||
Marigold is a universal monocular depth estimator that delivers accurate and sharp predictions in the wild. Based on Stable Diffusion, it is trained exclusively with synthetic depth data and excels in zero-shot adaptation to real-world imagery. This pipeline is an official implementation of the inference process. More details can be found on our [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) (also implemented with diffusers).
|
||||
|
||||

|
||||
|
||||
This depth estimation pipeline processes a single input image through multiple diffusion denoising stages to estimate depth maps. These maps are subsequently merged to produce the final output. Below is an example code snippet, including optional arguments:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"Bingxin/Marigold",
|
||||
custom_pipeline="marigold_depth_estimation"
|
||||
# torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
|
||||
)
|
||||
|
||||
pipe.to("cuda")
|
||||
|
||||
img_path_or_url = "https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_example.jpg"
|
||||
image: Image.Image = load_image(img_path_or_url)
|
||||
|
||||
pipeline_output = pipe(
|
||||
image, # Input image.
|
||||
# denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10.
|
||||
# ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10.
|
||||
# processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768.
|
||||
# match_input_res=True, # (optional) Resize depth prediction to match input resolution.
|
||||
# batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0.
|
||||
# color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral".
|
||||
# show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress.
|
||||
)
|
||||
|
||||
depth: np.ndarray = pipeline_output.depth_np # Predicted depth map
|
||||
depth_colored: Image.Image = pipeline_output.depth_colored # Colorized prediction
|
||||
|
||||
# Save as uint16 PNG
|
||||
depth_uint16 = (depth * 65535.0).astype(np.uint16)
|
||||
Image.fromarray(depth_uint16).save("./depth_map.png", mode="I;16")
|
||||
|
||||
# Save colorized depth map
|
||||
depth_colored.save("./depth_colored.png")
|
||||
```
|
||||
|
||||
### LLM-grounded Diffusion
|
||||
|
||||
LMD and LMD+ greatly improves the prompt understanding ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. It improves spatial reasoning, the understanding of negation, attribute binding, generative numeracy, etc. in a unified manner without explicitly aiming for each. LMD is completely training-free (i.e., uses SD model off-the-shelf). LMD+ takes in additional adapters for better control. This is a reproduction of LMD+ model used in our work. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion)
|
||||
@@ -1618,10 +1667,11 @@ This approach is using (optional) CoCa model to avoid writing image description.
|
||||
|
||||
This SDXL pipeline support unlimited length prompt and negative prompt, compatible with A1111 prompt weighted style.
|
||||
|
||||
You can provide both `prompt` and `prompt_2`. if only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
|
||||
You can provide both `prompt` and `prompt_2`. If only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
@@ -1632,25 +1682,52 @@ pipe = DiffusionPipeline.from_pretrained(
|
||||
, custom_pipeline = "lpw_stable_diffusion_xl",
|
||||
)
|
||||
|
||||
prompt = "photo of a cute (white) cat running on the grass"*20
|
||||
prompt2 = "chasing (birds:1.5)"*20
|
||||
prompt = "photo of a cute (white) cat running on the grass" * 20
|
||||
prompt2 = "chasing (birds:1.5)" * 20
|
||||
prompt = f"{prompt},{prompt2}"
|
||||
neg_prompt = "blur, low quality, carton, animate"
|
||||
|
||||
pipe.to("cuda")
|
||||
images = pipe(
|
||||
prompt = prompt
|
||||
, negative_prompt = neg_prompt
|
||||
).images[0]
|
||||
|
||||
# text2img
|
||||
t2i_images = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=neg_prompt,
|
||||
).images # alternatively, you can call the .text2img() function
|
||||
|
||||
# img2img
|
||||
input_image = load_image("/path/to/local/image.png") # or URL to your input image
|
||||
i2i_images = pipe.img2img(
|
||||
prompt=prompt,
|
||||
negative_prompt=neg_prompt,
|
||||
image=input_image,
|
||||
strength=0.8, # higher strength will result in more variation compared to original image
|
||||
).images
|
||||
|
||||
# inpaint
|
||||
input_mask = load_image("/path/to/local/mask.png") # or URL to your input inpainting mask
|
||||
inpaint_images = pipe.inpaint(
|
||||
prompt="photo of a cute (black) cat running on the grass" * 20,
|
||||
negative_prompt=neg_prompt,
|
||||
image=input_image,
|
||||
mask=input_mask,
|
||||
strength=0.6, # higher strength will result in more variation compared to original image
|
||||
).images
|
||||
|
||||
pipe.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
images
|
||||
|
||||
from IPython.display import display # assuming you are using this code in a notebook
|
||||
display(t2i_images[0])
|
||||
display(i2i_images[0])
|
||||
display(inpaint_images[0])
|
||||
```
|
||||
|
||||
In the above code, the `prompt2` is appended to the `prompt`, which is more than 77 tokens. "birds" are showing up in the result.
|
||||

|
||||
|
||||
For more results, checkout [PR #6114](https://github.com/huggingface/diffusers/pull/6114).
|
||||
|
||||
## Example Images Mixing (with CoCa)
|
||||
```python
|
||||
import requests
|
||||
@@ -2986,3 +3063,42 @@ def image_grid(imgs, save_path=None):
|
||||
image_grid(images, save_path="./outputs/")
|
||||
```
|
||||

|
||||
|
||||
### SDE Drag pipeline
|
||||
|
||||
This pipeline provides drag-and-drop image editing using stochastic differential equations. It enables image editing by inputting prompt, image, mask_image, source_points, and target_points.
|
||||
|
||||

|
||||
|
||||
See [paper](https://arxiv.org/abs/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more infomation.
|
||||
|
||||
```py
|
||||
import PIL
|
||||
import torch
|
||||
from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
|
||||
# Load the pipeline
|
||||
model_path = "runwayml/stable-diffusion-v1-5"
|
||||
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")
|
||||
pipe.to('cuda')
|
||||
|
||||
# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
|
||||
# If not training LoRA, please avoid using torch.float16
|
||||
# pipe.to(torch.float16)
|
||||
|
||||
# Provide prompt, image, mask image, and the starting and target points for drag editing.
|
||||
prompt = "prompt of the image"
|
||||
image = PIL.Image.open('/path/to/image')
|
||||
mask_image = PIL.Image.open('/path/to/mask_image')
|
||||
source_points = [[123, 456]]
|
||||
target_points = [[234, 567]]
|
||||
|
||||
# train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image.
|
||||
pipe.train_lora(prompt, image)
|
||||
|
||||
output = pipe(prompt, image, mask_image, source_points, target_points)
|
||||
output_image = PIL.Image.fromarray(output)
|
||||
output_image.save("./output.png")
|
||||
|
||||
```
|
||||
|
||||
@@ -11,10 +11,11 @@ import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
@@ -23,7 +24,7 @@ from diffusers.models.attention_processor import (
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
is_accelerate_available,
|
||||
@@ -461,6 +462,65 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL.
|
||||
@@ -526,6 +586,9 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
||||
)
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
@@ -813,6 +876,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
@@ -824,6 +888,9 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
@@ -880,23 +947,263 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
|
||||
# get the original timestep using init_timestep
|
||||
if denoising_start is None:
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
else:
|
||||
t_start = 0
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
if denoising_start is not None:
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (denoising_start * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
|
||||
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
|
||||
# if the scheduler is a 2nd order scheduler we might have to do +1
|
||||
# because `num_inference_steps` might be even given that every timestep
|
||||
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
|
||||
# mean that we cut the timesteps in the middle of the denoising step
|
||||
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
|
||||
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
|
||||
num_inference_steps = num_inference_steps + 1
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
# because t_n+1 >= t_n, we slice the timesteps starting from the end
|
||||
timesteps = timesteps[-num_inference_steps:]
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
mask,
|
||||
width,
|
||||
height,
|
||||
num_channels_latents,
|
||||
timestep,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
dtype,
|
||||
device,
|
||||
generator=None,
|
||||
add_noise=True,
|
||||
latents=None,
|
||||
is_strength_max=True,
|
||||
return_noise=False,
|
||||
return_image_latents=False,
|
||||
):
|
||||
batch_size *= num_images_per_prompt
|
||||
|
||||
if image is None:
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
elif mask is None:
|
||||
if not isinstance(image, (torch.Tensor, Image.Image, list)):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if image.shape[1] == 4:
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
elif isinstance(generator, list):
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
if self.vae.config.force_upcast:
|
||||
self.vae.to(dtype)
|
||||
|
||||
init_latents = init_latents.to(dtype)
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
if add_noise:
|
||||
shape = init_latents.shape
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
|
||||
latents = init_latents
|
||||
return latents
|
||||
|
||||
else:
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if (image is None or timestep is None) and not is_strength_max:
|
||||
raise ValueError(
|
||||
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
|
||||
"However, either the image or the noise timestep has not been provided."
|
||||
)
|
||||
|
||||
if image.shape[1] == 4:
|
||||
image_latents = image.to(device=device, dtype=dtype)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
elif return_image_latents or (latents is None and not is_strength_max):
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
if latents is None and add_noise:
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
||||
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
|
||||
# if pure noise then scale the initial latents by the Scheduler's init sigma
|
||||
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
|
||||
elif add_noise:
|
||||
noise = latents.to(device)
|
||||
latents = noise * self.scheduler.init_noise_sigma
|
||||
else:
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = image_latents.to(device)
|
||||
|
||||
outputs = (latents,)
|
||||
|
||||
if return_noise:
|
||||
outputs += (noise,)
|
||||
|
||||
if return_image_latents:
|
||||
outputs += (image_latents,)
|
||||
|
||||
return outputs
|
||||
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
dtype = image.dtype
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
if self.vae.config.force_upcast:
|
||||
self.vae.to(dtype)
|
||||
|
||||
image_latents = image_latents.to(dtype)
|
||||
image_latents = self.vae.config.scaling_factor * image_latents
|
||||
|
||||
return image_latents
|
||||
|
||||
def prepare_mask_latents(
|
||||
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
||||
):
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
)
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
|
||||
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
||||
|
||||
if masked_image is not None and masked_image.shape[1] == 4:
|
||||
masked_image_latents = masked_image
|
||||
else:
|
||||
masked_image_latents = None
|
||||
|
||||
if masked_image is not None:
|
||||
if masked_image_latents is None:
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
|
||||
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
|
||||
masked_image_latents = (
|
||||
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
||||
)
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||
|
||||
return mask, masked_image_latents
|
||||
|
||||
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
@@ -934,15 +1241,52 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def denoising_end(self):
|
||||
return self._denoising_end
|
||||
|
||||
@property
|
||||
def denoising_start(self):
|
||||
return self._denoising_start
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str = None,
|
||||
prompt_2: Optional[str] = None,
|
||||
image: Optional[PipelineImageInput] = None,
|
||||
mask_image: Optional[PipelineImageInput] = None,
|
||||
masked_image_latents: Optional[torch.FloatTensor] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
denoising_start: Optional[float] = None,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[str] = None,
|
||||
@@ -975,20 +1319,46 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
prompt_2 (`str`):
|
||||
The prompt to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in both text-encoders
|
||||
image (`PipelineImageInput`, *optional*):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process.
|
||||
mask_image (`PipelineImageInput`, *optional*):
|
||||
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
||||
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
||||
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
||||
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
||||
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
||||
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
||||
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
||||
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
denoising_start (`float`, *optional*):
|
||||
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
|
||||
it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
|
||||
strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
|
||||
is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image
|
||||
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
|
||||
denoising_end (`float`, *optional*):
|
||||
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
||||
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
||||
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
||||
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
||||
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
||||
still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
|
||||
denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
|
||||
final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
|
||||
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image
|
||||
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
@@ -1084,6 +1454,7 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
@@ -1093,6 +1464,12 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._denoising_start = denoising_start
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -1121,28 +1498,126 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
) = get_weighted_text_embeddings_sdxl(
|
||||
pipe=self, prompt=prompt, neg_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt
|
||||
)
|
||||
dtype = prompt_embeds.dtype
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
if image is not None:
|
||||
image = image.to(device=self.device, dtype=dtype)
|
||||
|
||||
if isinstance(mask_image, Image.Image):
|
||||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
else:
|
||||
mask = mask_image
|
||||
if mask_image is not None:
|
||||
mask = mask.to(device=self.device, dtype=dtype)
|
||||
|
||||
if masked_image_latents is not None:
|
||||
masked_image = masked_image_latents
|
||||
elif image.shape[1] == 4:
|
||||
# if image is in latent space, we can't mask it
|
||||
masked_image = None
|
||||
else:
|
||||
masked_image = image * (mask < 0.5)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
def denoising_value_valid(dnv):
|
||||
return isinstance(self.denoising_end, float) and 0 < dnv < 1
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
if image is not None:
|
||||
timesteps, num_inference_steps = self.get_timesteps(
|
||||
num_inference_steps,
|
||||
strength,
|
||||
device,
|
||||
denoising_start=self.denoising_start if denoising_value_valid else None,
|
||||
)
|
||||
|
||||
# check that number of inference steps is not < 1 - as this doesn't make sense
|
||||
if num_inference_steps < 1:
|
||||
raise ValueError(
|
||||
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
||||
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
||||
)
|
||||
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
is_strength_max = strength == 1.0
|
||||
add_noise = True if self.denoising_start is None else False
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
num_channels_unet = self.unet.config.in_channels
|
||||
return_image_latents = num_channels_unet == 4
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
image=image,
|
||||
mask=mask,
|
||||
width=width,
|
||||
height=height,
|
||||
num_channels_latents=num_channels_unet,
|
||||
timestep=latent_timestep,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
add_noise=add_noise,
|
||||
latents=latents,
|
||||
is_strength_max=is_strength_max,
|
||||
return_noise=True,
|
||||
return_image_latents=return_image_latents,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
if return_image_latents:
|
||||
latents, noise, image_latents = latents
|
||||
else:
|
||||
latents, noise = latents
|
||||
|
||||
# 5.1. Prepare mask latent variables
|
||||
if mask is not None:
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask=mask,
|
||||
masked_image=masked_image,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
|
||||
raise ValueError(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
raise ValueError(
|
||||
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
height, width = latents.shape[-2:]
|
||||
height = height * self.vae_scale_factor
|
||||
width = width * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
@@ -1158,20 +1633,41 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 7.1 Apply denoising_end
|
||||
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
||||
if (
|
||||
self.denoising_end is not None
|
||||
and self.denoising_start is not None
|
||||
and denoising_value_valid(self.denoising_end)
|
||||
and denoising_value_valid(self.denoising_start)
|
||||
and self.denoising_start >= self.denoising_end
|
||||
):
|
||||
raise ValueError(
|
||||
f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
|
||||
+ f" {self.denoising_end} when using type float."
|
||||
)
|
||||
elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
# 8. Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
if self.unet.config.time_cond_proj_dim is not None:
|
||||
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
||||
timestep_cond = self.get_guidance_scale_embedding(
|
||||
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
||||
).to(device=device, dtype=latents.dtype)
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 9. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
@@ -1179,13 +1675,17 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
if mask is not None and num_channels_unet == 9:
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
timestep_cond=timestep_cond,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -1202,6 +1702,22 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if mask is not None and num_channels_unet == 4:
|
||||
init_latents_proper = image_latents
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
init_mask, _ = mask.chunk(2)
|
||||
else:
|
||||
init_mask = mask
|
||||
|
||||
if i < len(timesteps) - 1:
|
||||
noise_timestep = timesteps[i + 1]
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
init_latents_proper, noise, torch.tensor([noise_timestep])
|
||||
)
|
||||
|
||||
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
@@ -1241,6 +1757,204 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
def text2img(
|
||||
self,
|
||||
prompt: str = None,
|
||||
prompt_2: Optional[str] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
denoising_start: Optional[float] = None,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt_2: Optional[str] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
return self.__call__(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
denoising_start=denoising_start,
|
||||
denoising_end=denoising_end,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
original_size=original_size,
|
||||
crops_coords_top_left=crops_coords_top_left,
|
||||
target_size=target_size,
|
||||
)
|
||||
|
||||
def img2img(
|
||||
self,
|
||||
prompt: str = None,
|
||||
prompt_2: Optional[str] = None,
|
||||
image: Optional[PipelineImageInput] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
denoising_start: Optional[float] = None,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt_2: Optional[str] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
return self.__call__(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
image=image,
|
||||
height=height,
|
||||
width=width,
|
||||
strength=strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
denoising_start=denoising_start,
|
||||
denoising_end=denoising_end,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
original_size=original_size,
|
||||
crops_coords_top_left=crops_coords_top_left,
|
||||
target_size=target_size,
|
||||
)
|
||||
|
||||
def inpaint(
|
||||
self,
|
||||
prompt: str = None,
|
||||
prompt_2: Optional[str] = None,
|
||||
image: Optional[PipelineImageInput] = None,
|
||||
mask_image: Optional[PipelineImageInput] = None,
|
||||
masked_image_latents: Optional[torch.FloatTensor] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
denoising_start: Optional[float] = None,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt_2: Optional[str] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
return self.__call__(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
image=image,
|
||||
mask_image=mask_image,
|
||||
masked_image_latents=masked_image_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
strength=strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
denoising_start=denoising_start,
|
||||
denoising_end=denoising_end,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
original_size=original_size,
|
||||
crops_coords_top_left=crops_coords_top_left,
|
||||
target_size=target_size,
|
||||
)
|
||||
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
|
||||
602
examples/community/marigold_depth_estimation.py
Normal file
602
examples/community/marigold_depth_estimation.py
Normal file
@@ -0,0 +1,602 @@
|
||||
# Copyright 2023 Bingxin Ke, ETH Zurich and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# --------------------------------------------------------------------------
|
||||
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
||||
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
||||
# More information about the method can be found at https://marigoldmonodepth.github.io
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
import math
|
||||
from typing import Dict, Union
|
||||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from scipy.optimize import minimize
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.20.1.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Marigold monocular depth prediction pipeline.
|
||||
|
||||
Args:
|
||||
depth_np (`np.ndarray`):
|
||||
Predicted depth map, with depth values in the range of [0, 1].
|
||||
depth_colored (`PIL.Image.Image`):
|
||||
Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
|
||||
uncertainty (`None` or `np.ndarray`):
|
||||
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
|
||||
"""
|
||||
|
||||
depth_np: np.ndarray
|
||||
depth_colored: Image.Image
|
||||
uncertainty: Union[None, np.ndarray]
|
||||
|
||||
|
||||
class MarigoldPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
unet (`UNet2DConditionModel`):
|
||||
Conditional U-Net to denoise the depth latent, conditioned on image latent.
|
||||
vae (`AutoencoderKL`):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
|
||||
to and from latent representations.
|
||||
scheduler (`DDIMScheduler`):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
Text-encoder, for empty text embedding.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
CLIP tokenizer.
|
||||
"""
|
||||
|
||||
rgb_latent_scale_factor = 0.18215
|
||||
depth_latent_scale_factor = 0.18215
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
vae: AutoencoderKL,
|
||||
scheduler: DDIMScheduler,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
self.empty_text_embed = None
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
input_image: Image,
|
||||
denoising_steps: int = 10,
|
||||
ensemble_size: int = 10,
|
||||
processing_res: int = 768,
|
||||
match_input_res: bool = True,
|
||||
batch_size: int = 0,
|
||||
color_map: str = "Spectral",
|
||||
show_progress_bar: bool = True,
|
||||
ensemble_kwargs: Dict = None,
|
||||
) -> MarigoldDepthOutput:
|
||||
"""
|
||||
Function invoked when calling the pipeline.
|
||||
|
||||
Args:
|
||||
input_image (`Image`):
|
||||
Input RGB (or gray-scale) image.
|
||||
processing_res (`int`, *optional*, defaults to `768`):
|
||||
Maximum resolution of processing.
|
||||
If set to 0: will not resize at all.
|
||||
match_input_res (`bool`, *optional*, defaults to `True`):
|
||||
Resize depth prediction to match input resolution.
|
||||
Only valid if `limit_input_res` is not None.
|
||||
denoising_steps (`int`, *optional*, defaults to `10`):
|
||||
Number of diffusion denoising steps (DDIM) during inference.
|
||||
ensemble_size (`int`, *optional*, defaults to `10`):
|
||||
Number of predictions to be ensembled.
|
||||
batch_size (`int`, *optional*, defaults to `0`):
|
||||
Inference batch size, no bigger than `num_ensemble`.
|
||||
If set to 0, the script will automatically decide the proper batch size.
|
||||
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
||||
Display a progress bar of diffusion denoising.
|
||||
color_map (`str`, *optional*, defaults to `"Spectral"`):
|
||||
Colormap used to colorize the depth map.
|
||||
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
|
||||
Arguments for detailed ensembling settings.
|
||||
Returns:
|
||||
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
|
||||
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
|
||||
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1]
|
||||
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
|
||||
coming from ensembling. None if `ensemble_size = 1`
|
||||
"""
|
||||
|
||||
device = self.device
|
||||
input_size = input_image.size
|
||||
|
||||
if not match_input_res:
|
||||
assert processing_res is not None, "Value error: `resize_output_back` is only valid with "
|
||||
assert processing_res >= 0
|
||||
assert denoising_steps >= 1
|
||||
assert ensemble_size >= 1
|
||||
|
||||
# ----------------- Image Preprocess -----------------
|
||||
# Resize image
|
||||
if processing_res > 0:
|
||||
input_image = self.resize_max_res(input_image, max_edge_resolution=processing_res)
|
||||
# Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
|
||||
input_image = input_image.convert("RGB")
|
||||
image = np.asarray(input_image)
|
||||
|
||||
# Normalize rgb values
|
||||
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
|
||||
rgb_norm = rgb / 255.0
|
||||
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
|
||||
rgb_norm = rgb_norm.to(device)
|
||||
assert rgb_norm.min() >= 0.0 and rgb_norm.max() <= 1.0
|
||||
|
||||
# ----------------- Predicting depth -----------------
|
||||
# Batch repeated input image
|
||||
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
|
||||
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
||||
if batch_size > 0:
|
||||
_bs = batch_size
|
||||
else:
|
||||
_bs = self._find_batch_size(
|
||||
ensemble_size=ensemble_size,
|
||||
input_res=max(rgb_norm.shape[1:]),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False)
|
||||
|
||||
# Predict depth maps (batched)
|
||||
depth_pred_ls = []
|
||||
if show_progress_bar:
|
||||
iterable = tqdm(single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False)
|
||||
else:
|
||||
iterable = single_rgb_loader
|
||||
for batch in iterable:
|
||||
(batched_img,) = batch
|
||||
depth_pred_raw = self.single_infer(
|
||||
rgb_in=batched_img,
|
||||
num_inference_steps=denoising_steps,
|
||||
show_pbar=show_progress_bar,
|
||||
)
|
||||
depth_pred_ls.append(depth_pred_raw.detach().clone())
|
||||
depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze()
|
||||
torch.cuda.empty_cache() # clear vram cache for ensembling
|
||||
|
||||
# ----------------- Test-time ensembling -----------------
|
||||
if ensemble_size > 1:
|
||||
depth_pred, pred_uncert = self.ensemble_depths(depth_preds, **(ensemble_kwargs or {}))
|
||||
else:
|
||||
depth_pred = depth_preds
|
||||
pred_uncert = None
|
||||
|
||||
# ----------------- Post processing -----------------
|
||||
# Scale prediction to [0, 1]
|
||||
min_d = torch.min(depth_pred)
|
||||
max_d = torch.max(depth_pred)
|
||||
depth_pred = (depth_pred - min_d) / (max_d - min_d)
|
||||
|
||||
# Convert to numpy
|
||||
depth_pred = depth_pred.cpu().numpy().astype(np.float32)
|
||||
|
||||
# Resize back to original resolution
|
||||
if match_input_res:
|
||||
pred_img = Image.fromarray(depth_pred)
|
||||
pred_img = pred_img.resize(input_size)
|
||||
depth_pred = np.asarray(pred_img)
|
||||
|
||||
# Clip output range
|
||||
depth_pred = depth_pred.clip(0, 1)
|
||||
|
||||
# Colorize
|
||||
depth_colored = self.colorize_depth_maps(
|
||||
depth_pred, 0, 1, cmap=color_map
|
||||
).squeeze() # [3, H, W], value in (0, 1)
|
||||
depth_colored = (depth_colored * 255).astype(np.uint8)
|
||||
depth_colored_hwc = self.chw2hwc(depth_colored)
|
||||
depth_colored_img = Image.fromarray(depth_colored_hwc)
|
||||
return MarigoldDepthOutput(
|
||||
depth_np=depth_pred,
|
||||
depth_colored=depth_colored_img,
|
||||
uncertainty=pred_uncert,
|
||||
)
|
||||
|
||||
def _encode_empty_text(self):
|
||||
"""
|
||||
Encode text embedding for empty prompt.
|
||||
"""
|
||||
prompt = ""
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="do_not_pad",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
||||
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool) -> torch.Tensor:
|
||||
"""
|
||||
Perform an individual depth prediction without ensembling.
|
||||
|
||||
Args:
|
||||
rgb_in (`torch.Tensor`):
|
||||
Input RGB image.
|
||||
num_inference_steps (`int`):
|
||||
Number of diffusion denoisign steps (DDIM) during inference.
|
||||
show_pbar (`bool`):
|
||||
Display a progress bar of diffusion denoising.
|
||||
Returns:
|
||||
`torch.Tensor`: Predicted depth map.
|
||||
"""
|
||||
device = rgb_in.device
|
||||
|
||||
# Set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps # [T]
|
||||
|
||||
# Encode image
|
||||
rgb_latent = self._encode_rgb(rgb_in)
|
||||
|
||||
# Initial depth map (noise)
|
||||
depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w]
|
||||
|
||||
# Batched empty text embedding
|
||||
if self.empty_text_embed is None:
|
||||
self._encode_empty_text()
|
||||
batch_empty_text_embed = self.empty_text_embed.repeat((rgb_latent.shape[0], 1, 1)) # [B, 2, 1024]
|
||||
|
||||
# Denoising loop
|
||||
if show_pbar:
|
||||
iterable = tqdm(
|
||||
enumerate(timesteps),
|
||||
total=len(timesteps),
|
||||
leave=False,
|
||||
desc=" " * 4 + "Diffusion denoising",
|
||||
)
|
||||
else:
|
||||
iterable = enumerate(timesteps)
|
||||
|
||||
for i, t in iterable:
|
||||
unet_input = torch.cat([rgb_latent, depth_latent], dim=1) # this order is important
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
|
||||
torch.cuda.empty_cache()
|
||||
depth = self._decode_depth(depth_latent)
|
||||
|
||||
# clip prediction
|
||||
depth = torch.clip(depth, -1.0, 1.0)
|
||||
# shift to [0, 1]
|
||||
depth = (depth + 1.0) / 2.0
|
||||
|
||||
return depth
|
||||
|
||||
def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Encode RGB image into latent.
|
||||
|
||||
Args:
|
||||
rgb_in (`torch.Tensor`):
|
||||
Input RGB image to be encoded.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Image latent.
|
||||
"""
|
||||
# encode
|
||||
h = self.vae.encoder(rgb_in)
|
||||
moments = self.vae.quant_conv(h)
|
||||
mean, logvar = torch.chunk(moments, 2, dim=1)
|
||||
# scale latent
|
||||
rgb_latent = mean * self.rgb_latent_scale_factor
|
||||
return rgb_latent
|
||||
|
||||
def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Decode depth latent into depth map.
|
||||
|
||||
Args:
|
||||
depth_latent (`torch.Tensor`):
|
||||
Depth latent to be decoded.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Decoded depth map.
|
||||
"""
|
||||
# scale latent
|
||||
depth_latent = depth_latent / self.depth_latent_scale_factor
|
||||
# decode
|
||||
z = self.vae.post_quant_conv(depth_latent)
|
||||
stacked = self.vae.decoder(z)
|
||||
# mean of output channels
|
||||
depth_mean = stacked.mean(dim=1, keepdim=True)
|
||||
return depth_mean
|
||||
|
||||
@staticmethod
|
||||
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
||||
"""
|
||||
Resize image to limit maximum edge length while keeping aspect ratio.
|
||||
|
||||
Args:
|
||||
img (`Image.Image`):
|
||||
Image to be resized.
|
||||
max_edge_resolution (`int`):
|
||||
Maximum edge length (pixel).
|
||||
|
||||
Returns:
|
||||
`Image.Image`: Resized image.
|
||||
"""
|
||||
original_width, original_height = img.size
|
||||
downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height)
|
||||
|
||||
new_width = int(original_width * downscale_factor)
|
||||
new_height = int(original_height * downscale_factor)
|
||||
|
||||
resized_img = img.resize((new_width, new_height))
|
||||
return resized_img
|
||||
|
||||
@staticmethod
|
||||
def colorize_depth_maps(depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None):
|
||||
"""
|
||||
Colorize depth maps.
|
||||
"""
|
||||
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
||||
|
||||
if isinstance(depth_map, torch.Tensor):
|
||||
depth = depth_map.detach().clone().squeeze().numpy()
|
||||
elif isinstance(depth_map, np.ndarray):
|
||||
depth = depth_map.copy().squeeze()
|
||||
# reshape to [ (B,) H, W ]
|
||||
if depth.ndim < 3:
|
||||
depth = depth[np.newaxis, :, :]
|
||||
|
||||
# colorize
|
||||
cm = matplotlib.colormaps[cmap]
|
||||
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
||||
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
|
||||
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
||||
|
||||
if valid_mask is not None:
|
||||
if isinstance(depth_map, torch.Tensor):
|
||||
valid_mask = valid_mask.detach().numpy()
|
||||
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
||||
if valid_mask.ndim < 3:
|
||||
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
||||
else:
|
||||
valid_mask = valid_mask[:, np.newaxis, :, :]
|
||||
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
||||
img_colored_np[~valid_mask] = 0
|
||||
|
||||
if isinstance(depth_map, torch.Tensor):
|
||||
img_colored = torch.from_numpy(img_colored_np).float()
|
||||
elif isinstance(depth_map, np.ndarray):
|
||||
img_colored = img_colored_np
|
||||
|
||||
return img_colored
|
||||
|
||||
@staticmethod
|
||||
def chw2hwc(chw):
|
||||
assert 3 == len(chw.shape)
|
||||
if isinstance(chw, torch.Tensor):
|
||||
hwc = torch.permute(chw, (1, 2, 0))
|
||||
elif isinstance(chw, np.ndarray):
|
||||
hwc = np.moveaxis(chw, 0, -1)
|
||||
return hwc
|
||||
|
||||
@staticmethod
|
||||
def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
||||
"""
|
||||
Automatically search for suitable operating batch size.
|
||||
|
||||
Args:
|
||||
ensemble_size (`int`):
|
||||
Number of predictions to be ensembled.
|
||||
input_res (`int`):
|
||||
Operating resolution of the input image.
|
||||
|
||||
Returns:
|
||||
`int`: Operating batch size.
|
||||
"""
|
||||
# Search table for suggested max. inference batch size
|
||||
bs_search_table = [
|
||||
# tested on A100-PCIE-80GB
|
||||
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
||||
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
||||
# tested on A100-PCIE-40GB
|
||||
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
||||
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
||||
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
||||
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
||||
# tested on RTX3090, RTX4090
|
||||
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
||||
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
||||
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
||||
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
||||
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
||||
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
||||
# tested on GTX1080Ti
|
||||
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
||||
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
||||
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
||||
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
||||
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
||||
]
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return 1
|
||||
|
||||
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
||||
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
||||
for settings in sorted(
|
||||
filtered_bs_search_table,
|
||||
key=lambda k: (k["res"], -k["total_vram"]),
|
||||
):
|
||||
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
||||
bs = settings["bs"]
|
||||
if bs > ensemble_size:
|
||||
bs = ensemble_size
|
||||
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
||||
bs = math.ceil(ensemble_size / 2)
|
||||
return bs
|
||||
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def ensemble_depths(
|
||||
input_images: torch.Tensor,
|
||||
regularizer_strength: float = 0.02,
|
||||
max_iter: int = 2,
|
||||
tol: float = 1e-3,
|
||||
reduction: str = "median",
|
||||
max_res: int = None,
|
||||
):
|
||||
"""
|
||||
To ensemble multiple affine-invariant depth images (up to scale and shift),
|
||||
by aligning estimating the scale and shift
|
||||
"""
|
||||
|
||||
def inter_distances(tensors: torch.Tensor):
|
||||
"""
|
||||
To calculate the distance between each two depth maps.
|
||||
"""
|
||||
distances = []
|
||||
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
|
||||
arr1 = tensors[i : i + 1]
|
||||
arr2 = tensors[j : j + 1]
|
||||
distances.append(arr1 - arr2)
|
||||
dist = torch.concatenate(distances, dim=0)
|
||||
return dist
|
||||
|
||||
device = input_images.device
|
||||
dtype = input_images.dtype
|
||||
np_dtype = np.float32
|
||||
|
||||
original_input = input_images.clone()
|
||||
n_img = input_images.shape[0]
|
||||
ori_shape = input_images.shape
|
||||
|
||||
if max_res is not None:
|
||||
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
|
||||
if scale_factor < 1:
|
||||
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
|
||||
input_images = downscaler(torch.from_numpy(input_images)).numpy()
|
||||
|
||||
# init guess
|
||||
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
||||
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
||||
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
|
||||
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
|
||||
x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
|
||||
|
||||
input_images = input_images.to(device)
|
||||
|
||||
# objective function
|
||||
def closure(x):
|
||||
l = len(x)
|
||||
s = x[: int(l / 2)]
|
||||
t = x[int(l / 2) :]
|
||||
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
||||
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
||||
|
||||
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
|
||||
dists = inter_distances(transformed_arrays)
|
||||
sqrt_dist = torch.sqrt(torch.mean(dists**2))
|
||||
|
||||
if "mean" == reduction:
|
||||
pred = torch.mean(transformed_arrays, dim=0)
|
||||
elif "median" == reduction:
|
||||
pred = torch.median(transformed_arrays, dim=0).values
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
|
||||
far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
|
||||
|
||||
err = sqrt_dist + (near_err + far_err) * regularizer_strength
|
||||
err = err.detach().cpu().numpy().astype(np_dtype)
|
||||
return err
|
||||
|
||||
res = minimize(
|
||||
closure,
|
||||
x,
|
||||
method="BFGS",
|
||||
tol=tol,
|
||||
options={"maxiter": max_iter, "disp": False},
|
||||
)
|
||||
x = res.x
|
||||
l = len(x)
|
||||
s = x[: int(l / 2)]
|
||||
t = x[int(l / 2) :]
|
||||
|
||||
# Prediction
|
||||
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
||||
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
||||
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
|
||||
if "mean" == reduction:
|
||||
aligned_images = torch.mean(transformed_arrays, dim=0)
|
||||
std = torch.std(transformed_arrays, dim=0)
|
||||
uncertainty = std
|
||||
elif "median" == reduction:
|
||||
aligned_images = torch.median(transformed_arrays, dim=0).values
|
||||
# MAD (median absolute deviation) as uncertainty indicator
|
||||
abs_dev = torch.abs(transformed_arrays - aligned_images)
|
||||
mad = torch.median(abs_dev, dim=0).values
|
||||
uncertainty = mad
|
||||
else:
|
||||
raise ValueError(f"Unknown reduction method: {reduction}")
|
||||
|
||||
# Scale and shift to [0, 1]
|
||||
_min = torch.min(aligned_images)
|
||||
_max = torch.max(aligned_images)
|
||||
aligned_images = (aligned_images - _min) / (_max - _min)
|
||||
uncertainty /= _max - _min
|
||||
|
||||
return aligned_images, uncertainty
|
||||
@@ -73,7 +73,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker,
|
||||
feature_extractor,
|
||||
requires_safety_checker,
|
||||
)
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -102,22 +109,22 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
rp_args: Dict[str, str] = None,
|
||||
):
|
||||
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721
|
||||
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
|
||||
if negative_prompt is None:
|
||||
negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721
|
||||
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
|
||||
|
||||
device = self._execution_device
|
||||
regions = 0
|
||||
|
||||
self.power = int(rp_args["power"]) if "power" in rp_args else 1
|
||||
|
||||
prompts = prompt if type(prompt) == list else [prompt] # noqa: E721
|
||||
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721
|
||||
prompts = prompt if isinstance(prompt, list) else [prompt]
|
||||
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
|
||||
self.batch = batch = num_images_per_prompt * len(prompts)
|
||||
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
|
||||
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
|
||||
|
||||
cn = len(all_prompts_cn) == len(all_n_prompts_cn)
|
||||
equal = len(all_prompts_cn) == len(all_n_prompts_cn)
|
||||
|
||||
if Compel:
|
||||
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
|
||||
@@ -129,7 +136,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
return torch.cat(embl)
|
||||
|
||||
conds = getcompelembs(all_prompts_cn)
|
||||
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
|
||||
unconds = getcompelembs(all_n_prompts_cn)
|
||||
embs = getcompelembs(prompts)
|
||||
n_embs = getcompelembs(n_prompts)
|
||||
prompt = negative_prompt = None
|
||||
@@ -137,7 +144,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
conds = self.encode_prompt(prompts, device, 1, True)[0]
|
||||
unconds = (
|
||||
self.encode_prompt(n_prompts, device, 1, True)[0]
|
||||
if cn
|
||||
if equal
|
||||
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
|
||||
)
|
||||
embs = n_embs = None
|
||||
@@ -206,8 +213,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
else:
|
||||
px, nx = hidden_states.chunk(2)
|
||||
|
||||
if cn:
|
||||
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
|
||||
if equal:
|
||||
hidden_states = torch.cat(
|
||||
[px for i in range(regions)] + [nx for i in range(regions)],
|
||||
0,
|
||||
)
|
||||
encoder_hidden_states = torch.cat([conds] + [unconds])
|
||||
else:
|
||||
hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
|
||||
@@ -289,9 +299,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
if any(x in mode for x in ["COL", "ROW"]):
|
||||
reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
|
||||
center = reshaped.shape[0] // 2
|
||||
px = reshaped[0:center] if cn else reshaped[0:-batch]
|
||||
nx = reshaped[center:] if cn else reshaped[-batch:]
|
||||
outs = [px, nx] if cn else [px]
|
||||
px = reshaped[0:center] if equal else reshaped[0:-batch]
|
||||
nx = reshaped[center:] if equal else reshaped[-batch:]
|
||||
outs = [px, nx] if equal else [px]
|
||||
for out in outs:
|
||||
c = 0
|
||||
for i, ocell in enumerate(ocells):
|
||||
@@ -321,15 +331,16 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
:,
|
||||
]
|
||||
c += 1
|
||||
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
|
||||
px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
|
||||
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
|
||||
hidden_states = hidden_states.reshape(xshape)
|
||||
|
||||
#### Regional Prompting Prompt mode
|
||||
elif "PRO" in mode:
|
||||
center = reshaped.shape[0] // 2
|
||||
px = reshaped[0:center] if cn else reshaped[0:-batch]
|
||||
nx = reshaped[center:] if cn else reshaped[-batch:]
|
||||
px, nx = (
|
||||
torch.chunk(hidden_states) if equal else hidden_states[0:-batch],
|
||||
hidden_states[-batch:],
|
||||
)
|
||||
|
||||
if (h, w) in self.attnmasks and self.maskready:
|
||||
|
||||
@@ -340,8 +351,8 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
out[b] = out[b] + out[r * batch + b]
|
||||
return out
|
||||
|
||||
px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
|
||||
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
|
||||
px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx)
|
||||
px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
|
||||
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
|
||||
return hidden_states
|
||||
|
||||
@@ -378,7 +389,15 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
save_mask = False
|
||||
|
||||
if mode == "PROMPT" and save_mask:
|
||||
saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
|
||||
saveattnmaps(
|
||||
self,
|
||||
output,
|
||||
height,
|
||||
width,
|
||||
thresholds,
|
||||
num_inference_steps // 2,
|
||||
regions,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -437,7 +456,11 @@ def make_cells(ratios):
|
||||
def make_emblist(self, prompts):
|
||||
with torch.no_grad():
|
||||
tokens = self.tokenizer(
|
||||
prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
|
||||
prompts,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids.to(self.device)
|
||||
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
|
||||
return embs
|
||||
@@ -563,7 +586,15 @@ def tokendealer(self, all_prompts):
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
getattn=False,
|
||||
) -> torch.Tensor:
|
||||
# Efficient implementation equivalent to the following:
|
||||
L, S = query.size(-2), key.size(-2)
|
||||
|
||||
594
examples/community/sde_drag.py
Normal file
594
examples/community/sde_drag.py
Normal file
@@ -0,0 +1,594 @@
|
||||
import math
|
||||
import tempfile
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel
|
||||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
|
||||
class SdeDragPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for image drag-and-drop editing using stochastic differential equations: https://arxiv.org/abs/2311.01410.
|
||||
Please refer to the [official repository](https://github.com/ML-GSAI/SDE-Drag) for more information.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Please use
|
||||
[`DDIMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
image: PIL.Image.Image,
|
||||
mask_image: PIL.Image.Image,
|
||||
source_points: List[List[int]],
|
||||
target_points: List[List[int]],
|
||||
t0: Optional[float] = 0.6,
|
||||
steps: Optional[int] = 200,
|
||||
step_size: Optional[int] = 2,
|
||||
image_scale: Optional[float] = 0.3,
|
||||
adapt_radius: Optional[int] = 5,
|
||||
min_lora_scale: Optional[float] = 0.5,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for image editing.
|
||||
Args:
|
||||
prompt (`str`, *required*):
|
||||
The prompt to guide the image editing.
|
||||
image (`PIL.Image.Image`, *required*):
|
||||
Which will be edited, parts of the image will be masked out with `mask_image` and edited
|
||||
according to `prompt`.
|
||||
mask_image (`PIL.Image.Image`, *required*):
|
||||
To mask `image`. White pixels in the mask will be edited, while black pixels will be preserved.
|
||||
source_points (`List[List[int]]`, *required*):
|
||||
Used to mark the starting positions of drag editing in the image, with each pixel represented as a
|
||||
`List[int]` of length 2.
|
||||
target_points (`List[List[int]]`, *required*):
|
||||
Used to mark the target positions of drag editing in the image, with each pixel represented as a
|
||||
`List[int]` of length 2.
|
||||
t0 (`float`, *optional*, defaults to 0.6):
|
||||
The time parameter. Higher t0 improves the fidelity while lowering the faithfulness of the edited images
|
||||
and vice versa.
|
||||
steps (`int`, *optional*, defaults to 200):
|
||||
The number of sampling iterations.
|
||||
step_size (`int`, *optional*, defaults to 2):
|
||||
The drag diatance of each drag step.
|
||||
image_scale (`float`, *optional*, defaults to 0.3):
|
||||
To avoid duplicating the content, use image_scale to perturbs the source.
|
||||
adapt_radius (`int`, *optional*, defaults to 5):
|
||||
The size of the region for copy and paste operations during each step of the drag process.
|
||||
min_lora_scale (`float`, *optional*, defaults to 0.5):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
min_lora_scale specifies the minimum LoRA scale during the image drag-editing process.
|
||||
generator ('torch.Generator', *optional*, defaults to None):
|
||||
To make generation deterministic(https://pytorch.org/docs/stable/generated/torch.Generator.html).
|
||||
Examples:
|
||||
```py
|
||||
>>> import PIL
|
||||
>>> import torch
|
||||
>>> from diffusers import DDIMScheduler, DiffusionPipeline
|
||||
|
||||
>>> # Load the pipeline
|
||||
>>> model_path = "runwayml/stable-diffusion-v1-5"
|
||||
>>> scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
>>> pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")
|
||||
>>> pipe.to('cuda')
|
||||
|
||||
>>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality.
|
||||
>>> # If not training LoRA, please avoid using torch.float16
|
||||
>>> # pipe.to(torch.float16)
|
||||
|
||||
>>> # Provide prompt, image, mask image, and the starting and target points for drag editing.
|
||||
>>> prompt = "prompt of the image"
|
||||
>>> image = PIL.Image.open('/path/to/image')
|
||||
>>> mask_image = PIL.Image.open('/path/to/mask_image')
|
||||
>>> source_points = [[123, 456]]
|
||||
>>> target_points = [[234, 567]]
|
||||
|
||||
>>> # train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image.
|
||||
>>> pipe.train_lora(prompt, image)
|
||||
|
||||
>>> output = pipe(prompt, image, mask_image, source_points, target_points)
|
||||
>>> output_image = PIL.Image.fromarray(output)
|
||||
>>> output_image.save("./output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
self.scheduler.set_timesteps(steps)
|
||||
|
||||
noise_scale = (1 - image_scale**2) ** (0.5)
|
||||
|
||||
text_embeddings = self._get_text_embed(prompt)
|
||||
uncond_embeddings = self._get_text_embed([""])
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
latent = self._get_img_latent(image)
|
||||
|
||||
mask = mask_image.resize((latent.shape[3], latent.shape[2]))
|
||||
mask = torch.tensor(np.array(mask))
|
||||
mask = mask.unsqueeze(0).expand_as(latent).to(self.device)
|
||||
|
||||
source_points = torch.tensor(source_points).div(torch.tensor([8]), rounding_mode="trunc")
|
||||
target_points = torch.tensor(target_points).div(torch.tensor([8]), rounding_mode="trunc")
|
||||
|
||||
distance = target_points - source_points
|
||||
distance_norm_max = torch.norm(distance.float(), dim=1, keepdim=True).max()
|
||||
|
||||
if distance_norm_max <= step_size:
|
||||
drag_num = 1
|
||||
else:
|
||||
drag_num = distance_norm_max.div(torch.tensor([step_size]), rounding_mode="trunc")
|
||||
if (distance_norm_max / drag_num - step_size).abs() > (
|
||||
distance_norm_max / (drag_num + 1) - step_size
|
||||
).abs():
|
||||
drag_num += 1
|
||||
|
||||
latents = []
|
||||
for i in tqdm(range(int(drag_num)), desc="SDE Drag"):
|
||||
source_new = source_points + (i / drag_num * distance).to(torch.int)
|
||||
target_new = source_points + ((i + 1) / drag_num * distance).to(torch.int)
|
||||
|
||||
latent, noises, hook_latents, lora_scales, cfg_scales = self._forward(
|
||||
latent, steps, t0, min_lora_scale, text_embeddings, generator
|
||||
)
|
||||
latent = self._copy_and_paste(
|
||||
latent,
|
||||
source_new,
|
||||
target_new,
|
||||
adapt_radius,
|
||||
latent.shape[2] - 1,
|
||||
latent.shape[3] - 1,
|
||||
image_scale,
|
||||
noise_scale,
|
||||
generator,
|
||||
)
|
||||
latent = self._backward(
|
||||
latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator
|
||||
)
|
||||
|
||||
latents.append(latent)
|
||||
|
||||
result_image = 1 / 0.18215 * latents[-1]
|
||||
|
||||
with torch.no_grad():
|
||||
result_image = self.vae.decode(result_image).sample
|
||||
|
||||
result_image = (result_image / 2 + 0.5).clamp(0, 1)
|
||||
result_image = result_image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
||||
result_image = (result_image * 255).astype(np.uint8)
|
||||
|
||||
return result_image
|
||||
|
||||
def train_lora(self, prompt, image, lora_step=100, lora_rank=16, generator=None):
|
||||
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision="fp16")
|
||||
|
||||
self.vae.requires_grad_(False)
|
||||
self.text_encoder.requires_grad_(False)
|
||||
self.unet.requires_grad_(False)
|
||||
|
||||
unet_lora_attn_procs = {}
|
||||
for name, attn_processor in self.unet.attn_processors.items():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = self.unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = self.unet.config.block_out_channels[block_id]
|
||||
else:
|
||||
raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")
|
||||
|
||||
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
|
||||
lora_attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
||||
else LoRAAttnProcessor
|
||||
)
|
||||
unet_lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
|
||||
)
|
||||
|
||||
self.unet.set_attn_processor(unet_lora_attn_procs)
|
||||
unet_lora_layers = AttnProcsLayers(self.unet.attn_processors)
|
||||
params_to_optimize = unet_lora_layers.parameters()
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
params_to_optimize,
|
||||
lr=2e-4,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=1e-2,
|
||||
eps=1e-08,
|
||||
)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
"constant",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=lora_step,
|
||||
num_cycles=1,
|
||||
power=1.0,
|
||||
)
|
||||
|
||||
unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
|
||||
optimizer = accelerator.prepare_optimizer(optimizer)
|
||||
lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
|
||||
|
||||
with torch.no_grad():
|
||||
text_inputs = self._tokenize_prompt(prompt, tokenizer_max_length=None)
|
||||
text_embedding = self._encode_prompt(
|
||||
text_inputs.input_ids, text_inputs.attention_mask, text_encoder_use_attention_mask=False
|
||||
)
|
||||
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
image = image_transforms(image).to(self.device, dtype=self.vae.dtype)
|
||||
image = image.unsqueeze(dim=0)
|
||||
latents_dist = self.vae.encode(image).latent_dist
|
||||
|
||||
for _ in tqdm(range(lora_step), desc="Train LoRA"):
|
||||
self.unet.train()
|
||||
model_input = latents_dist.sample() * self.vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(
|
||||
model_input.size(),
|
||||
dtype=model_input.dtype,
|
||||
layout=model_input.layout,
|
||||
device=model_input.device,
|
||||
generator=generator,
|
||||
)
|
||||
bsz, channels, height, width = model_input.shape
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, self.scheduler.config.num_train_timesteps, (bsz,), device=model_input.device, generator=generator
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = self.scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = self.unet(noisy_model_input, timesteps, text_embedding).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if self.scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif self.scheduler.config.prediction_type == "v_prediction":
|
||||
target = self.scheduler.get_velocity(model_input, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}")
|
||||
|
||||
loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
with tempfile.TemporaryDirectory() as save_lora_dir:
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=save_lora_dir,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=None,
|
||||
)
|
||||
|
||||
self.unet.load_attn_procs(save_lora_dir)
|
||||
|
||||
def _tokenize_prompt(self, prompt, tokenizer_max_length=None):
|
||||
if tokenizer_max_length is not None:
|
||||
max_length = tokenizer_max_length
|
||||
else:
|
||||
max_length = self.tokenizer.model_max_length
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return text_inputs
|
||||
|
||||
def _encode_prompt(self, input_ids, attention_mask, text_encoder_use_attention_mask=False):
|
||||
text_input_ids = input_ids.to(self.device)
|
||||
|
||||
if text_encoder_use_attention_mask:
|
||||
attention_mask = attention_mask.to(self.device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_text_embed(self, prompt):
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
return text_embeddings
|
||||
|
||||
def _copy_and_paste(
|
||||
self, latent, source_new, target_new, adapt_radius, max_height, max_width, image_scale, noise_scale, generator
|
||||
):
|
||||
def adaption_r(source, target, adapt_radius, max_height, max_width):
|
||||
r_x_lower = min(adapt_radius, source[0], target[0])
|
||||
r_x_upper = min(adapt_radius, max_width - source[0], max_width - target[0])
|
||||
r_y_lower = min(adapt_radius, source[1], target[1])
|
||||
r_y_upper = min(adapt_radius, max_height - source[1], max_height - target[1])
|
||||
return r_x_lower, r_x_upper, r_y_lower, r_y_upper
|
||||
|
||||
for source_, target_ in zip(source_new, target_new):
|
||||
r_x_lower, r_x_upper, r_y_lower, r_y_upper = adaption_r(
|
||||
source_, target_, adapt_radius, max_height, max_width
|
||||
)
|
||||
|
||||
source_feature = latent[
|
||||
:, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper
|
||||
].clone()
|
||||
|
||||
latent[
|
||||
:, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper
|
||||
] = image_scale * source_feature + noise_scale * torch.randn(
|
||||
latent.shape[0],
|
||||
4,
|
||||
r_y_lower + r_y_upper,
|
||||
r_x_lower + r_x_upper,
|
||||
device=self.device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
latent[
|
||||
:, :, target_[1] - r_y_lower : target_[1] + r_y_upper, target_[0] - r_x_lower : target_[0] + r_x_upper
|
||||
] = source_feature * 1.1
|
||||
return latent
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_img_latent(self, image, height=None, weight=None):
|
||||
data = image.convert("RGB")
|
||||
if height is not None:
|
||||
data = data.resize((weight, height))
|
||||
transform = transforms.ToTensor()
|
||||
data = transform(data).unsqueeze(0)
|
||||
data = (data * 2.0) - 1.0
|
||||
data = data.to(self.device, dtype=self.vae.dtype)
|
||||
latent = self.vae.encode(data).latent_dist.sample()
|
||||
latent = 0.18215 * latent
|
||||
return latent
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_eps(self, latent, timestep, guidance_scale, text_embeddings, lora_scale=None):
|
||||
latent_model_input = torch.cat([latent] * 2) if guidance_scale > 1.0 else latent
|
||||
text_embeddings = text_embeddings if guidance_scale > 1.0 else text_embeddings.chunk(2)[1]
|
||||
|
||||
cross_attention_kwargs = None if lora_scale is None else {"scale": lora_scale}
|
||||
|
||||
with torch.no_grad():
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
if guidance_scale > 1.0:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
elif guidance_scale == 1.0:
|
||||
noise_pred_text = noise_pred
|
||||
noise_pred_uncond = 0.0
|
||||
else:
|
||||
raise NotImplementedError(guidance_scale)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
return noise_pred
|
||||
|
||||
def _forward_sde(
|
||||
self, timestep, sample, guidance_scale, text_embeddings, steps, eta=1.0, lora_scale=None, generator=None
|
||||
):
|
||||
num_train_timesteps = len(self.scheduler)
|
||||
alphas_cumprod = self.scheduler.alphas_cumprod
|
||||
initial_alpha_cumprod = torch.tensor(1.0)
|
||||
|
||||
prev_timestep = timestep + num_train_timesteps // steps
|
||||
|
||||
alpha_prod_t = alphas_cumprod[timestep] if timestep >= 0 else initial_alpha_cumprod
|
||||
alpha_prod_t_prev = alphas_cumprod[prev_timestep]
|
||||
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
x_prev = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) * sample + (1 - alpha_prod_t_prev / alpha_prod_t) ** (
|
||||
0.5
|
||||
) * torch.randn(
|
||||
sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator
|
||||
)
|
||||
eps = self._get_eps(x_prev, prev_timestep, guidance_scale, text_embeddings, lora_scale)
|
||||
|
||||
sigma_t_prev = (
|
||||
eta
|
||||
* (1 - alpha_prod_t) ** (0.5)
|
||||
* (1 - alpha_prod_t_prev / (1 - alpha_prod_t_prev) * (1 - alpha_prod_t) / alpha_prod_t) ** (0.5)
|
||||
)
|
||||
|
||||
pred_original_sample = (x_prev - beta_prod_t_prev ** (0.5) * eps) / alpha_prod_t_prev ** (0.5)
|
||||
pred_sample_direction_coeff = (1 - alpha_prod_t - sigma_t_prev**2) ** (0.5)
|
||||
|
||||
noise = (
|
||||
sample - alpha_prod_t ** (0.5) * pred_original_sample - pred_sample_direction_coeff * eps
|
||||
) / sigma_t_prev
|
||||
|
||||
return x_prev, noise
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
timestep,
|
||||
sample,
|
||||
guidance_scale,
|
||||
text_embeddings,
|
||||
steps,
|
||||
sde=False,
|
||||
noise=None,
|
||||
eta=1.0,
|
||||
lora_scale=None,
|
||||
generator=None,
|
||||
):
|
||||
num_train_timesteps = len(self.scheduler)
|
||||
alphas_cumprod = self.scheduler.alphas_cumprod
|
||||
final_alpha_cumprod = torch.tensor(1.0)
|
||||
|
||||
eps = self._get_eps(sample, timestep, guidance_scale, text_embeddings, lora_scale)
|
||||
|
||||
prev_timestep = timestep - num_train_timesteps // steps
|
||||
|
||||
alpha_prod_t = alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alpha_cumprod
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
sigma_t = (
|
||||
eta
|
||||
* ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** (0.5)
|
||||
* (1 - alpha_prod_t / alpha_prod_t_prev) ** (0.5)
|
||||
if sde
|
||||
else 0
|
||||
)
|
||||
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * eps) / alpha_prod_t ** (0.5)
|
||||
pred_sample_direction_coeff = (1 - alpha_prod_t_prev - sigma_t**2) ** (0.5)
|
||||
|
||||
noise = (
|
||||
torch.randn(
|
||||
sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator
|
||||
)
|
||||
if noise is None
|
||||
else noise
|
||||
)
|
||||
latent = (
|
||||
alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction_coeff * eps + sigma_t * noise
|
||||
)
|
||||
|
||||
return latent
|
||||
|
||||
def _forward(self, latent, steps, t0, lora_scale_min, text_embeddings, generator):
|
||||
def scale_schedule(begin, end, n, length, type="linear"):
|
||||
if type == "constant":
|
||||
return end
|
||||
elif type == "linear":
|
||||
return begin + (end - begin) * n / length
|
||||
elif type == "cos":
|
||||
factor = (1 - math.cos(n * math.pi / length)) / 2
|
||||
return (1 - factor) * begin + factor * end
|
||||
else:
|
||||
raise NotImplementedError(type)
|
||||
|
||||
noises = []
|
||||
latents = []
|
||||
lora_scales = []
|
||||
cfg_scales = []
|
||||
latents.append(latent)
|
||||
t0 = int(t0 * steps)
|
||||
t_begin = steps - t0
|
||||
|
||||
length = len(self.scheduler.timesteps[t_begin - 1 : -1]) - 1
|
||||
index = 1
|
||||
for t in self.scheduler.timesteps[t_begin:].flip(dims=[0]):
|
||||
lora_scale = scale_schedule(1, lora_scale_min, index, length, type="cos")
|
||||
cfg_scale = scale_schedule(1, 3.0, index, length, type="linear")
|
||||
latent, noise = self._forward_sde(
|
||||
t, latent, cfg_scale, text_embeddings, steps, lora_scale=lora_scale, generator=generator
|
||||
)
|
||||
|
||||
noises.append(noise)
|
||||
latents.append(latent)
|
||||
lora_scales.append(lora_scale)
|
||||
cfg_scales.append(cfg_scale)
|
||||
index += 1
|
||||
return latent, noises, latents, lora_scales, cfg_scales
|
||||
|
||||
def _backward(
|
||||
self, latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator
|
||||
):
|
||||
t0 = int(t0 * steps)
|
||||
t_begin = steps - t0
|
||||
|
||||
hook_latent = hook_latents.pop()
|
||||
latent = torch.where(mask > 128, latent, hook_latent)
|
||||
for t in self.scheduler.timesteps[t_begin - 1 : -1]:
|
||||
latent = self._sample(
|
||||
t,
|
||||
latent,
|
||||
cfg_scales.pop(),
|
||||
text_embeddings,
|
||||
steps,
|
||||
sde=True,
|
||||
noise=noises.pop(),
|
||||
lora_scale=lora_scales.pop(),
|
||||
generator=generator,
|
||||
)
|
||||
hook_latent = hook_latents.pop()
|
||||
latent = torch.where(mask > 128, latent, hook_latent)
|
||||
return latent
|
||||
@@ -50,6 +50,7 @@ from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionPipelineOutput,
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils import logging
|
||||
|
||||
@@ -608,7 +609,7 @@ class TorchVAEEncoder(torch.nn.Module):
|
||||
self.vae_encoder = model
|
||||
|
||||
def forward(self, x):
|
||||
return self.vae_encoder.encode(x).latent_dist.sample()
|
||||
return retrieve_latents(self.vae_encoder.encode(x))
|
||||
|
||||
|
||||
class VAEEncoder(BaseModel):
|
||||
@@ -1004,7 +1005,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
"""
|
||||
self.generator = generator
|
||||
self.denoising_steps = num_inference_steps
|
||||
self.guidance_scale = guidance_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
# Pre-compute latent input scales and linear multistep coefficients
|
||||
self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)
|
||||
|
||||
@@ -94,7 +94,7 @@ accelerate launch train_lcm_distill_lora_sd_wds.py \
|
||||
--mixed_precision=fp16 \
|
||||
--resolution=512 \
|
||||
--lora_rank=64 \
|
||||
--learning_rate=1e-6 --loss_type="huber" --adam_weight_decay=0.0 \
|
||||
--learning_rate=1e-4 --loss_type="huber" --adam_weight_decay=0.0 \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
|
||||
@@ -96,7 +96,7 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \
|
||||
--mixed_precision=fp16 \
|
||||
--resolution=1024 \
|
||||
--lora_rank=64 \
|
||||
--learning_rate=1e-6 --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \
|
||||
--learning_rate=1e-4 --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \
|
||||
--max_train_steps=1000 \
|
||||
--max_train_samples=4000000 \
|
||||
--dataloader_num_workers=8 \
|
||||
@@ -111,4 +111,38 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \
|
||||
--report_to=wandb \
|
||||
--seed=453645634 \
|
||||
--push_to_hub \
|
||||
```
|
||||
```
|
||||
|
||||
We provide another version for LCM LoRA SDXL that follows best practices of `peft` and leverages the `datasets` library for quick experimentation. The script doesn't load two UNets unlike `train_lcm_distill_lora_sdxl_wds.py` which reduces the memory requirements quite a bit.
|
||||
|
||||
Below is an example training command that trains an LCM LoRA on the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions):
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
|
||||
|
||||
accelerate launch train_lcm_distill_lora_sdxl.py \
|
||||
--pretrained_teacher_model=${MODEL_NAME} \
|
||||
--pretrained_vae_model_name_or_path=${VAE_PATH} \
|
||||
--output_dir="pokemons-lora-lcm-sdxl" \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=24 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--lora_rank=64 \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=3000 \
|
||||
--checkpointing_steps=500 \
|
||||
--validation_steps=50 \
|
||||
--seed="0" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
|
||||
112
examples/consistency_distillation/test_lcm_lora.py
Normal file
112
examples/consistency_distillation/test_lcm_lora.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class TextToImageLCM(ExamplesTestsAccelerate):
|
||||
def test_text_to_image_lcm_lora_sdxl(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
|
||||
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 64
|
||||
--lora_rank 4
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
def test_text_to_image_lcm_lora_sdxl_checkpointing(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
|
||||
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 64
|
||||
--lora_rank 4
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--checkpointing_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
test_args = f"""
|
||||
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
|
||||
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 64
|
||||
--lora_rank 4
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 9
|
||||
--checkpointing_steps 2
|
||||
--resume_from_checkpoint latest
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
@@ -38,7 +38,7 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from braceexpand import braceexpand
|
||||
from huggingface_hub import create_repo
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
|
||||
from torch.utils.data import default_collate
|
||||
@@ -156,7 +156,7 @@ class WebdatasetFilter:
|
||||
return False
|
||||
|
||||
|
||||
class Text2ImageDataset:
|
||||
class SDText2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
@@ -359,19 +359,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
|
||||
|
||||
|
||||
# Compare LCMScheduler.step, Step 4
|
||||
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
pred_x_0 = (sample - sigmas * model_output) / alphas
|
||||
elif prediction_type == "sample":
|
||||
pred_x_0 = model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
|
||||
pred_x_0 = alphas * sample - sigmas * model_output
|
||||
else:
|
||||
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_x_0
|
||||
|
||||
|
||||
# Based on step 4 in DDIMScheduler.step
|
||||
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
pred_epsilon = model_output
|
||||
elif prediction_type == "sample":
|
||||
pred_epsilon = (sample - alphas * model_output) / sigmas
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_epsilon = alphas * model_output + sigmas * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_epsilon
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
@@ -823,7 +847,7 @@ def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
create_repo(
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
||||
exist_ok=True,
|
||||
token=args.hub_token,
|
||||
@@ -835,34 +859,35 @@ def main(args):
|
||||
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# The scheduler calculates the alpha and sigma schedule for us
|
||||
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
||||
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
||||
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
||||
# Initialize the DDIM ODE solver for distillation.
|
||||
solver = DDIMSolver(
|
||||
noise_scheduler.alphas_cumprod.numpy(),
|
||||
timesteps=noise_scheduler.config.num_train_timesteps,
|
||||
ddim_timesteps=args.num_ddim_timesteps,
|
||||
)
|
||||
|
||||
# 2. Load tokenizers from SD-XL checkpoint.
|
||||
# 2. Load tokenizers from SD 1.X/2.X checkpoint.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
|
||||
)
|
||||
|
||||
# 3. Load text encoders from SD-1.5 checkpoint.
|
||||
# 3. Load text encoders from SD 1.X/2.X checkpoint.
|
||||
# import correct text encoder classes
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
|
||||
# 4. Load VAE from SD 1.X/2.X checkpoint
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_teacher_model,
|
||||
subfolder="vae",
|
||||
revision=args.teacher_revision,
|
||||
)
|
||||
|
||||
# 5. Load teacher U-Net from SD-XL checkpoint
|
||||
# 5. Load teacher U-Net from SD 1.X/2.X checkpoint
|
||||
teacher_unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
@@ -872,7 +897,7 @@ def main(args):
|
||||
text_encoder.requires_grad_(False)
|
||||
teacher_unet.requires_grad_(False)
|
||||
|
||||
# 7. Create online (`unet`) student U-Nets.
|
||||
# 7. Create online student U-Net.
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
@@ -935,6 +960,7 @@ def main(args):
|
||||
# Also move the alpha and sigma noise schedules to accelerator.device.
|
||||
alpha_schedule = alpha_schedule.to(accelerator.device)
|
||||
sigma_schedule = sigma_schedule.to(accelerator.device)
|
||||
# Move the ODE solver to accelerator.device.
|
||||
solver = solver.to(accelerator.device)
|
||||
|
||||
# 10. Handle saving and loading of checkpoints
|
||||
@@ -1011,13 +1037,14 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# 13. Dataset creation and data processing
|
||||
# Here, we compute not just the text embeddings but also the additional embeddings
|
||||
# needed for the SD XL UNet to operate.
|
||||
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
|
||||
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
dataset = Text2ImageDataset(
|
||||
dataset = SDText2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
@@ -1037,6 +1064,7 @@ def main(args):
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# 14. LR Scheduler creation
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
|
||||
@@ -1051,6 +1079,7 @@ def main(args):
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
# 15. Prepare for training
|
||||
# Prepare everything with our `accelerator`.
|
||||
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
|
||||
|
||||
@@ -1072,7 +1101,7 @@ def main(args):
|
||||
).input_ids.to(accelerator.device)
|
||||
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
|
||||
|
||||
# Train!
|
||||
# 16. Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
@@ -1123,6 +1152,7 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# 1. Load and process the image and text conditioning
|
||||
image, text = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
@@ -1140,37 +1170,37 @@ def main(args):
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
latents = latents.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
|
||||
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
||||
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
||||
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
||||
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
||||
start_timesteps = solver.ddim_timesteps[index]
|
||||
timesteps = start_timesteps - topk
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noise = torch.randn_like(latents)
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
||||
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max]
|
||||
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# 20.4.8. Prepare prompt embeds and unet_added_conditions
|
||||
# 6. Prepare prompt embeds and unet_added_conditions
|
||||
prompt_embeds = encoded_text.pop("prompt_embeds")
|
||||
|
||||
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
|
||||
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
@@ -1179,7 +1209,7 @@ def main(args):
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
pred_x_0 = predicted_origin(
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1190,17 +1220,27 @@ def main(args):
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
|
||||
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
|
||||
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and conditional embedding
|
||||
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
||||
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=prompt_embeds.to(weight_dtype),
|
||||
).sample
|
||||
cond_pred_x0 = predicted_origin(
|
||||
cond_pred_x0 = get_predicted_original_sample(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
cond_pred_noise = get_predicted_noise(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1209,13 +1249,21 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# Get teacher model prediction on noisy_latents and unconditional embedding
|
||||
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
||||
uncond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
||||
).sample
|
||||
uncond_pred_x0 = predicted_origin(
|
||||
uncond_pred_x0 = get_predicted_original_sample(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
uncond_pred_noise = get_predicted_noise(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1224,12 +1272,17 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
|
||||
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
||||
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
||||
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
||||
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
|
||||
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
||||
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
||||
# augmented PF-ODE trajectory (solving backward in time)
|
||||
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
|
||||
|
||||
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
target_noise_pred = unet(
|
||||
@@ -1238,7 +1291,7 @@ def main(args):
|
||||
timestep_cond=None,
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
).sample
|
||||
pred_x_0 = predicted_origin(
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
target_noise_pred,
|
||||
timesteps,
|
||||
x_prev,
|
||||
@@ -1248,7 +1301,7 @@ def main(args):
|
||||
)
|
||||
target = c_skip * x_prev + c_out * pred_x_0
|
||||
|
||||
# 20.4.13. Calculate loss
|
||||
# 10. Calculate loss
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
@@ -1256,7 +1309,7 @@ def main(args):
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
|
||||
# 20.4.14. Backpropagate on the online student model (`unet`)
|
||||
# 11. Backpropagate on the online student model (`unet`)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
@@ -1313,6 +1366,14 @@ def main(args):
|
||||
lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
|
||||
StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
|
||||
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
1358
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
Normal file
1358
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -39,7 +39,7 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from braceexpand import braceexpand
|
||||
from huggingface_hub import create_repo
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
|
||||
from torch.utils.data import default_collate
|
||||
@@ -162,7 +162,7 @@ class WebdatasetFilter:
|
||||
return False
|
||||
|
||||
|
||||
class Text2ImageDataset:
|
||||
class SDXLText2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
@@ -346,19 +346,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
|
||||
|
||||
|
||||
# Compare LCMScheduler.step, Step 4
|
||||
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
pred_x_0 = (sample - sigmas * model_output) / alphas
|
||||
elif prediction_type == "sample":
|
||||
pred_x_0 = model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
|
||||
pred_x_0 = alphas * sample - sigmas * model_output
|
||||
else:
|
||||
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_x_0
|
||||
|
||||
|
||||
# Based on step 4 in DDIMScheduler.step
|
||||
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
pred_epsilon = model_output
|
||||
elif prediction_type == "sample":
|
||||
pred_epsilon = (sample - alphas * model_output) / sigmas
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_epsilon = alphas * model_output + sigmas * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_epsilon
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
@@ -818,7 +842,7 @@ def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
create_repo(
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
||||
exist_ok=True,
|
||||
token=args.hub_token,
|
||||
@@ -830,9 +854,10 @@ def main(args):
|
||||
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# The scheduler calculates the alpha and sigma schedule for us
|
||||
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
||||
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
||||
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
||||
# Initialize the DDIM ODE solver for distillation.
|
||||
solver = DDIMSolver(
|
||||
noise_scheduler.alphas_cumprod.numpy(),
|
||||
timesteps=noise_scheduler.config.num_train_timesteps,
|
||||
@@ -886,7 +911,7 @@ def main(args):
|
||||
text_encoder_two.requires_grad_(False)
|
||||
teacher_unet.requires_grad_(False)
|
||||
|
||||
# 7. Create online (`unet`) student U-Nets.
|
||||
# 7. Create online student U-Net.
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
@@ -950,6 +975,7 @@ def main(args):
|
||||
# Also move the alpha and sigma noise schedules to accelerator.device.
|
||||
alpha_schedule = alpha_schedule.to(accelerator.device)
|
||||
sigma_schedule = sigma_schedule.to(accelerator.device)
|
||||
# Move the ODE solver to accelerator.device.
|
||||
solver = solver.to(accelerator.device)
|
||||
|
||||
# 10. Handle saving and loading of checkpoints
|
||||
@@ -1057,7 +1083,7 @@ def main(args):
|
||||
|
||||
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
|
||||
|
||||
dataset = Text2ImageDataset(
|
||||
dataset = SDXLText2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
@@ -1175,6 +1201,7 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)
|
||||
image, text, orig_size, crop_coords = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
@@ -1196,37 +1223,37 @@ def main(args):
|
||||
latents = latents * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
latents = latents.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
|
||||
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
||||
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
||||
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
||||
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
||||
start_timesteps = solver.ddim_timesteps[index]
|
||||
timesteps = start_timesteps - topk
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noise = torch.randn_like(latents)
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
||||
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max]
|
||||
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# 20.4.8. Prepare prompt embeds and unet_added_conditions
|
||||
# 6. Prepare prompt embeds and unet_added_conditions
|
||||
prompt_embeds = encoded_text.pop("prompt_embeds")
|
||||
|
||||
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
|
||||
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
@@ -1235,7 +1262,7 @@ def main(args):
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
pred_x_0 = predicted_origin(
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1246,18 +1273,28 @@ def main(args):
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
|
||||
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
|
||||
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and conditional embedding
|
||||
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
||||
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=prompt_embeds.to(weight_dtype),
|
||||
added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
|
||||
).sample
|
||||
cond_pred_x0 = predicted_origin(
|
||||
cond_pred_x0 = get_predicted_original_sample(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
cond_pred_noise = get_predicted_noise(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1266,7 +1303,7 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# Get teacher model prediction on noisy_latents and unconditional embedding
|
||||
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
||||
uncond_added_conditions = copy.deepcopy(encoded_text)
|
||||
uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
|
||||
uncond_teacher_output = teacher_unet(
|
||||
@@ -1275,7 +1312,15 @@ def main(args):
|
||||
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
||||
added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
|
||||
).sample
|
||||
uncond_pred_x0 = predicted_origin(
|
||||
uncond_pred_x0 = get_predicted_original_sample(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
uncond_pred_noise = get_predicted_noise(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1284,12 +1329,17 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
|
||||
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
||||
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
||||
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
||||
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
|
||||
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
||||
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
||||
# augmented PF-ODE trajectory (solving backward in time)
|
||||
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
|
||||
|
||||
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
|
||||
target_noise_pred = unet(
|
||||
@@ -1299,7 +1349,7 @@ def main(args):
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
pred_x_0 = predicted_origin(
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
target_noise_pred,
|
||||
timesteps,
|
||||
x_prev,
|
||||
@@ -1309,7 +1359,7 @@ def main(args):
|
||||
)
|
||||
target = c_skip * x_prev + c_out * pred_x_0
|
||||
|
||||
# 20.4.13. Calculate loss
|
||||
# 10. Calculate loss
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
@@ -1317,7 +1367,7 @@ def main(args):
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
|
||||
# 20.4.14. Backpropagate on the online student model (`unet`)
|
||||
# 11. Backpropagate on the online student model (`unet`)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
@@ -1374,6 +1424,14 @@ def main(args):
|
||||
lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
|
||||
StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
|
||||
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from braceexpand import braceexpand
|
||||
from huggingface_hub import create_repo
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from torch.utils.data import default_collate
|
||||
from torchvision import transforms
|
||||
@@ -138,7 +138,7 @@ class WebdatasetFilter:
|
||||
return False
|
||||
|
||||
|
||||
class Text2ImageDataset:
|
||||
class SDText2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
@@ -336,19 +336,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
|
||||
|
||||
|
||||
# Compare LCMScheduler.step, Step 4
|
||||
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
pred_x_0 = (sample - sigmas * model_output) / alphas
|
||||
elif prediction_type == "sample":
|
||||
pred_x_0 = model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
|
||||
pred_x_0 = alphas * sample - sigmas * model_output
|
||||
else:
|
||||
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_x_0
|
||||
|
||||
|
||||
# Based on step 4 in DDIMScheduler.step
|
||||
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
pred_epsilon = model_output
|
||||
elif prediction_type == "sample":
|
||||
pred_epsilon = (sample - alphas * model_output) / sigmas
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_epsilon = alphas * model_output + sigmas * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_epsilon
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
@@ -811,7 +835,7 @@ def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
create_repo(
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
||||
exist_ok=True,
|
||||
token=args.hub_token,
|
||||
@@ -823,34 +847,35 @@ def main(args):
|
||||
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# The scheduler calculates the alpha and sigma schedule for us
|
||||
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
||||
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
||||
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
||||
# Initialize the DDIM ODE solver for distillation.
|
||||
solver = DDIMSolver(
|
||||
noise_scheduler.alphas_cumprod.numpy(),
|
||||
timesteps=noise_scheduler.config.num_train_timesteps,
|
||||
ddim_timesteps=args.num_ddim_timesteps,
|
||||
)
|
||||
|
||||
# 2. Load tokenizers from SD-XL checkpoint.
|
||||
# 2. Load tokenizers from SD 1.X/2.X checkpoint.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
|
||||
)
|
||||
|
||||
# 3. Load text encoders from SD-1.5 checkpoint.
|
||||
# 3. Load text encoders from SD 1.X/2.X checkpoint.
|
||||
# import correct text encoder classes
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
|
||||
# 4. Load VAE from SD 1.X/2.X checkpoint
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_teacher_model,
|
||||
subfolder="vae",
|
||||
revision=args.teacher_revision,
|
||||
)
|
||||
|
||||
# 5. Load teacher U-Net from SD-XL checkpoint
|
||||
# 5. Load teacher U-Net from SD 1.X/2.X checkpoint
|
||||
teacher_unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
|
||||
)
|
||||
@@ -860,17 +885,18 @@ def main(args):
|
||||
text_encoder.requires_grad_(False)
|
||||
teacher_unet.requires_grad_(False)
|
||||
|
||||
# 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
|
||||
if teacher_unet.config.time_cond_proj_dim is None:
|
||||
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
|
||||
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
|
||||
unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
# load teacher_unet weights into unet
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
unet.train()
|
||||
|
||||
# 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from unet
|
||||
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from (online) unet
|
||||
target_unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
target_unet.load_state_dict(unet.state_dict())
|
||||
target_unet.train()
|
||||
@@ -887,7 +913,7 @@ def main(args):
|
||||
f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
|
||||
)
|
||||
|
||||
# 10. Handle mixed precision and device placement
|
||||
# 9. Handle mixed precision and device placement
|
||||
# For mixed precision training we cast all non-trainable weigths to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
@@ -914,7 +940,7 @@ def main(args):
|
||||
sigma_schedule = sigma_schedule.to(accelerator.device)
|
||||
solver = solver.to(accelerator.device)
|
||||
|
||||
# 11. Handle saving and loading of checkpoints
|
||||
# 10. Handle saving and loading of checkpoints
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -948,7 +974,7 @@ def main(args):
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# 12. Enable optimizations
|
||||
# 11. Enable optimizations
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
@@ -994,13 +1020,14 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# 13. Dataset creation and data processing
|
||||
# Here, we compute not just the text embeddings but also the additional embeddings
|
||||
# needed for the SD XL UNet to operate.
|
||||
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
|
||||
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
dataset = Text2ImageDataset(
|
||||
dataset = SDText2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
@@ -1020,6 +1047,7 @@ def main(args):
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# 14. LR Scheduler creation
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
|
||||
@@ -1034,6 +1062,7 @@ def main(args):
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
# 15. Prepare for training
|
||||
# Prepare everything with our `accelerator`.
|
||||
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
|
||||
|
||||
@@ -1055,7 +1084,7 @@ def main(args):
|
||||
).input_ids.to(accelerator.device)
|
||||
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
|
||||
|
||||
# Train!
|
||||
# 16. Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
@@ -1106,6 +1135,7 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# 1. Load and process the image and text conditioning
|
||||
image, text = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
@@ -1123,40 +1153,39 @@ def main(args):
|
||||
|
||||
latents = latents * vae.config.scaling_factor
|
||||
latents = latents.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
|
||||
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
||||
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
||||
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
||||
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
||||
start_timesteps = solver.ddim_timesteps[index]
|
||||
timesteps = start_timesteps - topk
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noise = torch.randn_like(latents)
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
||||
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
# Move to U-Net device and dtype
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# 20.4.8. Prepare prompt embeds and unet_added_conditions
|
||||
# 6. Prepare prompt embeds and unet_added_conditions
|
||||
prompt_embeds = encoded_text.pop("prompt_embeds")
|
||||
|
||||
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
|
||||
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
@@ -1165,7 +1194,7 @@ def main(args):
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
pred_x_0 = predicted_origin(
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1176,17 +1205,27 @@ def main(args):
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
|
||||
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
|
||||
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and conditional embedding
|
||||
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
||||
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=prompt_embeds.to(weight_dtype),
|
||||
).sample
|
||||
cond_pred_x0 = predicted_origin(
|
||||
cond_pred_x0 = get_predicted_original_sample(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
cond_pred_noise = get_predicted_noise(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1195,13 +1234,21 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# Get teacher model prediction on noisy_latents and unconditional embedding
|
||||
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
||||
uncond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
||||
).sample
|
||||
uncond_pred_x0 = predicted_origin(
|
||||
uncond_pred_x0 = get_predicted_original_sample(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
uncond_pred_noise = get_predicted_noise(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1210,12 +1257,16 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
|
||||
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
||||
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
||||
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
||||
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
|
||||
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
||||
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
||||
# augmented PF-ODE trajectory (solving backward in time)
|
||||
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
|
||||
|
||||
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
target_noise_pred = target_unet(
|
||||
@@ -1224,7 +1275,7 @@ def main(args):
|
||||
timestep_cond=w_embedding,
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
).sample
|
||||
pred_x_0 = predicted_origin(
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
target_noise_pred,
|
||||
timesteps,
|
||||
x_prev,
|
||||
@@ -1234,7 +1285,7 @@ def main(args):
|
||||
)
|
||||
target = c_skip * x_prev + c_out * pred_x_0
|
||||
|
||||
# 20.4.13. Calculate loss
|
||||
# 10. Calculate loss
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
@@ -1242,7 +1293,7 @@ def main(args):
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
|
||||
# 20.4.14. Backpropagate on the online student model (`unet`)
|
||||
# 11. Backpropagate on the online student model (`unet`)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
@@ -1252,7 +1303,7 @@ def main(args):
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
# 20.4.15. Make EMA update to target student model parameters
|
||||
# 12. Make EMA update to target student model parameters (`target_unet`)
|
||||
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
@@ -1303,6 +1354,14 @@ def main(args):
|
||||
target_unet = accelerator.unwrap_model(target_unet)
|
||||
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
|
||||
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from braceexpand import braceexpand
|
||||
from huggingface_hub import create_repo
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from torch.utils.data import default_collate
|
||||
from torchvision import transforms
|
||||
@@ -144,7 +144,7 @@ class WebdatasetFilter:
|
||||
return False
|
||||
|
||||
|
||||
class Text2ImageDataset:
|
||||
class SDXLText2ImageDataset:
|
||||
def __init__(
|
||||
self,
|
||||
train_shards_path_or_url: Union[str, List[str]],
|
||||
@@ -324,19 +324,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=
|
||||
|
||||
|
||||
# Compare LCMScheduler.step, Step 4
|
||||
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
pred_x_0 = (sample - sigmas * model_output) / alphas
|
||||
elif prediction_type == "sample":
|
||||
pred_x_0 = model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
|
||||
pred_x_0 = alphas * sample - sigmas * model_output
|
||||
else:
|
||||
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_x_0
|
||||
|
||||
|
||||
# Based on step 4 in DDIMScheduler.step
|
||||
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
||||
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
||||
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
||||
if prediction_type == "epsilon":
|
||||
pred_epsilon = model_output
|
||||
elif prediction_type == "sample":
|
||||
pred_epsilon = (sample - alphas * model_output) / sigmas
|
||||
elif prediction_type == "v_prediction":
|
||||
pred_epsilon = alphas * model_output + sigmas * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
|
||||
f" are supported."
|
||||
)
|
||||
|
||||
return pred_epsilon
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
@@ -851,7 +875,7 @@ def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
create_repo(
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name,
|
||||
exist_ok=True,
|
||||
token=args.hub_token,
|
||||
@@ -863,9 +887,10 @@ def main(args):
|
||||
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
|
||||
)
|
||||
|
||||
# The scheduler calculates the alpha and sigma schedule for us
|
||||
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
|
||||
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
|
||||
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
|
||||
# Initialize the DDIM ODE solver for distillation.
|
||||
solver = DDIMSolver(
|
||||
noise_scheduler.alphas_cumprod.numpy(),
|
||||
timesteps=noise_scheduler.config.num_train_timesteps,
|
||||
@@ -919,17 +944,18 @@ def main(args):
|
||||
text_encoder_two.requires_grad_(False)
|
||||
teacher_unet.requires_grad_(False)
|
||||
|
||||
# 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
|
||||
if teacher_unet.config.time_cond_proj_dim is None:
|
||||
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
|
||||
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
|
||||
unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
# load teacher_unet weights into unet
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
unet.train()
|
||||
|
||||
# 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from unet
|
||||
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from (online) unet
|
||||
target_unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
target_unet.load_state_dict(unet.state_dict())
|
||||
target_unet.train()
|
||||
@@ -971,6 +997,7 @@ def main(args):
|
||||
# Also move the alpha and sigma noise schedules to accelerator.device.
|
||||
alpha_schedule = alpha_schedule.to(accelerator.device)
|
||||
sigma_schedule = sigma_schedule.to(accelerator.device)
|
||||
# Move the ODE solver to accelerator.device.
|
||||
solver = solver.to(accelerator.device)
|
||||
|
||||
# 10. Handle saving and loading of checkpoints
|
||||
@@ -1084,7 +1111,7 @@ def main(args):
|
||||
|
||||
return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
|
||||
|
||||
dataset = Text2ImageDataset(
|
||||
dataset = SDXLText2ImageDataset(
|
||||
train_shards_path_or_url=args.train_shards_path_or_url,
|
||||
num_train_examples=args.max_train_samples,
|
||||
per_gpu_batch_size=args.train_batch_size,
|
||||
@@ -1202,6 +1229,7 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)
|
||||
image, text, orig_size, crop_coords = batch
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
@@ -1223,38 +1251,39 @@ def main(args):
|
||||
latents = latents * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
latents = latents.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
|
||||
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
|
||||
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
|
||||
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
|
||||
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
|
||||
start_timesteps = solver.ddim_timesteps[index]
|
||||
timesteps = start_timesteps - topk
|
||||
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
|
||||
|
||||
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
|
||||
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
|
||||
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
|
||||
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
|
||||
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
|
||||
|
||||
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
|
||||
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
|
||||
noise = torch.randn_like(latents)
|
||||
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
|
||||
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
# 5. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
# Move to U-Net device and dtype
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# 20.4.8. Prepare prompt embeds and unet_added_conditions
|
||||
# 6. Prepare prompt embeds and unet_added_conditions
|
||||
prompt_embeds = encoded_text.pop("prompt_embeds")
|
||||
|
||||
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
|
||||
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
@@ -1263,7 +1292,7 @@ def main(args):
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
pred_x_0 = predicted_origin(
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
noise_pred,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1274,18 +1303,28 @@ def main(args):
|
||||
|
||||
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
|
||||
|
||||
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
|
||||
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
|
||||
# Get teacher model prediction on noisy_latents and conditional embedding
|
||||
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
|
||||
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
start_timesteps,
|
||||
encoder_hidden_states=prompt_embeds.to(weight_dtype),
|
||||
added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
|
||||
).sample
|
||||
cond_pred_x0 = predicted_origin(
|
||||
cond_pred_x0 = get_predicted_original_sample(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
cond_pred_noise = get_predicted_noise(
|
||||
cond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1294,7 +1333,7 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# Get teacher model prediction on noisy_latents and unconditional embedding
|
||||
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
|
||||
uncond_added_conditions = copy.deepcopy(encoded_text)
|
||||
uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
|
||||
uncond_teacher_output = teacher_unet(
|
||||
@@ -1303,7 +1342,15 @@ def main(args):
|
||||
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
|
||||
added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
|
||||
).sample
|
||||
uncond_pred_x0 = predicted_origin(
|
||||
uncond_pred_x0 = get_predicted_original_sample(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
noise_scheduler.config.prediction_type,
|
||||
alpha_schedule,
|
||||
sigma_schedule,
|
||||
)
|
||||
uncond_pred_noise = get_predicted_noise(
|
||||
uncond_teacher_output,
|
||||
start_timesteps,
|
||||
noisy_model_input,
|
||||
@@ -1312,12 +1359,16 @@ def main(args):
|
||||
sigma_schedule,
|
||||
)
|
||||
|
||||
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
|
||||
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
|
||||
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
|
||||
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
|
||||
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
|
||||
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
|
||||
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
|
||||
# augmented PF-ODE trajectory (solving backward in time)
|
||||
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
|
||||
x_prev = solver.ddim_step(pred_x0, pred_noise, index)
|
||||
|
||||
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
with torch.no_grad():
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
target_noise_pred = target_unet(
|
||||
@@ -1327,7 +1378,7 @@ def main(args):
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
pred_x_0 = predicted_origin(
|
||||
pred_x_0 = get_predicted_original_sample(
|
||||
target_noise_pred,
|
||||
timesteps,
|
||||
x_prev,
|
||||
@@ -1337,7 +1388,7 @@ def main(args):
|
||||
)
|
||||
target = c_skip * x_prev + c_out * pred_x_0
|
||||
|
||||
# 20.4.13. Calculate loss
|
||||
# 10. Calculate loss
|
||||
if args.loss_type == "l2":
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
elif args.loss_type == "huber":
|
||||
@@ -1345,7 +1396,7 @@ def main(args):
|
||||
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
|
||||
)
|
||||
|
||||
# 20.4.14. Backpropagate on the online student model (`unet`)
|
||||
# 11. Backpropagate on the online student model (`unet`)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
@@ -1355,7 +1406,7 @@ def main(args):
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
# 20.4.15. Make EMA update to target student model parameters
|
||||
# 12. Make EMA update to target student model parameters (`target_unet`)
|
||||
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
@@ -1406,6 +1457,14 @@ def main(args):
|
||||
target_unet = accelerator.unwrap_model(target_unet)
|
||||
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
|
||||
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ class ControlNet(ExamplesTestsAccelerate):
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
|
||||
--max_train_steps=9
|
||||
--max_train_steps=6
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
@@ -73,7 +73,7 @@ class ControlNet(ExamplesTestsAccelerate):
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -85,18 +85,15 @@ class ControlNet(ExamplesTestsAccelerate):
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
|
||||
--max_train_steps=11
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
--resume_from_checkpoint=checkpoint-6
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
|
||||
)
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
|
||||
class ControlNetSDXL(ExamplesTestsAccelerate):
|
||||
@@ -111,7 +108,7 @@ class ControlNetSDXL(ExamplesTestsAccelerate):
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
|
||||
--max_train_steps=9
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
|
||||
@@ -76,10 +76,7 @@ class CustomDiffusion(ExamplesTestsAccelerate):
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
|
||||
def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -93,7 +90,7 @@ class CustomDiffusion(ExamplesTestsAccelerate):
|
||||
--train_batch_size=1
|
||||
--modifier_token=<new1>
|
||||
--dataloader_num_workers=0
|
||||
--max_train_steps=9
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--no_safe_serialization
|
||||
""".split()
|
||||
@@ -102,7 +99,7 @@ class CustomDiffusion(ExamplesTestsAccelerate):
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -115,16 +112,13 @@ class CustomDiffusion(ExamplesTestsAccelerate):
|
||||
--train_batch_size=1
|
||||
--modifier_token=<new1>
|
||||
--dataloader_num_workers=0
|
||||
--max_train_steps=11
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--no_safe_serialization
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
@@ -89,7 +89,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 5, checkpointing_steps == 2
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
@@ -100,7 +100,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 5
|
||||
--max_train_steps 4
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -114,7 +114,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
|
||||
# check can run the original fully trained output pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(instance_prompt, num_inference_steps=2)
|
||||
pipe(instance_prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
|
||||
@@ -123,7 +123,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
# check can run an intermediate checkpoint
|
||||
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
|
||||
pipe(instance_prompt, num_inference_steps=2)
|
||||
pipe(instance_prompt, num_inference_steps=1)
|
||||
|
||||
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
|
||||
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
|
||||
@@ -138,7 +138,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--max_train_steps 6
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -153,7 +153,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
|
||||
# check can run new fully trained pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(instance_prompt, num_inference_steps=2)
|
||||
pipe(instance_prompt, num_inference_steps=1)
|
||||
|
||||
# check old checkpoints do not exist
|
||||
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
|
||||
@@ -196,7 +196,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=9
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
@@ -204,7 +204,7 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -216,15 +216,12 @@ class DreamBooth(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=11
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
@@ -135,16 +135,13 @@ class DreamBoothLoRA(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=9
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
@@ -155,18 +152,15 @@ class DreamBoothLoRA(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=11
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_if_model(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -328,7 +322,7 @@ class DreamBoothLoRASDXL(ExamplesTestsAccelerate):
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--max_train_steps 6
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
--learning_rate 5.0e-04
|
||||
@@ -342,14 +336,11 @@ class DreamBoothLoRASDXL(ExamplesTestsAccelerate):
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe("a prompt", num_inference_steps=2)
|
||||
pipe("a prompt", num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
# checkpoint-2 should have been deleted
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
|
||||
def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
|
||||
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
|
||||
@@ -64,39 +64,6 @@ check_min_version("0.25.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
@@ -860,6 +827,7 @@ def main(args):
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
|
||||
)
|
||||
@@ -868,7 +836,10 @@ def main(args):
|
||||
# The text encoder comes from 🤗 transformers, we will also attach adapters to it.
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
text_encoder.add_adapter(text_lora_config)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ from diffusers import (
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -64,39 +64,6 @@ check_min_version("0.25.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
@@ -1011,7 +978,10 @@ def main(args):
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
unet.add_adapter(unet_lora_config)
|
||||
|
||||
@@ -1019,11 +989,25 @@ def main(args):
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
@@ -1035,11 +1019,15 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1166,10 +1154,26 @@ def main(args):
|
||||
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
if args.learning_rate <= 0.1:
|
||||
logger.warn(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warn(
|
||||
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
)
|
||||
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
|
||||
# --learning_rate
|
||||
params_to_optimize[1]["lr"] = args.learning_rate
|
||||
params_to_optimize[2]["lr"] = args.learning_rate
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
decouple=args.prodigy_decouple,
|
||||
@@ -1615,13 +1619,17 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = get_peft_model_state_dict(unet)
|
||||
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
)
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
)
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
|
||||
@@ -71,7 +71,7 @@ accelerate launch train_instruct_pix2pix_sdxl.py \
|
||||
|
||||
We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`.
|
||||
|
||||
[Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters.
|
||||
[Here](https://wandb.ai/sayakpaul/instruct-pix2pix-sdxl-new/runs/sw53gxmc), you can find an example training run that includes some validation samples and the training hyperparameters.
|
||||
|
||||
***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--random_flip
|
||||
--train_batch_size=1
|
||||
--max_train_steps=7
|
||||
--max_train_steps=6
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=2
|
||||
--output_dir {tmpdir}
|
||||
@@ -63,7 +63,7 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--random_flip
|
||||
--train_batch_size=1
|
||||
--max_train_steps=9
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--output_dir {tmpdir}
|
||||
--seed=0
|
||||
@@ -74,7 +74,7 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -84,12 +84,12 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
|
||||
--resolution=64
|
||||
--random_flip
|
||||
--train_batch_size=1
|
||||
--max_train_steps=11
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--output_dir {tmpdir}
|
||||
--seed=0
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
@@ -97,5 +97,5 @@ class InstructPix2Pix(ExamplesTestsAccelerate):
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
{"checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
@@ -1,15 +1,3 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet-XS
|
||||
|
||||
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
|
||||
@@ -24,16 +12,5 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
|
||||
|
||||
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionControlNetXSPipeline
|
||||
[[autodoc]] StableDiffusionControlNetXSPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
@@ -1,15 +1,3 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet-XS with Stable Diffusion XL
|
||||
|
||||
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
|
||||
@@ -24,22 +12,4 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
|
||||
|
||||
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionXLControlNetXSPipeline
|
||||
[[autodoc]] StableDiffusionXLControlNetXSPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
@@ -21,15 +21,12 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.modules.normalization import GroupNorm
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
AttentionProcessor,
|
||||
)
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .lora import LoRACompatibleConv
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import (
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProcessor
|
||||
from diffusers.models.autoencoders import AutoencoderKL
|
||||
from diffusers.models.lora import LoRACompatibleConv
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
@@ -39,7 +36,8 @@ from .unet_2d_blocks import (
|
||||
UpBlock2D,
|
||||
Upsample2D,
|
||||
)
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -817,11 +815,23 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
|
||||
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
|
||||
norm_kwargs["num_channels"] += by # surgery done here
|
||||
# conv1
|
||||
conv1_args = (
|
||||
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
|
||||
)
|
||||
conv1_args = [
|
||||
"in_channels",
|
||||
"out_channels",
|
||||
"kernel_size",
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"groups",
|
||||
"bias",
|
||||
"padding_mode",
|
||||
]
|
||||
if not USE_PEFT_BACKEND:
|
||||
conv1_args.append("lora_layer")
|
||||
|
||||
for a in conv1_args:
|
||||
assert hasattr(old_conv1, a)
|
||||
|
||||
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
|
||||
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
|
||||
conv1_kwargs["in_channels"] += by # surgery done here
|
||||
@@ -839,25 +849,42 @@ def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no,
|
||||
}
|
||||
# swap old with new modules
|
||||
unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs)
|
||||
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = LoRACompatibleConv(**conv1_kwargs)
|
||||
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
|
||||
unet.down_blocks[block_no].resnets[resnet_idx].conv1 = (
|
||||
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
|
||||
)
|
||||
unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = (
|
||||
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
|
||||
)
|
||||
unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here
|
||||
|
||||
|
||||
def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by):
|
||||
"""Increase channels sizes to allow for additional concatted information from base model"""
|
||||
old_down = unet.down_blocks[block_no].downsamplers[0].conv
|
||||
# conv1
|
||||
args = "in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(
|
||||
" "
|
||||
)
|
||||
|
||||
args = [
|
||||
"in_channels",
|
||||
"out_channels",
|
||||
"kernel_size",
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"groups",
|
||||
"bias",
|
||||
"padding_mode",
|
||||
]
|
||||
if not USE_PEFT_BACKEND:
|
||||
args.append("lora_layer")
|
||||
|
||||
for a in args:
|
||||
assert hasattr(old_down, a)
|
||||
kwargs = {a: getattr(old_down, a) for a in args}
|
||||
kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor.
|
||||
kwargs["in_channels"] += by # surgery done here
|
||||
# swap old with new modules
|
||||
unet.down_blocks[block_no].downsamplers[0].conv = LoRACompatibleConv(**kwargs)
|
||||
unet.down_blocks[block_no].downsamplers[0].conv = (
|
||||
nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs)
|
||||
)
|
||||
unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here
|
||||
|
||||
|
||||
@@ -871,12 +898,20 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
|
||||
assert hasattr(old_norm1, a)
|
||||
norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
|
||||
norm_kwargs["num_channels"] += by # surgery done here
|
||||
# conv1
|
||||
conv1_args = (
|
||||
"in_channels out_channels kernel_size stride padding dilation groups bias padding_mode lora_layer".split(" ")
|
||||
)
|
||||
for a in conv1_args:
|
||||
assert hasattr(old_conv1, a)
|
||||
conv1_args = [
|
||||
"in_channels",
|
||||
"out_channels",
|
||||
"kernel_size",
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"groups",
|
||||
"bias",
|
||||
"padding_mode",
|
||||
]
|
||||
if not USE_PEFT_BACKEND:
|
||||
conv1_args.append("lora_layer")
|
||||
|
||||
conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
|
||||
conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
|
||||
conv1_kwargs["in_channels"] += by # surgery done here
|
||||
@@ -894,8 +929,12 @@ def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
|
||||
}
|
||||
# swap old with new modules
|
||||
unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs)
|
||||
unet.mid_block.resnets[0].conv1 = LoRACompatibleConv(**conv1_kwargs)
|
||||
unet.mid_block.resnets[0].conv_shortcut = LoRACompatibleConv(**conv_shortcut_args_kwargs)
|
||||
unet.mid_block.resnets[0].conv1 = (
|
||||
nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
|
||||
)
|
||||
unet.mid_block.resnets[0].conv_shortcut = (
|
||||
nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
|
||||
)
|
||||
unet.mid_block.resnets[0].in_channels += by # surgery done here
|
||||
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
# !pip install opencv-python transformers accelerate
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnetxs import ControlNetXSModel
|
||||
from PIL import Image
|
||||
from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
|
||||
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
)
|
||||
parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
|
||||
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
|
||||
parser.add_argument(
|
||||
"--image_path",
|
||||
type=str,
|
||||
default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
|
||||
)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prompt = args.prompt
|
||||
negative_prompt = args.negative_prompt
|
||||
# download an image
|
||||
image = load_image(args.image_path)
|
||||
|
||||
# initialize the models and pipeline
|
||||
controlnet_conditioning_scale = args.controlnet_conditioning_scale
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# get canny image
|
||||
image = np.array(image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
|
||||
num_inference_steps = args.num_inference_steps
|
||||
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
image=canny_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
).images[0]
|
||||
image.save("cnxs_sd.canny.png")
|
||||
@@ -0,0 +1,57 @@
|
||||
# !pip install opencv-python transformers accelerate
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnetxs import ControlNetXSModel
|
||||
from PIL import Image
|
||||
from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
|
||||
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
)
|
||||
parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
|
||||
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
|
||||
parser.add_argument(
|
||||
"--image_path",
|
||||
type=str,
|
||||
default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
|
||||
)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prompt = args.prompt
|
||||
negative_prompt = args.negative_prompt
|
||||
# download an image
|
||||
image = load_image(args.image_path)
|
||||
# initialize the models and pipeline
|
||||
controlnet_conditioning_scale = args.controlnet_conditioning_scale
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# get canny image
|
||||
image = np.array(image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
|
||||
num_inference_steps = args.num_inference_steps
|
||||
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
image=canny_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
).images[0]
|
||||
image.save("cnxs_sdxl.canny.png")
|
||||
@@ -19,74 +19,30 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from controlnetxs import ControlNetXSModel
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> # !pip install opencv-python transformers accelerate
|
||||
>>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSModel
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
|
||||
>>> import cv2
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
>>> negative_prompt = "low quality, bad quality, sketches"
|
||||
|
||||
>>> # download an image
|
||||
>>> image = load_image(
|
||||
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
|
||||
... )
|
||||
|
||||
>>> # initialize the models and pipeline
|
||||
>>> controlnet_conditioning_scale = 0.5
|
||||
>>> controlnet = ControlNetXSModel.from_pretrained(
|
||||
... "UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> # get canny image
|
||||
>>> image = np.array(image)
|
||||
>>> image = cv2.Canny(image, 100, 200)
|
||||
>>> image = image[:, :, None]
|
||||
>>> image = np.concatenate([image, image, image], axis=2)
|
||||
>>> canny_image = Image.fromarray(image)
|
||||
>>> # generate image
|
||||
>>> image = pipe(
|
||||
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
|
||||
... ).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionControlNetXSPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
@@ -98,7 +54,9 @@ class StableDiffusionControlNetXSPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -667,7 +625,6 @@ class StableDiffusionControlNetXSPipeline(
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -21,76 +21,36 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers.utils.import_utils import is_invisible_watermark_available
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_invisible_watermark_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> # !pip install opencv-python transformers accelerate
|
||||
>>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSModel, AutoencoderKL
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
|
||||
>>> import cv2
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
>>> negative_prompt = "low quality, bad quality, sketches"
|
||||
|
||||
>>> # download an image
|
||||
>>> image = load_image(
|
||||
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
|
||||
... )
|
||||
|
||||
>>> # initialize the models and pipeline
|
||||
>>> controlnet_conditioning_scale = 0.5 # recommended for good generalization
|
||||
>>> controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16)
|
||||
>>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
||||
>>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> # get canny image
|
||||
>>> image = np.array(image)
|
||||
>>> image = cv2.Canny(image, 100, 200)
|
||||
>>> image = image[:, :, None]
|
||||
>>> image = np.concatenate([image, image, image], axis=2)
|
||||
>>> canny_image = Image.fromarray(image)
|
||||
|
||||
>>> # generate image
|
||||
>>> image = pipe(
|
||||
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
|
||||
... ).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetXSPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
@@ -102,8 +62,9 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -729,7 +690,6 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -1,6 +1,6 @@
|
||||
diffusers==0.20.1
|
||||
accelerate==0.23.0
|
||||
transformers==4.34.0
|
||||
transformers==4.36.0
|
||||
peft==0.5.0
|
||||
torch==2.0.1
|
||||
torchvision>=0.16
|
||||
|
||||
@@ -101,8 +101,8 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
|
||||
|
||||
Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
model_path = "path_to_saved_model"
|
||||
@@ -114,12 +114,13 @@ image.save("yoda-pokemon.png")
|
||||
```
|
||||
|
||||
Checkpoints only save the unet, so to run inference from a checkpoint, just load the unet
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
||||
|
||||
model_path = "path_to_saved_model"
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-<N>/unet")
|
||||
unet = UNet2DConditionModel.from_pretrained(model_path + "/checkpoint-<N>/unet", torch_dtype=torch.float16)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("<initial model>", unet=unet, torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
|
||||
@@ -64,7 +64,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 5, checkpointing_steps == 2
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
@@ -76,7 +76,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 5
|
||||
--max_train_steps 4
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -89,7 +89,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
@@ -100,12 +100,12 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
# check can run an intermediate checkpoint
|
||||
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
|
||||
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
|
||||
|
||||
# Run training script for 7 total steps resuming from checkpoint 4
|
||||
# Run training script for 2 total steps resuming from checkpoint 4
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/text_to_image/train_text_to_image.py
|
||||
@@ -116,13 +116,13 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--checkpointing_steps=1
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--seed=0
|
||||
""".split()
|
||||
@@ -131,16 +131,13 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
|
||||
# check can run new fully trained pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# no checkpoint-2 -> check old checkpoints do not exist
|
||||
# check new checkpoints exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{
|
||||
# no checkpoint-2 -> check old checkpoints do not exist
|
||||
# check new checkpoints exist
|
||||
"checkpoint-4",
|
||||
"checkpoint-6",
|
||||
},
|
||||
{"checkpoint-4", "checkpoint-5"},
|
||||
)
|
||||
|
||||
def test_text_to_image_checkpointing_use_ema(self):
|
||||
@@ -149,7 +146,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 5, checkpointing_steps == 2
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
@@ -161,7 +158,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 5
|
||||
--max_train_steps 4
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -186,12 +183,12 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
# check can run an intermediate checkpoint
|
||||
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
|
||||
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
|
||||
|
||||
# Run training script for 7 total steps resuming from checkpoint 4
|
||||
# Run training script for 2 total steps resuming from checkpoint 4
|
||||
|
||||
resume_run_args = f"""
|
||||
examples/text_to_image/train_text_to_image.py
|
||||
@@ -202,13 +199,13 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--checkpointing_steps=1
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--use_ema
|
||||
--seed=0
|
||||
@@ -218,16 +215,13 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
|
||||
# check can run new fully trained pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# no checkpoint-2 -> check old checkpoints do not exist
|
||||
# check new checkpoints exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{
|
||||
# no checkpoint-2 -> check old checkpoints do not exist
|
||||
# check new checkpoints exist
|
||||
"checkpoint-4",
|
||||
"checkpoint-6",
|
||||
},
|
||||
{"checkpoint-4", "checkpoint-5"},
|
||||
)
|
||||
|
||||
def test_text_to_image_checkpointing_checkpoints_total_limit(self):
|
||||
@@ -236,7 +230,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
@@ -249,7 +243,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--max_train_steps 6
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -263,14 +257,11 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
# checkpoint-2 should have been deleted
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
|
||||
def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
|
||||
@@ -278,8 +269,8 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 9, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4, 6, 8
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/text_to_image/train_text_to_image.py
|
||||
@@ -290,7 +281,7 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 9
|
||||
--max_train_steps 4
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -303,15 +294,15 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
)
|
||||
|
||||
# resume and we should try to checkpoint at 10, where we'll have to remove
|
||||
# resume and we should try to checkpoint at 6, where we'll have to remove
|
||||
# checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -323,27 +314,27 @@ class TextToImage(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 11
|
||||
--max_train_steps 8
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
{"checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
@@ -52,7 +52,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--max_train_steps 6
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -66,14 +66,11 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
# checkpoint-2 should have been deleted
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
|
||||
def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
|
||||
@@ -81,7 +78,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
@@ -94,7 +91,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--max_train_steps 6
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -112,14 +109,11 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
|
||||
)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
# checkpoint-2 should have been deleted
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
|
||||
def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
|
||||
@@ -127,8 +121,8 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 9, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4, 6, 8
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
examples/text_to_image/train_text_to_image_lora.py
|
||||
@@ -139,7 +133,7 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 9
|
||||
--max_train_steps 4
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -156,15 +150,15 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
|
||||
)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
)
|
||||
|
||||
# resume and we should try to checkpoint at 10, where we'll have to remove
|
||||
# resume and we should try to checkpoint at 6, where we'll have to remove
|
||||
# checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -176,15 +170,15 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
--random_flip
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 11
|
||||
--max_train_steps 8
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-8
|
||||
--checkpoints_total_limit=3
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--seed=0
|
||||
--num_validation_images=0
|
||||
""".split()
|
||||
@@ -195,12 +189,12 @@ class TextToImageLoRA(ExamplesTestsAccelerate):
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
|
||||
)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
{"checkpoint-6", "checkpoint-8"},
|
||||
)
|
||||
|
||||
|
||||
@@ -272,7 +266,7 @@ class TextToImageLoRASDXL(ExamplesTestsAccelerate):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
|
||||
# Should create checkpoints at steps 2, 4, 6
|
||||
# with checkpoint at step 2 deleted
|
||||
|
||||
@@ -283,7 +277,7 @@ class TextToImageLoRASDXL(ExamplesTestsAccelerate):
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 7
|
||||
--max_train_steps 6
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -298,11 +292,8 @@ class TextToImageLoRASDXL(ExamplesTestsAccelerate):
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
pipe(prompt, num_inference_steps=2)
|
||||
pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
# checkpoint-2 should have been deleted
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
# checkpoint-2 should have been deleted
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
|
||||
|
||||
@@ -54,39 +54,6 @@ check_min_version("0.25.0.dev0")
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
@@ -485,7 +452,10 @@ def main():
|
||||
param.requires_grad_(False)
|
||||
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
@@ -493,7 +463,13 @@ def main():
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Add adapter and make sure the trainable params are in float32.
|
||||
unet.add_adapter(unet_lora_config)
|
||||
if args.mixed_precision == "fp16":
|
||||
for param in unet.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
@@ -832,7 +808,8 @@ def main():
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unwrapped_unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
|
||||
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
@@ -870,10 +847,11 @@ def main():
|
||||
if args.seed is not None:
|
||||
generator = generator.manual_seed(args.seed)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(
|
||||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
||||
)
|
||||
with torch.cuda.amp.autocast():
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(
|
||||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
@@ -897,7 +875,8 @@ def main():
|
||||
if accelerator.is_main_process:
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unwrapped_unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
@@ -919,40 +898,46 @@ def main():
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
if args.validation_prompt is not None:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
|
||||
# load attention processors
|
||||
pipeline.unet.load_attn_procs(args.output_dir)
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device)
|
||||
if args.seed is not None:
|
||||
generator = generator.manual_seed(args.seed)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
|
||||
|
||||
if accelerator.is_main_process:
|
||||
for tracker in accelerator.trackers:
|
||||
if len(images) != 0:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"test": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device)
|
||||
if args.seed is not None:
|
||||
generator = generator.manual_seed(args.seed)
|
||||
images = []
|
||||
with torch.cuda.amp.autocast():
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(
|
||||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if len(images) != 0:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"test": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ import os
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -63,39 +62,6 @@ check_min_version("0.25.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
@@ -469,22 +435,6 @@ DATASET_NAME_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
|
||||
"""
|
||||
Returns:
|
||||
a state dict containing just the attention processor parameters.
|
||||
"""
|
||||
attn_processors = unet.attn_processors
|
||||
|
||||
attn_processors_state_dict = {}
|
||||
|
||||
for attn_processor_key, attn_processor in attn_processors.items():
|
||||
for parameter_key, parameter in attn_processor.state_dict().items():
|
||||
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
|
||||
|
||||
return attn_processors_state_dict
|
||||
|
||||
|
||||
def tokenize_prompt(tokenizer, prompt):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
@@ -659,7 +609,10 @@ def main(args):
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
# Set correct lora layers
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
|
||||
unet.add_adapter(unet_lora_config)
|
||||
@@ -668,11 +621,25 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
@@ -1220,6 +1187,9 @@ def main(args):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Final inference
|
||||
# Make sure vae.dtype is consistent with the unet.dtype
|
||||
if args.mixed_precision == "fp16":
|
||||
vae.to(weight_dtype)
|
||||
# Load previous pipeline
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
|
||||
@@ -40,8 +40,6 @@ class TextualInversion(ExamplesTestsAccelerate):
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--validation_prompt <cat-toy>
|
||||
--validation_steps 1
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
@@ -68,8 +66,6 @@ class TextualInversion(ExamplesTestsAccelerate):
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--validation_prompt <cat-toy>
|
||||
--validation_steps 1
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
@@ -102,14 +98,12 @@ class TextualInversion(ExamplesTestsAccelerate):
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--validation_prompt <cat-toy>
|
||||
--validation_steps 1
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 3
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
@@ -123,7 +117,7 @@ class TextualInversion(ExamplesTestsAccelerate):
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-1", "checkpoint-2", "checkpoint-3"},
|
||||
{"checkpoint-1", "checkpoint-2"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -133,21 +127,19 @@ class TextualInversion(ExamplesTestsAccelerate):
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--validation_prompt <cat-toy>
|
||||
--validation_steps 1
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 4
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=1
|
||||
--resume_from_checkpoint=checkpoint-3
|
||||
--resume_from_checkpoint=checkpoint-2
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
@@ -156,5 +148,5 @@ class TextualInversion(ExamplesTestsAccelerate):
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-3", "checkpoint-4"},
|
||||
{"checkpoint-2", "checkpoint-3"},
|
||||
)
|
||||
|
||||
@@ -90,10 +90,10 @@ class Unconditional(ExamplesTestsAccelerate):
|
||||
--train_batch_size 1
|
||||
--num_epochs 1
|
||||
--gradient_accumulation_steps 1
|
||||
--ddpm_num_inference_steps 2
|
||||
--ddpm_num_inference_steps 1
|
||||
--learning_rate 1e-3
|
||||
--lr_warmup_steps 5
|
||||
--checkpointing_steps=1
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
@@ -101,7 +101,7 @@ class Unconditional(ExamplesTestsAccelerate):
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"},
|
||||
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
@@ -113,12 +113,12 @@ class Unconditional(ExamplesTestsAccelerate):
|
||||
--train_batch_size 1
|
||||
--num_epochs 2
|
||||
--gradient_accumulation_steps 1
|
||||
--ddpm_num_inference_steps 2
|
||||
--ddpm_num_inference_steps 1
|
||||
--learning_rate 1e-3
|
||||
--lr_warmup_steps 5
|
||||
--resume_from_checkpoint=checkpoint-6
|
||||
--checkpointing_steps=2
|
||||
--checkpoints_total_limit=3
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
@@ -126,5 +126,5 @@ class Unconditional(ExamplesTestsAccelerate):
|
||||
# check checkpoint directories exist
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
|
||||
{"checkpoint-10", "checkpoint-12"},
|
||||
)
|
||||
|
||||
@@ -77,7 +77,7 @@ First, you need to set up your development environment as explained in the [inst
|
||||
```bash
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
accelerate launch train_text_to_image_prior_lora.py \
|
||||
accelerate launch train_text_to_image_lora_prior.py \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=$DATASET_NAME --caption_column="text" \
|
||||
--resolution=768 \
|
||||
|
||||
@@ -527,9 +527,17 @@ def main():
|
||||
|
||||
# lora attn processor
|
||||
prior_lora_config = LoraConfig(
|
||||
r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"]
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
|
||||
)
|
||||
# Add adapter and make sure the trainable params are in float32.
|
||||
prior.add_adapter(prior_lora_config)
|
||||
if args.mixed_precision == "fp16":
|
||||
for param in prior.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
|
||||
523
scripts/convert_amused.py
Normal file
523
scripts/convert_amused.py
Normal file
@@ -0,0 +1,523 @@
|
||||
import inspect
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from muse import MaskGiTUViT, VQGANModel
|
||||
from muse import PipelineMuse as OldPipelineMuse
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import VQModel
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.models.uvit_2d import UVit2DModel
|
||||
from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
|
||||
from diffusers.schedulers import AmusedScheduler
|
||||
|
||||
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
# Enable CUDNN deterministic mode
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def main():
|
||||
args = ArgumentParser()
|
||||
args.add_argument("--model_256", action="store_true")
|
||||
args.add_argument("--write_to", type=str, required=False, default=None)
|
||||
args.add_argument("--transformer_path", type=str, required=False, default=None)
|
||||
args = args.parse_args()
|
||||
|
||||
transformer_path = args.transformer_path
|
||||
subfolder = "transformer"
|
||||
|
||||
if transformer_path is None:
|
||||
if args.model_256:
|
||||
transformer_path = "openMUSE/muse-256"
|
||||
else:
|
||||
transformer_path = (
|
||||
"../research-run-512-checkpoints/research-run-512-with-downsample-checkpoint-554000/unwrapped_model/"
|
||||
)
|
||||
subfolder = None
|
||||
|
||||
old_transformer = MaskGiTUViT.from_pretrained(transformer_path, subfolder=subfolder)
|
||||
|
||||
old_transformer.to(device)
|
||||
|
||||
old_vae = VQGANModel.from_pretrained("openMUSE/muse-512", subfolder="vae")
|
||||
old_vae.to(device)
|
||||
|
||||
vqvae = make_vqvae(old_vae)
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
|
||||
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
|
||||
text_encoder.to(device)
|
||||
|
||||
transformer = make_transformer(old_transformer, args.model_256)
|
||||
|
||||
scheduler = AmusedScheduler(mask_token_id=old_transformer.config.mask_token_id)
|
||||
|
||||
new_pipe = AmusedPipeline(
|
||||
vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
old_pipe = OldPipelineMuse(
|
||||
vae=old_vae, transformer=old_transformer, text_encoder=text_encoder, tokenizer=tokenizer
|
||||
)
|
||||
old_pipe.to(device)
|
||||
|
||||
if args.model_256:
|
||||
transformer_seq_len = 256
|
||||
orig_size = (256, 256)
|
||||
else:
|
||||
transformer_seq_len = 1024
|
||||
orig_size = (512, 512)
|
||||
|
||||
old_out = old_pipe(
|
||||
"dog",
|
||||
generator=torch.Generator(device).manual_seed(0),
|
||||
transformer_seq_len=transformer_seq_len,
|
||||
orig_size=orig_size,
|
||||
timesteps=12,
|
||||
)[0]
|
||||
|
||||
new_out = new_pipe("dog", generator=torch.Generator(device).manual_seed(0)).images[0]
|
||||
|
||||
old_out = np.array(old_out)
|
||||
new_out = np.array(new_out)
|
||||
|
||||
diff = np.abs(old_out.astype(np.float64) - new_out.astype(np.float64))
|
||||
|
||||
# assert diff diff.sum() == 0
|
||||
print("skipping pipeline full equivalence check")
|
||||
|
||||
print(f"max diff: {diff.max()}, diff.sum() / diff.size {diff.sum() / diff.size}")
|
||||
|
||||
if args.model_256:
|
||||
assert diff.max() <= 3
|
||||
assert diff.sum() / diff.size < 0.7
|
||||
else:
|
||||
assert diff.max() <= 1
|
||||
assert diff.sum() / diff.size < 0.4
|
||||
|
||||
if args.write_to is not None:
|
||||
new_pipe.save_pretrained(args.write_to)
|
||||
|
||||
|
||||
def make_transformer(old_transformer, model_256):
|
||||
args = dict(old_transformer.config)
|
||||
force_down_up_sample = args["force_down_up_sample"]
|
||||
|
||||
signature = inspect.signature(UVit2DModel.__init__)
|
||||
|
||||
args_ = {
|
||||
"downsample": force_down_up_sample,
|
||||
"upsample": force_down_up_sample,
|
||||
"block_out_channels": args["block_out_channels"][0],
|
||||
"sample_size": 16 if model_256 else 32,
|
||||
}
|
||||
|
||||
for s in list(signature.parameters.keys()):
|
||||
if s in ["self", "downsample", "upsample", "sample_size", "block_out_channels"]:
|
||||
continue
|
||||
|
||||
args_[s] = args[s]
|
||||
|
||||
new_transformer = UVit2DModel(**args_)
|
||||
new_transformer.to(device)
|
||||
|
||||
new_transformer.set_attn_processor(AttnProcessor())
|
||||
|
||||
state_dict = old_transformer.state_dict()
|
||||
|
||||
state_dict["cond_embed.linear_1.weight"] = state_dict.pop("cond_embed.0.weight")
|
||||
state_dict["cond_embed.linear_2.weight"] = state_dict.pop("cond_embed.2.weight")
|
||||
|
||||
for i in range(22):
|
||||
state_dict[f"transformer_layers.{i}.norm1.norm.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.attn_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.norm1.linear.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.self_attn_adaLN_modulation.mapper.weight"
|
||||
)
|
||||
|
||||
state_dict[f"transformer_layers.{i}.attn1.to_q.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.attention.query.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.attn1.to_k.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.attention.key.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.attn1.to_v.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.attention.value.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.attention.out.weight"
|
||||
)
|
||||
|
||||
state_dict[f"transformer_layers.{i}.norm2.norm.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.crossattn_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.norm2.linear.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.cross_attn_adaLN_modulation.mapper.weight"
|
||||
)
|
||||
|
||||
state_dict[f"transformer_layers.{i}.attn2.to_q.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.crossattention.query.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.attn2.to_k.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.crossattention.key.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.attn2.to_v.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.crossattention.value.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.crossattention.out.weight"
|
||||
)
|
||||
|
||||
state_dict[f"transformer_layers.{i}.norm3.norm.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.ffn.pre_mlp_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"transformer_layers.{i}.norm3.linear.weight"] = state_dict.pop(
|
||||
f"transformer_layers.{i}.ffn.adaLN_modulation.mapper.weight"
|
||||
)
|
||||
|
||||
wi_0_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_0.weight")
|
||||
wi_1_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_1.weight")
|
||||
proj_weight = torch.concat([wi_1_weight, wi_0_weight], dim=0)
|
||||
state_dict[f"transformer_layers.{i}.ff.net.0.proj.weight"] = proj_weight
|
||||
|
||||
state_dict[f"transformer_layers.{i}.ff.net.2.weight"] = state_dict.pop(f"transformer_layers.{i}.ffn.wo.weight")
|
||||
|
||||
if force_down_up_sample:
|
||||
state_dict["down_block.downsample.norm.weight"] = state_dict.pop("down_blocks.0.downsample.0.norm.weight")
|
||||
state_dict["down_block.downsample.conv.weight"] = state_dict.pop("down_blocks.0.downsample.1.weight")
|
||||
|
||||
state_dict["up_block.upsample.norm.weight"] = state_dict.pop("up_blocks.0.upsample.0.norm.weight")
|
||||
state_dict["up_block.upsample.conv.weight"] = state_dict.pop("up_blocks.0.upsample.1.weight")
|
||||
|
||||
state_dict["mlm_layer.layer_norm.weight"] = state_dict.pop("mlm_layer.layer_norm.norm.weight")
|
||||
|
||||
for i in range(3):
|
||||
state_dict[f"down_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.res_blocks.{i}.norm.norm.weight"
|
||||
)
|
||||
state_dict[f"down_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.res_blocks.{i}.channelwise.0.weight"
|
||||
)
|
||||
state_dict[f"down_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
|
||||
f"down_blocks.0.res_blocks.{i}.channelwise.2.gamma"
|
||||
)
|
||||
state_dict[f"down_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
|
||||
f"down_blocks.0.res_blocks.{i}.channelwise.2.beta"
|
||||
)
|
||||
state_dict[f"down_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.res_blocks.{i}.channelwise.4.weight"
|
||||
)
|
||||
state_dict[f"down_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
|
||||
)
|
||||
|
||||
state_dict[f"down_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.attention.query.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.attention.key.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.attention.value.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.attention.out.weight"
|
||||
)
|
||||
|
||||
state_dict[f"down_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.crossattention.query.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.crossattention.key.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.crossattention.value.weight"
|
||||
)
|
||||
state_dict[f"down_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"down_blocks.0.attention_blocks.{i}.crossattention.out.weight"
|
||||
)
|
||||
|
||||
state_dict[f"up_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.res_blocks.{i}.norm.norm.weight"
|
||||
)
|
||||
state_dict[f"up_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.res_blocks.{i}.channelwise.0.weight"
|
||||
)
|
||||
state_dict[f"up_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
|
||||
f"up_blocks.0.res_blocks.{i}.channelwise.2.gamma"
|
||||
)
|
||||
state_dict[f"up_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
|
||||
f"up_blocks.0.res_blocks.{i}.channelwise.2.beta"
|
||||
)
|
||||
state_dict[f"up_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.res_blocks.{i}.channelwise.4.weight"
|
||||
)
|
||||
state_dict[f"up_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
|
||||
)
|
||||
|
||||
state_dict[f"up_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.attention.query.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.attention.key.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.attention.value.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.attention.out.weight"
|
||||
)
|
||||
|
||||
state_dict[f"up_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.crossattention.query.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.crossattention.key.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.crossattention.value.weight"
|
||||
)
|
||||
state_dict[f"up_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"up_blocks.0.attention_blocks.{i}.crossattention.out.weight"
|
||||
)
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
if key.startswith("up_blocks.0"):
|
||||
key_ = "up_block." + ".".join(key.split(".")[2:])
|
||||
state_dict[key_] = state_dict.pop(key)
|
||||
|
||||
if key.startswith("down_blocks.0"):
|
||||
key_ = "down_block." + ".".join(key.split(".")[2:])
|
||||
state_dict[key_] = state_dict.pop(key)
|
||||
|
||||
new_transformer.load_state_dict(state_dict)
|
||||
|
||||
input_ids = torch.randint(0, 10, (1, 32, 32), device=old_transformer.device)
|
||||
encoder_hidden_states = torch.randn((1, 77, 768), device=old_transformer.device)
|
||||
cond_embeds = torch.randn((1, 768), device=old_transformer.device)
|
||||
micro_conds = torch.tensor([[512, 512, 0, 0, 6]], dtype=torch.float32, device=old_transformer.device)
|
||||
|
||||
old_out = old_transformer(input_ids.reshape(1, -1), encoder_hidden_states, cond_embeds, micro_conds)
|
||||
old_out = old_out.reshape(1, 32, 32, 8192).permute(0, 3, 1, 2)
|
||||
|
||||
new_out = new_transformer(input_ids, encoder_hidden_states, cond_embeds, micro_conds)
|
||||
|
||||
# NOTE: these differences are solely due to using the geglu block that has a single linear layer of
|
||||
# double output dimension instead of two different linear layers
|
||||
max_diff = (old_out - new_out).abs().max()
|
||||
total_diff = (old_out - new_out).abs().sum()
|
||||
print(f"Transformer max_diff: {max_diff} total_diff: {total_diff}")
|
||||
assert max_diff < 0.01
|
||||
assert total_diff < 1500
|
||||
|
||||
return new_transformer
|
||||
|
||||
|
||||
def make_vqvae(old_vae):
|
||||
new_vae = VQModel(
|
||||
act_fn="silu",
|
||||
block_out_channels=[128, 256, 256, 512, 768],
|
||||
down_block_types=[
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
],
|
||||
in_channels=3,
|
||||
latent_channels=64,
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
num_vq_embeddings=8192,
|
||||
out_channels=3,
|
||||
sample_size=32,
|
||||
up_block_types=[
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
],
|
||||
mid_block_add_attention=False,
|
||||
lookup_from_codebook=True,
|
||||
)
|
||||
new_vae.to(device)
|
||||
|
||||
# fmt: off
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
old_state_dict = old_vae.state_dict()
|
||||
|
||||
new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight")
|
||||
new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias")
|
||||
|
||||
convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0")
|
||||
convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1")
|
||||
convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2")
|
||||
convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3")
|
||||
convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4")
|
||||
|
||||
new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("encoder.mid.block_2.norm2.bias")
|
||||
new_state_dict["encoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("encoder.mid.block_2.conv2.weight")
|
||||
new_state_dict["encoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("encoder.mid.block_2.conv2.bias")
|
||||
new_state_dict["encoder.conv_norm_out.weight"] = old_state_dict.pop("encoder.norm_out.weight")
|
||||
new_state_dict["encoder.conv_norm_out.bias"] = old_state_dict.pop("encoder.norm_out.bias")
|
||||
new_state_dict["encoder.conv_out.weight"] = old_state_dict.pop("encoder.conv_out.weight")
|
||||
new_state_dict["encoder.conv_out.bias"] = old_state_dict.pop("encoder.conv_out.bias")
|
||||
new_state_dict["quant_conv.weight"] = old_state_dict.pop("quant_conv.weight")
|
||||
new_state_dict["quant_conv.bias"] = old_state_dict.pop("quant_conv.bias")
|
||||
new_state_dict["quantize.embedding.weight"] = old_state_dict.pop("quantize.embedding.weight")
|
||||
new_state_dict["post_quant_conv.weight"] = old_state_dict.pop("post_quant_conv.weight")
|
||||
new_state_dict["post_quant_conv.bias"] = old_state_dict.pop("post_quant_conv.bias")
|
||||
new_state_dict["decoder.conv_in.weight"] = old_state_dict.pop("decoder.conv_in.weight")
|
||||
new_state_dict["decoder.conv_in.bias"] = old_state_dict.pop("decoder.conv_in.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("decoder.mid.block_1.norm1.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("decoder.mid.block_1.norm1.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("decoder.mid.block_1.conv1.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("decoder.mid.block_1.conv1.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("decoder.mid.block_1.norm2.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("decoder.mid.block_1.norm2.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("decoder.mid.block_1.conv2.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("decoder.mid.block_1.conv2.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("decoder.mid.block_2.norm1.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("decoder.mid.block_2.norm1.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("decoder.mid.block_2.conv1.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("decoder.mid.block_2.conv1.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("decoder.mid.block_2.norm2.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("decoder.mid.block_2.norm2.bias")
|
||||
new_state_dict["decoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("decoder.mid.block_2.conv2.weight")
|
||||
new_state_dict["decoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("decoder.mid.block_2.conv2.bias")
|
||||
|
||||
convert_vae_block_state_dict(old_state_dict, "decoder.up.0", new_state_dict, "decoder.up_blocks.4")
|
||||
convert_vae_block_state_dict(old_state_dict, "decoder.up.1", new_state_dict, "decoder.up_blocks.3")
|
||||
convert_vae_block_state_dict(old_state_dict, "decoder.up.2", new_state_dict, "decoder.up_blocks.2")
|
||||
convert_vae_block_state_dict(old_state_dict, "decoder.up.3", new_state_dict, "decoder.up_blocks.1")
|
||||
convert_vae_block_state_dict(old_state_dict, "decoder.up.4", new_state_dict, "decoder.up_blocks.0")
|
||||
|
||||
new_state_dict["decoder.conv_norm_out.weight"] = old_state_dict.pop("decoder.norm_out.weight")
|
||||
new_state_dict["decoder.conv_norm_out.bias"] = old_state_dict.pop("decoder.norm_out.bias")
|
||||
new_state_dict["decoder.conv_out.weight"] = old_state_dict.pop("decoder.conv_out.weight")
|
||||
new_state_dict["decoder.conv_out.bias"] = old_state_dict.pop("decoder.conv_out.bias")
|
||||
|
||||
# fmt: on
|
||||
|
||||
assert len(old_state_dict.keys()) == 0
|
||||
|
||||
new_vae.load_state_dict(new_state_dict)
|
||||
|
||||
input = torch.randn((1, 3, 512, 512), device=device)
|
||||
input = input.clamp(-1, 1)
|
||||
|
||||
old_encoder_output = old_vae.quant_conv(old_vae.encoder(input))
|
||||
new_encoder_output = new_vae.quant_conv(new_vae.encoder(input))
|
||||
assert (old_encoder_output == new_encoder_output).all()
|
||||
|
||||
old_decoder_output = old_vae.decoder(old_vae.post_quant_conv(old_encoder_output))
|
||||
new_decoder_output = new_vae.decoder(new_vae.post_quant_conv(new_encoder_output))
|
||||
|
||||
# assert (old_decoder_output == new_decoder_output).all()
|
||||
print("kipping vae decoder equivalence check")
|
||||
print(f"vae decoder diff {(old_decoder_output - new_decoder_output).float().abs().sum()}")
|
||||
|
||||
old_output = old_vae(input)[0]
|
||||
new_output = new_vae(input)[0]
|
||||
|
||||
# assert (old_output == new_output).all()
|
||||
print("skipping full vae equivalence check")
|
||||
print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
|
||||
|
||||
return new_vae
|
||||
|
||||
|
||||
def convert_vae_block_state_dict(old_state_dict, prefix_from, new_state_dict, prefix_to):
|
||||
# fmt: off
|
||||
|
||||
new_state_dict[f"{prefix_to}.resnets.0.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.bias")
|
||||
|
||||
if f"{prefix_from}.block.0.nin_shortcut.weight" in old_state_dict:
|
||||
new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.bias")
|
||||
|
||||
new_state_dict[f"{prefix_to}.resnets.1.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.1.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.bias")
|
||||
|
||||
if f"{prefix_from}.downsample.conv.weight" in old_state_dict:
|
||||
new_state_dict[f"{prefix_to}.downsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.weight")
|
||||
new_state_dict[f"{prefix_to}.downsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.bias")
|
||||
|
||||
if f"{prefix_from}.upsample.conv.weight" in old_state_dict:
|
||||
new_state_dict[f"{prefix_to}.upsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.weight")
|
||||
new_state_dict[f"{prefix_to}.upsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.bias")
|
||||
|
||||
if f"{prefix_from}.block.2.norm1.weight" in old_state_dict:
|
||||
new_state_dict[f"{prefix_to}.resnets.2.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.bias")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.weight")
|
||||
new_state_dict[f"{prefix_to}.resnets.2.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.bias")
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
51
scripts/convert_animatediff_motion_lora_to_diffusers.py
Normal file
51
scripts/convert_animatediff_motion_lora_to_diffusers.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
def convert_motion_module(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
if "pos_encoder" in k:
|
||||
continue
|
||||
|
||||
else:
|
||||
converted_state_dict[
|
||||
k.replace(".norms.0", ".norm1")
|
||||
.replace(".norms.1", ".norm2")
|
||||
.replace(".ff_norm", ".norm3")
|
||||
.replace(".attention_blocks.0", ".attn1")
|
||||
.replace(".attention_blocks.1", ".attn2")
|
||||
.replace(".temporal_transformer", "")
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
conv_state_dict = convert_motion_module(state_dict)
|
||||
|
||||
# convert to new format
|
||||
output_dict = {}
|
||||
for module_name, params in conv_state_dict.items():
|
||||
if type(params) is not torch.Tensor:
|
||||
continue
|
||||
output_dict.update({f"unet.{module_name}": params})
|
||||
|
||||
save_file(output_dict, f"{args.output_path}/diffusion_pytorch_model.safetensors")
|
||||
51
scripts/convert_animatediff_motion_module_to_diffusers.py
Normal file
51
scripts/convert_animatediff_motion_module_to_diffusers.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import MotionAdapter
|
||||
|
||||
|
||||
def convert_motion_module(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
if "pos_encoder" in k:
|
||||
continue
|
||||
|
||||
else:
|
||||
converted_state_dict[
|
||||
k.replace(".norms.0", ".norm1")
|
||||
.replace(".norms.1", ".norm2")
|
||||
.replace(".ff_norm", ".norm3")
|
||||
.replace(".attention_blocks.0", ".attn1")
|
||||
.replace(".attention_blocks.1", ".attn2")
|
||||
.replace(".temporal_transformer", "")
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument("--use_motion_mid_block", action="store_true")
|
||||
parser.add_argument("--motion_max_seq_length", type=int, default=32)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
conv_state_dict = convert_motion_module(state_dict)
|
||||
adapter = MotionAdapter(
|
||||
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length
|
||||
)
|
||||
# skip loading position embeddings
|
||||
adapter.load_state_dict(conv_state_dict, strict=False)
|
||||
adapter.save_pretrained(args.output_path)
|
||||
adapter.save_pretrained(args.output_path, variant="fp16", torch_dtype=torch.float16)
|
||||
@@ -12,9 +12,9 @@ from safetensors.torch import load_file as stl
|
||||
from tqdm import tqdm
|
||||
|
||||
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
|
||||
from diffusers.models.autoencoders.vae import Encoder
|
||||
from diffusers.models.embeddings import TimestepEmbedding
|
||||
from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
|
||||
from diffusers.models.vae import Encoder
|
||||
|
||||
|
||||
args = ArgumentParser()
|
||||
|
||||
@@ -159,6 +159,14 @@ vae_conversion_map_attn = [
|
||||
("proj_out.", "proj_attn."),
|
||||
]
|
||||
|
||||
# This is probably not the most ideal solution, but it does work.
|
||||
vae_extra_conversion_map = [
|
||||
("to_q", "q"),
|
||||
("to_k", "k"),
|
||||
("to_v", "v"),
|
||||
("to_out.0", "proj_out"),
|
||||
]
|
||||
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
@@ -178,11 +186,20 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
||||
weights_to_convert = ["q", "k", "v", "proj_out"]
|
||||
keys_to_rename = {}
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
print(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
for weight_name, real_weight_name in vae_extra_conversion_map:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
|
||||
keys_to_rename[k] = k.replace(weight_name, real_weight_name)
|
||||
for k, v in keys_to_rename.items():
|
||||
if k in new_state_dict:
|
||||
print(f"Renaming {k} to {v}")
|
||||
new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
|
||||
del new_state_dict[k]
|
||||
return new_state_dict
|
||||
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -204,7 +204,7 @@ class DepsTableUpdateCommand(Command):
|
||||
extras = {}
|
||||
extras["quality"] = deps_list("urllib3", "isort", "ruff", "hf-doc-builder")
|
||||
extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft")
|
||||
extras["test"] = deps_list(
|
||||
"compel",
|
||||
"GitPython",
|
||||
|
||||
@@ -80,7 +80,6 @@ else:
|
||||
"AutoencoderTiny",
|
||||
"ConsistencyDecoderVAE",
|
||||
"ControlNetModel",
|
||||
"ControlNetXSModel",
|
||||
"Kandinsky3UNet",
|
||||
"ModelMixin",
|
||||
"MotionAdapter",
|
||||
@@ -95,6 +94,7 @@ else:
|
||||
"UNet3DConditionModel",
|
||||
"UNetMotionModel",
|
||||
"UNetSpatioTemporalConditionModel",
|
||||
"UVit2DModel",
|
||||
"VQModel",
|
||||
]
|
||||
)
|
||||
@@ -131,6 +131,7 @@ else:
|
||||
)
|
||||
_import_structure["schedulers"].extend(
|
||||
[
|
||||
"AmusedScheduler",
|
||||
"CMStochasticIterativeScheduler",
|
||||
"DDIMInverseScheduler",
|
||||
"DDIMParallelScheduler",
|
||||
@@ -202,6 +203,9 @@ else:
|
||||
[
|
||||
"AltDiffusionImg2ImgPipeline",
|
||||
"AltDiffusionPipeline",
|
||||
"AmusedImg2ImgPipeline",
|
||||
"AmusedInpaintPipeline",
|
||||
"AmusedPipeline",
|
||||
"AnimateDiffPipeline",
|
||||
"AudioLDM2Pipeline",
|
||||
"AudioLDM2ProjectionModel",
|
||||
@@ -251,7 +255,6 @@ else:
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetXSPipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
"StableDiffusionDiffEditPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
@@ -275,7 +278,6 @@ else:
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
"StableDiffusionXLControlNetXSPipeline",
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
@@ -457,7 +459,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetModel,
|
||||
ControlNetXSModel,
|
||||
Kandinsky3UNet,
|
||||
ModelMixin,
|
||||
MotionAdapter,
|
||||
@@ -472,6 +473,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UNet3DConditionModel,
|
||||
UNetMotionModel,
|
||||
UNetSpatioTemporalConditionModel,
|
||||
UVit2DModel,
|
||||
VQModel,
|
||||
)
|
||||
from .optimization import (
|
||||
@@ -506,6 +508,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ScoreSdeVePipeline,
|
||||
)
|
||||
from .schedulers import (
|
||||
AmusedScheduler,
|
||||
CMStochasticIterativeScheduler,
|
||||
DDIMInverseScheduler,
|
||||
DDIMParallelScheduler,
|
||||
@@ -560,6 +563,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipelines import (
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
AmusedImg2ImgPipeline,
|
||||
AmusedInpaintPipeline,
|
||||
AmusedPipeline,
|
||||
AnimateDiffPipeline,
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
@@ -607,7 +613,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
@@ -631,7 +636,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
|
||||
from .configuration_utils import ConfigMixin, register_to_config
|
||||
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
||||
@@ -166,6 +166,244 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
|
||||
"""
|
||||
Blurs an image.
|
||||
"""
|
||||
image = image.filter(ImageFilter.GaussianBlur(blur_factor))
|
||||
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
|
||||
"""
|
||||
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
|
||||
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
|
||||
|
||||
Args:
|
||||
mask_image (PIL.Image.Image): Mask image.
|
||||
width (int): Width of the image to be processed.
|
||||
height (int): Height of the image to be processed.
|
||||
pad (int, optional): Padding to be added to the crop region. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
|
||||
"""
|
||||
|
||||
mask_image = mask_image.convert("L")
|
||||
mask = np.array(mask_image)
|
||||
|
||||
# 1. find a rectangular region that contains all masked ares in an image
|
||||
h, w = mask.shape
|
||||
crop_left = 0
|
||||
for i in range(w):
|
||||
if not (mask[:, i] == 0).all():
|
||||
break
|
||||
crop_left += 1
|
||||
|
||||
crop_right = 0
|
||||
for i in reversed(range(w)):
|
||||
if not (mask[:, i] == 0).all():
|
||||
break
|
||||
crop_right += 1
|
||||
|
||||
crop_top = 0
|
||||
for i in range(h):
|
||||
if not (mask[i] == 0).all():
|
||||
break
|
||||
crop_top += 1
|
||||
|
||||
crop_bottom = 0
|
||||
for i in reversed(range(h)):
|
||||
if not (mask[i] == 0).all():
|
||||
break
|
||||
crop_bottom += 1
|
||||
|
||||
# 2. add padding to the crop region
|
||||
x1, y1, x2, y2 = (
|
||||
int(max(crop_left - pad, 0)),
|
||||
int(max(crop_top - pad, 0)),
|
||||
int(min(w - crop_right + pad, w)),
|
||||
int(min(h - crop_bottom + pad, h)),
|
||||
)
|
||||
|
||||
# 3. expands crop region to match the aspect ratio of the image to be processed
|
||||
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
||||
ratio_processing = width / height
|
||||
|
||||
if ratio_crop_region > ratio_processing:
|
||||
desired_height = (x2 - x1) / ratio_processing
|
||||
desired_height_diff = int(desired_height - (y2 - y1))
|
||||
y1 -= desired_height_diff // 2
|
||||
y2 += desired_height_diff - desired_height_diff // 2
|
||||
if y2 >= mask_image.height:
|
||||
diff = y2 - mask_image.height
|
||||
y2 -= diff
|
||||
y1 -= diff
|
||||
if y1 < 0:
|
||||
y2 -= y1
|
||||
y1 -= y1
|
||||
if y2 >= mask_image.height:
|
||||
y2 = mask_image.height
|
||||
else:
|
||||
desired_width = (y2 - y1) * ratio_processing
|
||||
desired_width_diff = int(desired_width - (x2 - x1))
|
||||
x1 -= desired_width_diff // 2
|
||||
x2 += desired_width_diff - desired_width_diff // 2
|
||||
if x2 >= mask_image.width:
|
||||
diff = x2 - mask_image.width
|
||||
x2 -= diff
|
||||
x1 -= diff
|
||||
if x1 < 0:
|
||||
x2 -= x1
|
||||
x1 -= x1
|
||||
if x2 >= mask_image.width:
|
||||
x2 = mask_image.width
|
||||
|
||||
return x1, y1, x2, y2
|
||||
|
||||
def _resize_and_fill(
|
||||
self,
|
||||
image: PIL.Image.Image,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
||||
|
||||
Args:
|
||||
image: The image to resize.
|
||||
width: The width to resize the image to.
|
||||
height: The height to resize the image to.
|
||||
"""
|
||||
|
||||
ratio = width / height
|
||||
src_ratio = image.width / image.height
|
||||
|
||||
src_w = width if ratio < src_ratio else image.width * height // image.height
|
||||
src_h = height if ratio >= src_ratio else image.height * width // image.width
|
||||
|
||||
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
|
||||
if ratio < src_ratio:
|
||||
fill_height = height // 2 - src_h // 2
|
||||
if fill_height > 0:
|
||||
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
||||
res.paste(
|
||||
resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
|
||||
box=(0, fill_height + src_h),
|
||||
)
|
||||
elif ratio > src_ratio:
|
||||
fill_width = width // 2 - src_w // 2
|
||||
if fill_width > 0:
|
||||
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
||||
res.paste(
|
||||
resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
|
||||
box=(fill_width + src_w, 0),
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def _resize_and_crop(
|
||||
self,
|
||||
image: PIL.Image.Image,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
||||
|
||||
Args:
|
||||
image: The image to resize.
|
||||
width: The width to resize the image to.
|
||||
height: The height to resize the image to.
|
||||
"""
|
||||
ratio = width / height
|
||||
src_ratio = image.width / image.height
|
||||
|
||||
src_w = width if ratio > src_ratio else image.width * height // image.height
|
||||
src_h = height if ratio <= src_ratio else image.height * width // image.width
|
||||
|
||||
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
return res
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
height: int,
|
||||
width: int,
|
||||
resize_mode: str = "default", # "defalt", "fill", "crop"
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Resize image.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
||||
The image input, can be a PIL image, numpy array or pytorch tensor.
|
||||
height (`int`):
|
||||
The height to resize to.
|
||||
width (`int`):
|
||||
The width to resize to.
|
||||
resize_mode (`str`, *optional*, defaults to `default`):
|
||||
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
||||
within the specified width and height, and it may not maintaining the original aspect ratio.
|
||||
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
||||
within the dimensions, filling empty with data from image.
|
||||
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
||||
within the dimensions, cropping the excess.
|
||||
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
||||
The resized image.
|
||||
"""
|
||||
if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
if resize_mode == "default":
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
elif resize_mode == "fill":
|
||||
image = self._resize_and_fill(image, width, height)
|
||||
elif resize_mode == "crop":
|
||||
image = self._resize_and_crop(image, width, height)
|
||||
else:
|
||||
raise ValueError(f"resize_mode {resize_mode} is not supported")
|
||||
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = torch.nn.functional.interpolate(
|
||||
image,
|
||||
size=(height, width),
|
||||
)
|
||||
elif isinstance(image, np.ndarray):
|
||||
image = self.numpy_to_pt(image)
|
||||
image = torch.nn.functional.interpolate(
|
||||
image,
|
||||
size=(height, width),
|
||||
)
|
||||
image = self.pt_to_numpy(image)
|
||||
return image
|
||||
|
||||
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
"""
|
||||
Create a mask.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The image input, should be a PIL image.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
|
||||
"""
|
||||
image[image < 0.5] = 0
|
||||
image[image >= 0.5] = 1
|
||||
return image
|
||||
|
||||
def get_default_height_width(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
@@ -209,67 +447,34 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
return height, width
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Resize image.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
||||
The image input, can be a PIL image, numpy array or pytorch tensor.
|
||||
height (`int`, *optional*, defaults to `None`):
|
||||
The height to resize to.
|
||||
width (`int`, *optional*`, defaults to `None`):
|
||||
The width to resize to.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
||||
The resized image.
|
||||
"""
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = torch.nn.functional.interpolate(
|
||||
image,
|
||||
size=(height, width),
|
||||
)
|
||||
elif isinstance(image, np.ndarray):
|
||||
image = self.numpy_to_pt(image)
|
||||
image = torch.nn.functional.interpolate(
|
||||
image,
|
||||
size=(height, width),
|
||||
)
|
||||
image = self.pt_to_numpy(image)
|
||||
return image
|
||||
|
||||
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
"""
|
||||
Create a mask.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The image input, should be a PIL image.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
|
||||
"""
|
||||
image[image < 0.5] = 0
|
||||
image[image >= 0.5] = 1
|
||||
return image
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
||||
image: PipelineImageInput,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
resize_mode: str = "default", # "defalt", "fill", "crop"
|
||||
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
||||
Preprocess the image input.
|
||||
|
||||
Args:
|
||||
image (`pipeline_image_input`):
|
||||
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
|
||||
height (`int`, *optional*, defaults to `None`):
|
||||
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
|
||||
width (`int`, *optional*`, defaults to `None`):
|
||||
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
||||
resize_mode (`str`, *optional*, defaults to `default`):
|
||||
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
||||
within the specified width and height, and it may not maintaining the original aspect ratio.
|
||||
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
||||
within the dimensions, filling empty with data from image.
|
||||
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
||||
within the dimensions, cropping the excess.
|
||||
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
|
||||
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
||||
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
||||
"""
|
||||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
||||
|
||||
@@ -299,13 +504,15 @@ class VaeImageProcessor(ConfigMixin):
|
||||
)
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
if crops_coords is not None:
|
||||
image = [i.crop(crops_coords) for i in image]
|
||||
if self.config.do_resize:
|
||||
height, width = self.get_default_height_width(image[0], height, width)
|
||||
image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
|
||||
if self.config.do_convert_rgb:
|
||||
image = [self.convert_to_rgb(i) for i in image]
|
||||
elif self.config.do_convert_grayscale:
|
||||
image = [self.convert_to_grayscale(i) for i in image]
|
||||
if self.config.do_resize:
|
||||
height, width = self.get_default_height_width(image[0], height, width)
|
||||
image = [self.resize(i, height, width) for i in image]
|
||||
image = self.pil_to_numpy(image) # to np
|
||||
image = self.numpy_to_pt(image) # to pt
|
||||
|
||||
@@ -406,6 +613,39 @@ class VaeImageProcessor(ConfigMixin):
|
||||
if output_type == "pil":
|
||||
return self.numpy_to_pil(image)
|
||||
|
||||
def apply_overlay(
|
||||
self,
|
||||
mask: PIL.Image.Image,
|
||||
init_image: PIL.Image.Image,
|
||||
image: PIL.Image.Image,
|
||||
crop_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
overlay the inpaint output to the original image
|
||||
"""
|
||||
|
||||
width, height = image.width, image.height
|
||||
|
||||
init_image = self.resize(init_image, width=width, height=height)
|
||||
mask = self.resize(mask, width=width, height=height)
|
||||
|
||||
init_image_masked = PIL.Image.new("RGBa", (width, height))
|
||||
init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
|
||||
init_image_masked = init_image_masked.convert("RGBA")
|
||||
|
||||
if crop_coords is not None:
|
||||
x, y, w, h = crop_coords
|
||||
base_image = PIL.Image.new("RGBA", (width, height))
|
||||
image = self.resize(image, height=h, width=w, resize_mode="crop")
|
||||
base_image.paste(image, (x, y))
|
||||
image = base_image.convert("RGB")
|
||||
|
||||
image = image.convert("RGBA")
|
||||
image.alpha_composite(init_image_masked)
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
"""
|
||||
|
||||
@@ -149,9 +149,11 @@ class IPAdapterMixin:
|
||||
self.feature_extractor = CLIPImageProcessor()
|
||||
|
||||
# load ip-adapter into unet
|
||||
self.unet._load_ip_adapter_weights(state_dict)
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet._load_ip_adapter_weights(state_dict)
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
for attn_processor in self.unet.attn_processors.values():
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for attn_processor in unet.attn_processors.values():
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
attn_processor.scale = scale
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
@@ -18,6 +19,7 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
@@ -58,6 +60,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
TEXT_ENCODER_NAME = "text_encoder"
|
||||
UNET_NAME = "unet"
|
||||
TRANSFORMER_NAME = "transformer"
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
@@ -73,6 +76,7 @@ class LoraLoaderMixin:
|
||||
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
unet_name = UNET_NAME
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
num_fused_loras = 0
|
||||
|
||||
def load_lora_weights(
|
||||
@@ -229,7 +233,9 @@ class LoraLoaderMixin:
|
||||
# determine `weight_name`.
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
file_extension=".safetensors",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
@@ -255,7 +261,7 @@ class LoraLoaderMixin:
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin"
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
@@ -294,7 +300,12 @@ class LoraLoaderMixin:
|
||||
return state_dict, network_alphas
|
||||
|
||||
@classmethod
|
||||
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
|
||||
def _best_guess_weight_name(
|
||||
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
||||
):
|
||||
if local_files_only or HF_HUB_OFFLINE:
|
||||
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
||||
|
||||
targeted_files = []
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||
@@ -653,6 +664,89 @@ class LoraLoaderMixin:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
See `LoRALinearLayer` for more details.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
||||
state_dict = {
|
||||
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
||||
}
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
|
||||
network_alphas = {
|
||||
k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(transformer)
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
@property
|
||||
def lora_scale(self) -> float:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
@@ -778,6 +872,7 @@ class LoraLoaderMixin:
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -812,14 +907,19 @@ class LoraLoaderMixin:
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
if not (unet_lora_layers or text_encoder_lora_layers):
|
||||
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
|
||||
if not (unet_lora_layers or text_encoder_lora_layers or transformer_lora_layers):
|
||||
raise ValueError(
|
||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `transformer_lora_layers`."
|
||||
)
|
||||
|
||||
if unet_lora_layers:
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
state_dict.update(pack_weights(unet_lora_layers, cls.unet_name))
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
@@ -876,6 +976,8 @@ class LoraLoaderMixin:
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
logger.warn(
|
||||
@@ -883,13 +985,13 @@ class LoraLoaderMixin:
|
||||
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
|
||||
)
|
||||
|
||||
for _, module in self.unet.named_modules():
|
||||
for _, module in unet.named_modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
else:
|
||||
recurse_remove_peft_layers(self.unet)
|
||||
if hasattr(self.unet, "peft_config"):
|
||||
del self.unet.peft_config
|
||||
recurse_remove_peft_layers(unet)
|
||||
if hasattr(unet, "peft_config"):
|
||||
del unet.peft_config
|
||||
|
||||
# Safe to call the following regardless of LoRA.
|
||||
self._remove_text_encoder_monkey_patch()
|
||||
@@ -900,6 +1002,7 @@ class LoraLoaderMixin:
|
||||
fuse_text_encoder: bool = True,
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
@@ -919,6 +1022,21 @@ class LoraLoaderMixin:
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
if fuse_unet or fuse_text_encoder:
|
||||
self.num_fused_loras += 1
|
||||
@@ -928,24 +1046,44 @@ class LoraLoaderMixin:
|
||||
)
|
||||
|
||||
if fuse_unet:
|
||||
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
||||
# TODO(Patrick, Younes): enable "safe" fusing
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
module.merge()
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
else:
|
||||
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs):
|
||||
if "adapter_names" in kwargs and kwargs["adapter_names"] is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported in your environment. Please switch to PEFT "
|
||||
"backend to use this argument by installing latest PEFT and transformers."
|
||||
" `pip install -U peft transformers`"
|
||||
)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
@@ -960,9 +1098,9 @@ class LoraLoaderMixin:
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing, adapter_names=adapter_names)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing, adapter_names=adapter_names)
|
||||
|
||||
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
|
||||
r"""
|
||||
@@ -981,13 +1119,14 @@ class LoraLoaderMixin:
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
if unfuse_unet:
|
||||
if not USE_PEFT_BACKEND:
|
||||
self.unet.unfuse_lora()
|
||||
unet.unfuse_lora()
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in self.unet.modules():
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
@@ -1103,8 +1242,9 @@ class LoraLoaderMixin:
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[List[float]] = None,
|
||||
):
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
# Handle the UNET
|
||||
self.unet.set_adapters(adapter_names, adapter_weights)
|
||||
unet.set_adapters(adapter_names, adapter_weights)
|
||||
|
||||
# Handle the Text Encoder
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1117,7 +1257,8 @@ class LoraLoaderMixin:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# Disable unet adapters
|
||||
self.unet.disable_lora()
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.disable_lora()
|
||||
|
||||
# Disable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1130,7 +1271,8 @@ class LoraLoaderMixin:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# Enable unet adapters
|
||||
self.unet.enable_lora()
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.enable_lora()
|
||||
|
||||
# Enable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1152,7 +1294,8 @@ class LoraLoaderMixin:
|
||||
adapter_names = [adapter_names]
|
||||
|
||||
# Delete unet adapters
|
||||
self.unet.delete_adapters(adapter_names)
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.delete_adapters(adapter_names)
|
||||
|
||||
for adapter_name in adapter_names:
|
||||
# Delete text encoder adapters
|
||||
@@ -1185,8 +1328,8 @@ class LoraLoaderMixin:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
active_adapters = []
|
||||
|
||||
for module in self.unet.modules():
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
break
|
||||
@@ -1210,8 +1353,9 @@ class LoraLoaderMixin:
|
||||
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
|
||||
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
|
||||
|
||||
if hasattr(self, "unet") and hasattr(self.unet, "peft_config"):
|
||||
set_adapters["unet"] = list(self.unet.peft_config.keys())
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"):
|
||||
set_adapters[self.unet_name] = list(self.unet.peft_config.keys())
|
||||
|
||||
return set_adapters
|
||||
|
||||
@@ -1232,7 +1376,8 @@ class LoraLoaderMixin:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
# Handle the UNET
|
||||
for unet_module in self.unet.modules():
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for unet_module in unet.modules():
|
||||
if isinstance(unet_module, BaseTunerLayer):
|
||||
for adapter_name in adapter_names:
|
||||
unet_module.lora_A[adapter_name].to(device)
|
||||
|
||||
@@ -11,9 +11,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
@@ -504,22 +506,43 @@ class UNet2DConditionLoadersMixin:
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
self.lora_scale = lora_scale
|
||||
self._safe_fusing = safe_fusing
|
||||
self.apply(self._fuse_lora_apply)
|
||||
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
|
||||
|
||||
def _fuse_lora_apply(self, module):
|
||||
def _fuse_lora_apply(self, module, adapter_names=None):
|
||||
if not USE_PEFT_BACKEND:
|
||||
if hasattr(module, "_fuse_lora"):
|
||||
module._fuse_lora(self.lora_scale, self._safe_fusing)
|
||||
|
||||
if adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported in your environment. Please switch"
|
||||
" to PEFT backend to use this argument by installing latest PEFT and transformers."
|
||||
" `pip install -U peft transformers`"
|
||||
)
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
merge_kwargs = {"safe_merge": self._safe_fusing}
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if self.lora_scale != 1.0:
|
||||
module.scale_layer(self.lora_scale)
|
||||
module.merge(safe_merge=self._safe_fusing)
|
||||
|
||||
# For BC with prevous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
|
||||
" to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
def unfuse_lora(self):
|
||||
self.apply(self._unfuse_lora_apply)
|
||||
@@ -664,6 +687,80 @@ class UNet2DConditionLoadersMixin:
|
||||
if hasattr(self, "peft_config"):
|
||||
self.peft_config.pop(adapter_name, None)
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
|
||||
updated_state_dict = {}
|
||||
image_projection = None
|
||||
|
||||
if "proj.weight" in state_dict:
|
||||
# IP-Adapter
|
||||
num_image_text_embeds = 4
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=clip_embeddings_dim,
|
||||
num_image_text_embeds=num_image_text_embeds,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj", "image_embeds")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
elif "proj.3.weight" in state_dict:
|
||||
# IP-Adapter Full
|
||||
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
|
||||
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
|
||||
|
||||
image_projection = MLPProjection(
|
||||
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
|
||||
diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
|
||||
diffusers_name = diffusers_name.replace("proj.3", "norm")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
else:
|
||||
# IP-Adapter Plus
|
||||
num_image_text_embeds = state_dict["latents"].shape[1]
|
||||
embed_dims = state_dict["proj_in.weight"].shape[1]
|
||||
output_dims = state_dict["proj_out.weight"].shape[0]
|
||||
hidden_dims = state_dict["latents"].shape[2]
|
||||
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
|
||||
|
||||
image_projection = Resampler(
|
||||
embed_dims=embed_dims,
|
||||
output_dims=output_dims,
|
||||
hidden_dims=hidden_dims,
|
||||
heads=heads,
|
||||
num_queries=num_image_text_embeds,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("0.to", "2.to")
|
||||
diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
|
||||
diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
|
||||
diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
|
||||
diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
|
||||
|
||||
if "norm1" in diffusers_name:
|
||||
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
|
||||
elif "norm2" in diffusers_name:
|
||||
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
|
||||
elif "to_kv" in diffusers_name:
|
||||
v_chunk = value.chunk(2, dim=0)
|
||||
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
|
||||
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
|
||||
elif "to_out" in diffusers_name:
|
||||
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
|
||||
else:
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
image_projection.load_state_dict(updated_state_dict)
|
||||
return image_projection
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dict):
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
@@ -724,103 +821,8 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
# create image projection layers.
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter
|
||||
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
|
||||
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=clip_embeddings_dim,
|
||||
num_image_text_embeds=num_image_text_embeds,
|
||||
)
|
||||
image_projection.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
# load image projection layer weights
|
||||
image_proj_state_dict = {}
|
||||
image_proj_state_dict.update(
|
||||
{
|
||||
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
|
||||
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
|
||||
"norm.weight": state_dict["image_proj"]["norm.weight"],
|
||||
"norm.bias": state_dict["image_proj"]["norm.bias"],
|
||||
}
|
||||
)
|
||||
image_projection.load_state_dict(image_proj_state_dict)
|
||||
del image_proj_state_dict
|
||||
|
||||
elif "proj.3.weight" in state_dict["image_proj"]:
|
||||
clip_embeddings_dim = state_dict["image_proj"]["proj.0.weight"].shape[0]
|
||||
cross_attention_dim = state_dict["image_proj"]["proj.3.weight"].shape[0]
|
||||
|
||||
image_projection = MLPProjection(
|
||||
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
|
||||
)
|
||||
image_projection.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
# load image projection layer weights
|
||||
image_proj_state_dict = {}
|
||||
image_proj_state_dict.update(
|
||||
{
|
||||
"ff.net.0.proj.weight": state_dict["image_proj"]["proj.0.weight"],
|
||||
"ff.net.0.proj.bias": state_dict["image_proj"]["proj.0.bias"],
|
||||
"ff.net.2.weight": state_dict["image_proj"]["proj.2.weight"],
|
||||
"ff.net.2.bias": state_dict["image_proj"]["proj.2.bias"],
|
||||
"norm.weight": state_dict["image_proj"]["proj.3.weight"],
|
||||
"norm.bias": state_dict["image_proj"]["proj.3.bias"],
|
||||
}
|
||||
)
|
||||
image_projection.load_state_dict(image_proj_state_dict)
|
||||
del image_proj_state_dict
|
||||
|
||||
else:
|
||||
# IP-Adapter Plus
|
||||
embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1]
|
||||
output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0]
|
||||
hidden_dims = state_dict["image_proj"]["latents"].shape[2]
|
||||
heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64
|
||||
|
||||
image_projection = Resampler(
|
||||
embed_dims=embed_dims,
|
||||
output_dims=output_dims,
|
||||
hidden_dims=hidden_dims,
|
||||
heads=heads,
|
||||
num_queries=num_image_text_embeds,
|
||||
)
|
||||
|
||||
image_proj_state_dict = state_dict["image_proj"]
|
||||
|
||||
new_sd = OrderedDict()
|
||||
for k, v in image_proj_state_dict.items():
|
||||
if "0.to" in k:
|
||||
k = k.replace("0.to", "2.to")
|
||||
elif "1.0.weight" in k:
|
||||
k = k.replace("1.0.weight", "3.0.weight")
|
||||
elif "1.0.bias" in k:
|
||||
k = k.replace("1.0.bias", "3.0.bias")
|
||||
elif "1.1.weight" in k:
|
||||
k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
|
||||
elif "1.3.weight" in k:
|
||||
k = k.replace("1.3.weight", "3.1.net.2.weight")
|
||||
|
||||
if "norm1" in k:
|
||||
new_sd[k.replace("0.norm1", "0")] = v
|
||||
elif "norm2" in k:
|
||||
new_sd[k.replace("0.norm2", "1")] = v
|
||||
elif "to_kv" in k:
|
||||
v_chunk = v.chunk(2, dim=0)
|
||||
new_sd[k.replace("to_kv", "to_k")] = v_chunk[0]
|
||||
new_sd[k.replace("to_kv", "to_v")] = v_chunk[1]
|
||||
elif "to_out" in k:
|
||||
new_sd[k.replace("to_out", "to_out.0")] = v
|
||||
else:
|
||||
new_sd[k] = v
|
||||
|
||||
image_projection.load_state_dict(new_sd)
|
||||
del image_proj_state_dict
|
||||
# convert IP-Adapter Image Projection layers to diffusers
|
||||
image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
|
||||
|
||||
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
|
||||
delete_adapter_layers
|
||||
|
||||
@@ -26,13 +26,12 @@ _import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnetxs"] = ["ControlNetXSModel"]
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
@@ -47,6 +46,7 @@ if is_torch_available():
|
||||
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["uvit_2d"] = ["UVit2DModel"]
|
||||
_import_structure["vq_model"] = ["VQModel"]
|
||||
|
||||
if is_flax_available():
|
||||
@@ -58,13 +58,14 @@ if is_flax_available():
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
from .autoencoders import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .controlnetxs import ControlNetXSModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
@@ -79,6 +80,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .unet_kandinsky3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .uvit_2d import UVit2DModel
|
||||
from .vq_model import VQModel
|
||||
|
||||
if is_flax_available():
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
@@ -22,7 +23,7 @@ from .activations import GEGLU, GELU, ApproximateGELU
|
||||
from .attention_processor import Attention
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .lora import LoRACompatibleLinear
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormZero
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
||||
|
||||
|
||||
def _chunked_feed_forward(
|
||||
@@ -148,6 +149,11 @@ class BasicTransformerBlock(nn.Module):
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
||||
ada_norm_bias: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
@@ -156,6 +162,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
@@ -179,6 +186,15 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
elif self.use_ada_layer_norm_continuous:
|
||||
self.norm1 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
@@ -190,6 +206,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
@@ -197,11 +214,20 @@ class BasicTransformerBlock(nn.Module):
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
self.norm2 = (
|
||||
AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
if self.use_ada_layer_norm
|
||||
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
)
|
||||
if self.use_ada_layer_norm:
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif self.use_ada_layer_norm_continuous:
|
||||
self.norm2 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
||||
@@ -210,20 +236,32 @@ class BasicTransformerBlock(nn.Module):
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
if not self.use_ada_layer_norm_single:
|
||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
if self.use_ada_layer_norm_continuous:
|
||||
self.norm3 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"layer_norm",
|
||||
)
|
||||
elif not self.use_ada_layer_norm_single:
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
# 4. Fuser
|
||||
@@ -252,6 +290,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.FloatTensor:
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
@@ -265,6 +304,8 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
elif self.use_layer_norm:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
elif self.use_ada_layer_norm_continuous:
|
||||
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif self.use_ada_layer_norm_single:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
@@ -314,6 +355,8 @@ class BasicTransformerBlock(nn.Module):
|
||||
# For PixArt norm2 isn't applied here:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
||||
norm_hidden_states = hidden_states
|
||||
elif self.use_ada_layer_norm_continuous:
|
||||
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
else:
|
||||
raise ValueError("Incorrect norm")
|
||||
|
||||
@@ -329,7 +372,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
if not self.use_ada_layer_norm_single:
|
||||
if self.use_ada_layer_norm_continuous:
|
||||
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
elif not self.use_ada_layer_norm_single:
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
@@ -490,6 +535,78 @@ class TemporalBasicTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SkipFFTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
kv_input_dim: int,
|
||||
kv_input_dim_proj_use_bias: bool,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if kv_input_dim != dim:
|
||||
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
|
||||
else:
|
||||
self.kv_mapper = None
|
||||
|
||||
self.norm1 = RMSNorm(dim, 1e-06)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(dim, 1e-06)
|
||||
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
|
||||
if self.kv_mapper is not None:
|
||||
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
@@ -512,10 +629,12 @@ class FeedForward(nn.Module):
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
final_dropout: bool = False,
|
||||
inner_dim=None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
||||
|
||||
|
||||
5
src/diffusers/models/autoencoders/__init__.py
Normal file
5
src/diffusers/models/autoencoders/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
@@ -16,10 +16,10 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .modeling_outputs import AutoencoderKLOutput
|
||||
from .modeling_utils import ModelMixin
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
|
||||
|
||||
|
||||
@@ -16,10 +16,10 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalVAEMixin
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .attention_processor import (
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalVAEMixin
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
@@ -27,8 +27,8 @@ from .attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .modeling_outputs import AutoencoderKLOutput
|
||||
from .modeling_utils import ModelMixin
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
@@ -16,14 +16,14 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalVAEMixin
|
||||
from ..utils import is_torch_version
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from .modeling_outputs import AutoencoderKLOutput
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalVAEMixin
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
@@ -18,10 +18,10 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .modeling_utils import ModelMixin
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DecoderTiny, EncoderTiny
|
||||
|
||||
|
||||
@@ -18,20 +18,20 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..schedulers import ConsistencyDecoderScheduler
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .attention_processor import (
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...schedulers import ConsistencyDecoderScheduler
|
||||
from ...utils import BaseOutput
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d import UNet2DModel
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unet_2d import UNet2DModel
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
@@ -162,7 +162,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.use_tiling = use_tiling
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
|
||||
def disable_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
@@ -170,7 +170,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.enable_tiling(False)
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
|
||||
def enable_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
@@ -178,7 +178,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
|
||||
def disable_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
@@ -333,14 +333,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
|
||||
return DecoderOutput(sample=x_0)
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
@@ -18,11 +18,11 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import BaseOutput, is_torch_version
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .activations import get_activation
|
||||
from .attention_processor import SpatialNorm
|
||||
from .unet_2d_blocks import (
|
||||
from ...utils import BaseOutput, is_torch_version
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import SpatialNorm
|
||||
from ..unet_2d_blocks import (
|
||||
AutoencoderTinyBlock,
|
||||
UNetMidBlock2D,
|
||||
get_down_block,
|
||||
@@ -77,6 +77,7 @@ class Encoder(nn.Module):
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
double_z: bool = True,
|
||||
mid_block_add_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
@@ -124,6 +125,7 @@ class Encoder(nn.Module):
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=None,
|
||||
add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
# out
|
||||
@@ -213,6 +215,7 @@ class Decoder(nn.Module):
|
||||
norm_num_groups: int = 32,
|
||||
act_fn: str = "silu",
|
||||
norm_type: str = "group", # group, spatial
|
||||
mid_block_add_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
@@ -240,6 +243,7 @@ class Decoder(nn.Module):
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=temb_channels,
|
||||
add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
338
src/diffusers/models/downsampling.py
Normal file
338
src/diffusers/models/downsampling.py
Normal file
@@ -0,0 +1,338 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .lora import LoRACompatibleConv
|
||||
from .normalization import RMSNorm
|
||||
from .upsampling import upfirdn2d_native
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
"""A 1D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
return self.conv(inputs)
|
||||
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
"""A 2D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
eps=None,
|
||||
elementwise_affine=None,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if norm_type == "ln_norm":
|
||||
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(channels, eps, elementwise_affine)
|
||||
elif norm_type is None:
|
||||
self.norm = None
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type: {norm_type}")
|
||||
|
||||
if use_conv:
|
||||
conv = conv_cls(
|
||||
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.Conv2d_0 = conv
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.norm is not None:
|
||||
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
if self.use_conv and self.padding == 0:
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if isinstance(self.conv, LoRACompatibleConv):
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FirDownsample2D(nn.Module):
|
||||
"""A 2D FIR downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _downsample_2d(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
||||
datatype as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
# setup kernel
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * gain
|
||||
|
||||
if self.use_conv:
|
||||
_, _, convH, convW = weight.shape
|
||||
pad_value = (kernel.shape[0] - factor) + (convW - 1)
|
||||
stride_value = [factor, factor]
|
||||
upfirdn_input = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
||||
else:
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
|
||||
class KDownsample2D(nn.Module):
|
||||
r"""A 2D K-downsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv2d(inputs, weight, stride=2)
|
||||
|
||||
|
||||
def downsample_2d(
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Downsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
||||
shape is a multiple of the downsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`)
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * gain
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
@@ -197,11 +197,12 @@ class TimestepEmbedding(nn.Module):
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
self.linear_1 = linear_cls(in_channels, time_embed_dim)
|
||||
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
@@ -214,7 +215,7 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
|
||||
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
@@ -729,7 +730,7 @@ class PositionNet(nn.Module):
|
||||
return objs
|
||||
|
||||
|
||||
class CombinedTimestepSizeEmbeddings(nn.Module):
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
"""
|
||||
For PixArt-Alpha.
|
||||
|
||||
@@ -746,45 +747,27 @@ class CombinedTimestepSizeEmbeddings(nn.Module):
|
||||
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
if use_additional_conditions:
|
||||
self.use_additional_conditions = True
|
||||
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
||||
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
||||
|
||||
def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
|
||||
if size.ndim == 1:
|
||||
size = size[:, None]
|
||||
|
||||
if size.shape[0] != batch_size:
|
||||
size = size.repeat(batch_size // size.shape[0], 1)
|
||||
if size.shape[0] != batch_size:
|
||||
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
|
||||
|
||||
current_batch_size, dims = size.shape[0], size.shape[1]
|
||||
size = size.reshape(-1)
|
||||
size_freq = self.additional_condition_proj(size).to(size.dtype)
|
||||
|
||||
size_emb = embedder(size_freq)
|
||||
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
|
||||
return size_emb
|
||||
|
||||
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
|
||||
if self.use_additional_conditions:
|
||||
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
|
||||
aspect_ratio = self.apply_condition(
|
||||
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
|
||||
)
|
||||
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
|
||||
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
|
||||
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
|
||||
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
|
||||
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
|
||||
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
|
||||
else:
|
||||
conditioning = timesteps_emb
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class CaptionProjection(nn.Module):
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
@@ -796,9 +779,8 @@ class CaptionProjection(nn.Module):
|
||||
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
|
||||
self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))
|
||||
|
||||
def forward(self, caption, force_drop_ids=None):
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
|
||||
@@ -13,14 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numbers
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import is_torch_version
|
||||
from .activations import get_activation
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
@@ -91,7 +93,7 @@ class AdaLayerNormSingle(nn.Module):
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.emb = CombinedTimestepSizeEmbeddings(
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
|
||||
)
|
||||
|
||||
@@ -146,3 +148,107 @@ class AdaGroupNorm(nn.Module):
|
||||
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class AdaLayerNormContinuous(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
||||
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
||||
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
||||
# However, this is how it was implemented in the original code, and it's rather likely you should
|
||||
# set `elementwise_affine` to False.
|
||||
elementwise_affine=True,
|
||||
eps=1e-5,
|
||||
bias=True,
|
||||
norm_type="layer_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(self.silu(conditioning_embedding))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
if is_torch_version(">=", "2.1.0"):
|
||||
LayerNorm = nn.LayerNorm
|
||||
else:
|
||||
# Has optional bias parameter compared to torch layer norm
|
||||
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
|
||||
if isinstance(dim, numbers.Integral):
|
||||
dim = (dim,)
|
||||
|
||||
self.dim = torch.Size(dim)
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
||||
else:
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
|
||||
if isinstance(dim, numbers.Integral):
|
||||
dim = (dim,)
|
||||
|
||||
self.dim = torch.Size(dim)
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
|
||||
if self.weight is not None:
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = hidden_states * self.weight
|
||||
else:
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GlobalResponseNorm(nn.Module):
|
||||
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
||||
|
||||
def forward(self, x):
|
||||
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * nx) + self.beta + x
|
||||
|
||||
@@ -23,562 +23,23 @@ import torch.nn.functional as F
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .activations import get_activation
|
||||
from .attention_processor import SpatialNorm
|
||||
from .downsampling import ( # noqa
|
||||
Downsample1D,
|
||||
Downsample2D,
|
||||
FirDownsample2D,
|
||||
KDownsample2D,
|
||||
downsample_2d,
|
||||
)
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from .normalization import AdaGroupNorm
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(inputs)
|
||||
|
||||
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
outputs = self.conv(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
"""A 1D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
return self.conv(inputs)
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
"""A 2D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
output_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(hidden_states)
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
||||
# https://github.com/pytorch/pytorch/issues/86679
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if output_size is None:
|
||||
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.use_conv:
|
||||
if self.name == "conv":
|
||||
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.Conv2d_0(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.Conv2d_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
"""A 2D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if use_conv:
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.Conv2d_0 = conv
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv and self.padding == 0:
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if isinstance(self.conv, LoRACompatibleConv):
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FirUpsample2D(nn.Module):
|
||||
"""A 2D FIR upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`, optional):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.use_conv = use_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _upsample_2d(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to nearest-neighbor upsampling.
|
||||
factor (`int`, *optional*): Integer upsampling factor (default: 2).
|
||||
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
||||
datatype as `hidden_states`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
|
||||
# Setup filter kernel.
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
# setup kernel
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * (gain * (factor**2))
|
||||
|
||||
if self.use_conv:
|
||||
convH = weight.shape[2]
|
||||
convW = weight.shape[3]
|
||||
inC = weight.shape[1]
|
||||
|
||||
pad_value = (kernel.shape[0] - factor) - (convW - 1)
|
||||
|
||||
stride = (factor, factor)
|
||||
# Determine data dimensions.
|
||||
output_shape = (
|
||||
(hidden_states.shape[2] - 1) * factor + convH,
|
||||
(hidden_states.shape[3] - 1) * factor + convW,
|
||||
)
|
||||
output_padding = (
|
||||
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
|
||||
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
|
||||
)
|
||||
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
||||
num_groups = hidden_states.shape[1] // inC
|
||||
|
||||
# Transpose weights.
|
||||
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
|
||||
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
|
||||
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
inverse_conv = F.conv_transpose2d(
|
||||
hidden_states,
|
||||
weight,
|
||||
stride=stride,
|
||||
output_padding=output_padding,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
output = upfirdn2d_native(
|
||||
inverse_conv,
|
||||
torch.tensor(kernel, device=inverse_conv.device),
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
|
||||
)
|
||||
else:
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
up=factor,
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return height
|
||||
|
||||
|
||||
class FirDownsample2D(nn.Module):
|
||||
"""A 2D FIR downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _downsample_2d(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
||||
datatype as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
# setup kernel
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * gain
|
||||
|
||||
if self.use_conv:
|
||||
_, _, convH, convW = weight.shape
|
||||
pad_value = (kernel.shape[0] - factor) + (convW - 1)
|
||||
stride_value = [factor, factor]
|
||||
upfirdn_input = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
||||
else:
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
|
||||
class KDownsample2D(nn.Module):
|
||||
r"""A 2D K-downsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv2d(inputs, weight, stride=2)
|
||||
|
||||
|
||||
class KUpsample2D(nn.Module):
|
||||
r"""A 2D K-upsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
from .upsampling import ( # noqa
|
||||
FirUpsample2D,
|
||||
KUpsample2D,
|
||||
Upsample1D,
|
||||
Upsample2D,
|
||||
upfirdn2d_native,
|
||||
upsample_2d,
|
||||
)
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
@@ -894,151 +355,6 @@ class ResidualTemporalBlock1D(nn.Module):
|
||||
return out + self.residual_conv(inputs)
|
||||
|
||||
|
||||
def upsample_2d(
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
|
||||
a: multiple of the upsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to nearest-neighbor upsampling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer upsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
"""
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * (gain * (factor**2))
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
up=factor,
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def downsample_2d(
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Downsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
||||
shape is a multiple of the downsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`)
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * gain
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def upfirdn2d_native(
|
||||
tensor: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
up: int = 1,
|
||||
down: int = 1,
|
||||
pad: Tuple[int, int] = (0, 0),
|
||||
) -> torch.Tensor:
|
||||
up_x = up_y = up
|
||||
down_x = down_y = down
|
||||
pad_x0 = pad_y0 = pad[0]
|
||||
pad_x1 = pad_y1 = pad[1]
|
||||
|
||||
_, channel, in_h, in_w = tensor.shape
|
||||
tensor = tensor.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = tensor.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out.to(tensor.device) # Move back to mps if necessary
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
||||
:,
|
||||
]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
||||
|
||||
|
||||
class TemporalConvLayer(nn.Module):
|
||||
"""
|
||||
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
|
||||
|
||||
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
||||
from .attention import BasicTransformerBlock
|
||||
from .embeddings import CaptionProjection, PatchEmbed
|
||||
from .embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from .modeling_utils import ModelMixin
|
||||
from .normalization import AdaLayerNormSingle
|
||||
@@ -235,7 +235,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.caption_projection = None
|
||||
if caption_channels is not None:
|
||||
self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user