mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
72 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b14ce397e | ||
|
|
23159f4adb | ||
|
|
4c476e99b5 | ||
|
|
9c876a5915 | ||
|
|
6ba0efb9a1 | ||
|
|
46ceba5b35 | ||
|
|
977162c02b | ||
|
|
744663f8dc | ||
|
|
abbf3c1adf | ||
|
|
da2ce1a6b9 | ||
|
|
e51f19aee8 | ||
|
|
1ffcc924bc | ||
|
|
730e01ec93 | ||
|
|
0d196f9f45 | ||
|
|
131312caba | ||
|
|
e9edbfc251 | ||
|
|
0ddc5bf7b9 | ||
|
|
c5933c9c89 | ||
|
|
91a2a80eb2 | ||
|
|
425192fe15 | ||
|
|
9965cb50ea | ||
|
|
20e426cb5d | ||
|
|
90eac14f72 | ||
|
|
11f527ac0f | ||
|
|
2c04e5855c | ||
|
|
391cfcd7d7 | ||
|
|
bc0392a0cb | ||
|
|
05d9baeacd | ||
|
|
e573ae06e2 | ||
|
|
2f6351b001 | ||
|
|
9c856118c7 | ||
|
|
9bce375f77 | ||
|
|
3045fb2763 | ||
|
|
7b0ba4820a | ||
|
|
8d5906a331 | ||
|
|
17470057d2 | ||
|
|
a5b242d30d | ||
|
|
a121e05feb | ||
|
|
3979aac996 | ||
|
|
7e6886f5e9 | ||
|
|
a4c91be73b | ||
|
|
3becd368b1 | ||
|
|
c8fdfe4572 | ||
|
|
bba1c1de15 | ||
|
|
86ecd4b795 | ||
|
|
bdeff4d64a | ||
|
|
fc1883918f | ||
|
|
f0c74e9a75 | ||
|
|
4bc157ffa9 | ||
|
|
f2df39fa0e | ||
|
|
8ecdd3ef65 | ||
|
|
cd8b7507c2 | ||
|
|
3b641eabe9 | ||
|
|
703307efcc | ||
|
|
ed8fd38337 | ||
|
|
ca783a0f1f | ||
|
|
beb848e2b6 | ||
|
|
cfc99adf0f | ||
|
|
807f69b328 | ||
|
|
b811964a7b | ||
|
|
1bd4c9e93d | ||
|
|
eb2ef31606 | ||
|
|
5c9dd0af95 | ||
|
|
d0f258206d | ||
|
|
3eaead0c4a | ||
|
|
3bf5ce21ad | ||
|
|
3a9d7d9758 | ||
|
|
e748b3c6e1 | ||
|
|
46c52f9b96 | ||
|
|
d06e06940b | ||
|
|
0a73b4d3cd | ||
|
|
e126a82cc5 |
33
.github/workflows/pr_tests.yml
vendored
33
.github/workflows/pr_tests.yml
vendored
@@ -21,22 +21,27 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch CPU tests on Ubuntu
|
||||
framework: pytorch
|
||||
- name: Fast PyTorch Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu
|
||||
- name: Fast Flax CPU tests on Ubuntu
|
||||
report: torch_cpu_pipelines
|
||||
- name: Fast PyTorch Models & Schedulers CPU tests
|
||||
framework: pytorch_models
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_models_schedulers
|
||||
- name: Fast Flax CPU tests
|
||||
framework: flax
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: Fast ONNXRuntime CPU tests on Ubuntu
|
||||
- name: Fast ONNXRuntime CPU tests
|
||||
framework: onnxruntime
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-onnxruntime-cpu
|
||||
report: onnx_cpu
|
||||
- name: PyTorch Example CPU tests on Ubuntu
|
||||
- name: PyTorch Example CPU tests
|
||||
framework: pytorch_examples
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
@@ -71,13 +76,21 @@ jobs:
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
tests/pipelines
|
||||
|
||||
- name: Run fast PyTorch Model Scheduler CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_models' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/models tests/schedulers tests/others
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
@@ -85,7 +98,7 @@ jobs:
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
tests
|
||||
|
||||
- name: Run fast ONNXRuntime CPU tests
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
- local: using-diffusers/schedulers
|
||||
title: Load and compare different schedulers
|
||||
- local: using-diffusers/custom_pipeline_overview
|
||||
title: Load and add custom pipelines
|
||||
title: Load community pipelines
|
||||
- local: using-diffusers/kerascv
|
||||
title: Load KerasCV Stable Diffusion checkpoints
|
||||
title: Loading & Hub
|
||||
@@ -47,9 +47,9 @@
|
||||
- local: using-diffusers/reproducibility
|
||||
title: Create reproducible pipelines
|
||||
- local: using-diffusers/custom_pipeline_examples
|
||||
title: Community Pipelines
|
||||
title: Community pipelines
|
||||
- local: using-diffusers/contribute_pipeline
|
||||
title: How to contribute a Pipeline
|
||||
title: How to contribute a community pipeline
|
||||
- local: using-diffusers/using_safetensors
|
||||
title: Using safetensors
|
||||
- local: using-diffusers/stable_diffusion_jax_how_to
|
||||
@@ -74,6 +74,8 @@
|
||||
title: ControlNet
|
||||
- local: training/instructpix2pix
|
||||
title: InstructPix2Pix Training
|
||||
- local: training/custom_diffusion
|
||||
title: Custom Diffusion
|
||||
title: Training
|
||||
- sections:
|
||||
- local: using-diffusers/rl
|
||||
@@ -103,6 +105,8 @@
|
||||
title: MPS
|
||||
- local: optimization/habana
|
||||
title: Habana Gaudi
|
||||
- local: optimization/tome
|
||||
title: Token Merging
|
||||
title: Optimization/Special Hardware
|
||||
- sections:
|
||||
- local: conceptual/philosophy
|
||||
@@ -150,6 +154,8 @@
|
||||
title: DDPM
|
||||
- local: api/pipelines/dit
|
||||
title: DiT
|
||||
- local: api/pipelines/if
|
||||
title: IF
|
||||
- local: api/pipelines/latent_diffusion
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/paint_by_example
|
||||
|
||||
@@ -36,3 +36,7 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g
|
||||
### LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.LoraLoaderMixin
|
||||
|
||||
### FromCkptMixin
|
||||
|
||||
[[autodoc]] loaders.FromCkptMixin
|
||||
|
||||
@@ -25,14 +25,14 @@ This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit
|
||||
|
||||
## Text-to-Audio
|
||||
|
||||
The [`AudioLDMPipeline`] can be used to load pre-trained weights from [cvssp/audioldm](https://huggingface.co/cvssp/audioldm) and generate text-conditional audio outputs:
|
||||
The [`AudioLDMPipeline`] can be used to load pre-trained weights from [cvssp/audioldm-s-full-v2](https://huggingface.co/cvssp/audioldm-s-full-v2) and generate text-conditional audio outputs:
|
||||
|
||||
```python
|
||||
from diffusers import AudioLDMPipeline
|
||||
import torch
|
||||
import scipy
|
||||
|
||||
repo_id = "cvssp/audioldm"
|
||||
repo_id = "cvssp/audioldm-s-full-v2"
|
||||
pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
@@ -56,7 +56,7 @@ Inference:
|
||||
### How to load and use different schedulers
|
||||
|
||||
The AudioLDM pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers
|
||||
that can be used with the AudioLDM pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
|
||||
that can be used with the AudioLDM pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
|
||||
[`EulerAncestralDiscreteScheduler`] etc. We recommend using the [`DPMSolverMultistepScheduler`] as it's currently the fastest
|
||||
scheduler there is.
|
||||
|
||||
@@ -68,12 +68,14 @@ method, or pass the `scheduler` argument to the `from_pretrained` method of the
|
||||
>>> from diffusers import AudioLDMPipeline, DPMSolverMultistepScheduler
|
||||
>>> import torch
|
||||
|
||||
>>> pipeline = AudioLDMPipeline.from_pretrained("cvssp/audioldm", torch_dtype=torch.float16)
|
||||
>>> pipeline = AudioLDMPipeline.from_pretrained("cvssp/audioldm-s-full-v2", torch_dtype=torch.float16)
|
||||
>>> pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
>>> # or
|
||||
>>> dpm_scheduler = DPMSolverMultistepScheduler.from_pretrained("cvssp/audioldm", subfolder="scheduler")
|
||||
>>> pipeline = AudioLDMPipeline.from_pretrained("cvssp/audioldm", scheduler=dpm_scheduler, torch_dtype=torch.float16)
|
||||
>>> dpm_scheduler = DPMSolverMultistepScheduler.from_pretrained("cvssp/audioldm-s-full-v2", subfolder="scheduler")
|
||||
>>> pipeline = AudioLDMPipeline.from_pretrained(
|
||||
... "cvssp/audioldm-s-full-v2", scheduler=dpm_scheduler, torch_dtype=torch.float16
|
||||
... )
|
||||
```
|
||||
|
||||
## AudioLDMPipeline
|
||||
|
||||
523
docs/source/en/api/pipelines/if.mdx
Normal file
523
docs/source/en/api/pipelines/if.mdx
Normal file
@@ -0,0 +1,523 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# IF
|
||||
|
||||
## Overview
|
||||
|
||||
DeepFloyd IF is a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding.
|
||||
The model is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules:
|
||||
- Stage 1: a base model that generates 64x64 px image based on text prompt,
|
||||
- Stage 2: a 64x64 px => 256x256 px super-resolution model, and a
|
||||
- Stage 3: a 256x256 px => 1024x1024 px super-resolution model
|
||||
Stage 1 and Stage 2 utilize a frozen text encoder based on the T5 transformer to extract text embeddings,
|
||||
which are then fed into a UNet architecture enhanced with cross-attention and attention pooling.
|
||||
Stage 3 is [Stability's x4 Upscaling model](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler).
|
||||
The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset.
|
||||
Our work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis.
|
||||
|
||||
## Usage
|
||||
|
||||
Before you can use IF, you need to accept its usage conditions. To do so:
|
||||
1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be logged in
|
||||
2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0). Accepting the license on the stage I model card will auto accept for the other IF models.
|
||||
3. Make sure to login locally. Install `huggingface_hub`
|
||||
```sh
|
||||
pip install huggingface_hub --upgrade
|
||||
```
|
||||
|
||||
run the login function in a Python shell
|
||||
|
||||
```py
|
||||
from huggingface_hub import login
|
||||
|
||||
login()
|
||||
```
|
||||
|
||||
and enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens).
|
||||
|
||||
Next we install `diffusers` and dependencies:
|
||||
|
||||
```sh
|
||||
pip install diffusers accelerate transformers safetensors
|
||||
```
|
||||
|
||||
The following sections give more in-detail examples of how to use IF. Specifically:
|
||||
|
||||
- [Text-to-Image Generation](#text-to-image-generation)
|
||||
- [Image-to-Image Generation](#text-guided-image-to-image-generation)
|
||||
- [Inpainting](#text-guided-inpainting-generation)
|
||||
- [Reusing model weights](#converting-between-different-pipelines)
|
||||
- [Speed optimization](#optimizing-for-speed)
|
||||
- [Memory optimization](#optimizing-for-memory)
|
||||
|
||||
**Available checkpoints**
|
||||
- *Stage-1*
|
||||
- [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)
|
||||
- [DeepFloyd/IF-I-L-v1.0](https://huggingface.co/DeepFloyd/IF-I-L-v1.0)
|
||||
- [DeepFloyd/IF-I-M-v1.0](https://huggingface.co/DeepFloyd/IF-I-M-v1.0)
|
||||
|
||||
- *Stage-2*
|
||||
- [DeepFloyd/IF-II-L-v1.0](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)
|
||||
- [DeepFloyd/IF-II-M-v1.0](https://huggingface.co/DeepFloyd/IF-II-M-v1.0)
|
||||
|
||||
- *Stage-3*
|
||||
- [stabilityai/stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler)
|
||||
|
||||
**Demo**
|
||||
[](https://huggingface.co/spaces/DeepFloyd/IF)
|
||||
|
||||
**Google Colab**
|
||||
[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
|
||||
|
||||
### Text-to-Image Generation
|
||||
|
||||
By default diffusers makes use of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings)
|
||||
to run the whole IF pipeline with as little as 14 GB of VRAM.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import pt_to_pil
|
||||
import torch
|
||||
|
||||
# stage 1
|
||||
stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
stage_1.enable_model_cpu_offload()
|
||||
|
||||
# stage 2
|
||||
stage_2 = DiffusionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
stage_2.enable_model_cpu_offload()
|
||||
|
||||
# stage 3
|
||||
safety_modules = {
|
||||
"feature_extractor": stage_1.feature_extractor,
|
||||
"safety_checker": stage_1.safety_checker,
|
||||
"watermarker": stage_1.watermarker,
|
||||
}
|
||||
stage_3 = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
|
||||
)
|
||||
stage_3.enable_model_cpu_offload()
|
||||
|
||||
prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
|
||||
generator = torch.manual_seed(1)
|
||||
|
||||
# text embeds
|
||||
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
|
||||
|
||||
# stage 1
|
||||
image = stage_1(
|
||||
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt"
|
||||
).images
|
||||
pt_to_pil(image)[0].save("./if_stage_I.png")
|
||||
|
||||
# stage 2
|
||||
image = stage_2(
|
||||
image=image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
generator=generator,
|
||||
output_type="pt",
|
||||
).images
|
||||
pt_to_pil(image)[0].save("./if_stage_II.png")
|
||||
|
||||
# stage 3
|
||||
image = stage_3(prompt=prompt, image=image, noise_level=100, generator=generator).images
|
||||
image[0].save("./if_stage_III.png")
|
||||
```
|
||||
|
||||
### Text Guided Image-to-Image Generation
|
||||
|
||||
The same IF model weights can be used for text-guided image-to-image translation or image variation.
|
||||
In this case just make sure to load the weights using the [`IFInpaintingPipeline`] and [`IFInpaintingSuperResolutionPipeline`] pipelines.
|
||||
|
||||
**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines
|
||||
without loading them twice by making use of the [`~DiffusionPipeline.components()`] function as explained [here](#converting-between-different-pipelines).
|
||||
|
||||
```python
|
||||
from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline
|
||||
from diffusers.utils import pt_to_pil
|
||||
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
from io import BytesIO
|
||||
|
||||
# download image
|
||||
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
response = requests.get(url)
|
||||
original_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
original_image = original_image.resize((768, 512))
|
||||
|
||||
# stage 1
|
||||
stage_1 = IFImg2ImgPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
stage_1.enable_model_cpu_offload()
|
||||
|
||||
# stage 2
|
||||
stage_2 = IFImg2ImgSuperResolutionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
stage_2.enable_model_cpu_offload()
|
||||
|
||||
# stage 3
|
||||
safety_modules = {
|
||||
"feature_extractor": stage_1.feature_extractor,
|
||||
"safety_checker": stage_1.safety_checker,
|
||||
"watermarker": stage_1.watermarker,
|
||||
}
|
||||
stage_3 = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
|
||||
)
|
||||
stage_3.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A fantasy landscape in style minecraft"
|
||||
generator = torch.manual_seed(1)
|
||||
|
||||
# text embeds
|
||||
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
|
||||
|
||||
# stage 1
|
||||
image = stage_1(
|
||||
image=original_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
generator=generator,
|
||||
output_type="pt",
|
||||
).images
|
||||
pt_to_pil(image)[0].save("./if_stage_I.png")
|
||||
|
||||
# stage 2
|
||||
image = stage_2(
|
||||
image=image,
|
||||
original_image=original_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
generator=generator,
|
||||
output_type="pt",
|
||||
).images
|
||||
pt_to_pil(image)[0].save("./if_stage_II.png")
|
||||
|
||||
# stage 3
|
||||
image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images
|
||||
image[0].save("./if_stage_III.png")
|
||||
```
|
||||
|
||||
### Text Guided Inpainting Generation
|
||||
|
||||
The same IF model weights can be used for text-guided image-to-image translation or image variation.
|
||||
In this case just make sure to load the weights using the [`IFInpaintingPipeline`] and [`IFInpaintingSuperResolutionPipeline`] pipelines.
|
||||
|
||||
**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines
|
||||
without loading them twice by making use of the [`~DiffusionPipeline.components()`] function as explained [here](#converting-between-different-pipelines).
|
||||
|
||||
```python
|
||||
from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline
|
||||
from diffusers.utils import pt_to_pil
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
import requests
|
||||
from io import BytesIO
|
||||
|
||||
# download image
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png"
|
||||
response = requests.get(url)
|
||||
original_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
original_image = original_image
|
||||
|
||||
# download mask
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png"
|
||||
response = requests.get(url)
|
||||
mask_image = Image.open(BytesIO(response.content))
|
||||
mask_image = mask_image
|
||||
|
||||
# stage 1
|
||||
stage_1 = IFInpaintingPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
stage_1.enable_model_cpu_offload()
|
||||
|
||||
# stage 2
|
||||
stage_2 = IFInpaintingSuperResolutionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
stage_2.enable_model_cpu_offload()
|
||||
|
||||
# stage 3
|
||||
safety_modules = {
|
||||
"feature_extractor": stage_1.feature_extractor,
|
||||
"safety_checker": stage_1.safety_checker,
|
||||
"watermarker": stage_1.watermarker,
|
||||
}
|
||||
stage_3 = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
|
||||
)
|
||||
stage_3.enable_model_cpu_offload()
|
||||
|
||||
prompt = "blue sunglasses"
|
||||
generator = torch.manual_seed(1)
|
||||
|
||||
# text embeds
|
||||
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
|
||||
|
||||
# stage 1
|
||||
image = stage_1(
|
||||
image=original_image,
|
||||
mask_image=mask_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
generator=generator,
|
||||
output_type="pt",
|
||||
).images
|
||||
pt_to_pil(image)[0].save("./if_stage_I.png")
|
||||
|
||||
# stage 2
|
||||
image = stage_2(
|
||||
image=image,
|
||||
original_image=original_image,
|
||||
mask_image=mask_image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
generator=generator,
|
||||
output_type="pt",
|
||||
).images
|
||||
pt_to_pil(image)[0].save("./if_stage_II.png")
|
||||
|
||||
# stage 3
|
||||
image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images
|
||||
image[0].save("./if_stage_III.png")
|
||||
```
|
||||
|
||||
### Converting between different pipelines
|
||||
|
||||
In addition to being loaded with `from_pretrained`, Pipelines can also be loaded directly from each other.
|
||||
|
||||
```python
|
||||
from diffusers import IFPipeline, IFSuperResolutionPipeline
|
||||
|
||||
pipe_1 = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0")
|
||||
pipe_2 = IFSuperResolutionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0")
|
||||
|
||||
|
||||
from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline
|
||||
|
||||
pipe_1 = IFImg2ImgPipeline(**pipe_1.components)
|
||||
pipe_2 = IFImg2ImgSuperResolutionPipeline(**pipe_2.components)
|
||||
|
||||
|
||||
from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline
|
||||
|
||||
pipe_1 = IFInpaintingPipeline(**pipe_1.components)
|
||||
pipe_2 = IFInpaintingSuperResolutionPipeline(**pipe_2.components)
|
||||
```
|
||||
|
||||
### Optimizing for speed
|
||||
|
||||
The simplest optimization to run IF faster is to move all model components to the GPU.
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
You can also run the diffusion process for a shorter number of timesteps.
|
||||
|
||||
This can either be done with the `num_inference_steps` argument
|
||||
|
||||
```py
|
||||
pipe("<prompt>", num_inference_steps=30)
|
||||
```
|
||||
|
||||
Or with the `timesteps` argument
|
||||
|
||||
```py
|
||||
from diffusers.pipelines.deepfloyd_if import fast27_timesteps
|
||||
|
||||
pipe("<prompt>", timesteps=fast27_timesteps)
|
||||
```
|
||||
|
||||
When doing image variation or inpainting, you can also decrease the number of timesteps
|
||||
with the strength argument. The strength argument is the amount of noise to add to
|
||||
the input image which also determines how many steps to run in the denoising process.
|
||||
A smaller number will vary the image less but run faster.
|
||||
|
||||
```py
|
||||
pipe = IFImg2ImgPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = pipe(image=image, prompt="<prompt>", strength=0.3).images
|
||||
```
|
||||
|
||||
You can also use [`torch.compile`](../../optimization/torch2.0). Note that we have not exhaustively tested `torch.compile`
|
||||
with IF and it might not give expected results.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
|
||||
pipe.text_encoder = torch.compile(pipe.text_encoder)
|
||||
pipe.unet = torch.compile(pipe.unet)
|
||||
```
|
||||
|
||||
### Optimizing for memory
|
||||
|
||||
When optimizing for GPU memory, we can use the standard diffusers cpu offloading APIs.
|
||||
|
||||
Either the model based CPU offloading,
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
or the more aggressive layer based CPU offloading.
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
```
|
||||
|
||||
Additionally, T5 can be loaded in 8bit precision
|
||||
|
||||
```py
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
"DeepFloyd/IF-I-XL-v1.0", subfolder="text_encoder", device_map="auto", load_in_8bit=True, variant="8bit"
|
||||
)
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-I-XL-v1.0",
|
||||
text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder
|
||||
unet=None,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
prompt_embeds, negative_embeds = pipe.encode_prompt("<prompt>")
|
||||
```
|
||||
|
||||
For CPU RAM constrained machines like google colab free tier where we can't load all
|
||||
model components to the CPU at once, we can manually only load the pipeline with
|
||||
the text encoder or unet when the respective model components are needed.
|
||||
|
||||
```py
|
||||
from diffusers import IFPipeline, IFSuperResolutionPipeline
|
||||
import torch
|
||||
import gc
|
||||
from transformers import T5EncoderModel
|
||||
from diffusers.utils import pt_to_pil
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
"DeepFloyd/IF-I-XL-v1.0", subfolder="text_encoder", device_map="auto", load_in_8bit=True, variant="8bit"
|
||||
)
|
||||
|
||||
# text to image
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-I-XL-v1.0",
|
||||
text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder
|
||||
unet=None,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
|
||||
prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
|
||||
|
||||
# Remove the pipeline so we can re-load the pipeline with the unet
|
||||
del text_encoder
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
pipe = IFPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-I-XL-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
image = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
output_type="pt",
|
||||
generator=generator,
|
||||
).images
|
||||
|
||||
pt_to_pil(image)[0].save("./if_stage_I.png")
|
||||
|
||||
# Remove the pipeline so we can load the super-resolution pipeline
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# First super resolution
|
||||
|
||||
pipe = IFSuperResolutionPipeline.from_pretrained(
|
||||
"DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
image = pipe(
|
||||
image=image,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_embeds,
|
||||
output_type="pt",
|
||||
generator=generator,
|
||||
).images
|
||||
|
||||
pt_to_pil(image)[0].save("./if_stage_II.png")
|
||||
```
|
||||
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Colab
|
||||
|---|---|:---:|
|
||||
| [pipeline_if.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py) | *Text-to-Image Generation* | - |
|
||||
| [pipeline_if_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py) | *Text-to-Image Generation* | - |
|
||||
| [pipeline_if_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py) | *Image-to-Image Generation* | - |
|
||||
| [pipeline_if_img2img_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py) | *Image-to-Image Generation* | - |
|
||||
| [pipeline_if_inpainting.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py) | *Image-to-Image Generation* | - |
|
||||
| [pipeline_if_inpainting_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py) | *Image-to-Image Generation* | - |
|
||||
|
||||
## IFPipeline
|
||||
[[autodoc]] IFPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## IFSuperResolutionPipeline
|
||||
[[autodoc]] IFSuperResolutionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## IFImg2ImgPipeline
|
||||
[[autodoc]] IFImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## IFImg2ImgSuperResolutionPipeline
|
||||
[[autodoc]] IFImg2ImgSuperResolutionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## IFInpaintingPipeline
|
||||
[[autodoc]] IFInpaintingPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## IFInpaintingSuperResolutionPipeline
|
||||
[[autodoc]] IFInpaintingSuperResolutionPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -51,6 +51,9 @@ available a colab notebook to directly try them out.
|
||||
| [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
|
||||
| [if](./if) | [**IF**](https://github.com/deep-floyd/IF) | Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
|
||||
| [if_img2img](./if) | [**IF**](https://github.com/deep-floyd/IF) | Image-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
|
||||
| [if_inpainting](./if) | [**IF**](https://github.com/deep-floyd/IF) | Image-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
|
||||
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
|
||||
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
|
||||
@@ -242,6 +242,41 @@ image.save("./multi_controlnet_output.png")
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/controlnet/multi_controlnet_output.png" width=600/>
|
||||
|
||||
### Guess Mode
|
||||
|
||||
Guess Mode is [a ControlNet feature that was implemented](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode) after the publication of [the paper](https://arxiv.org/abs/2302.05543). The description states:
|
||||
|
||||
>In this mode, the ControlNet encoder will try best to recognize the content of the input control map, like depth map, edge map, scribbles, etc, even if you remove all prompts.
|
||||
|
||||
#### The core implementation:
|
||||
|
||||
It adjusts the scale of the output residuals from ControlNet by a fixed ratio depending on the block depth. The shallowest DownBlock corresponds to `0.1`. As the blocks get deeper, the scale increases exponentially, and the scale for the output of the MidBlock becomes `1.0`.
|
||||
|
||||
Since the core implementation is just this, **it does not have any impact on prompt conditioning**. While it is common to use it without specifying any prompts, it is also possible to provide prompts if desired.
|
||||
|
||||
#### Usage:
|
||||
|
||||
Just specify `guess_mode=True` in the pipe() function. A `guidance_scale` between 3.0 and 5.0 is [recommended](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode).
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
import torch
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet).to(
|
||||
"cuda"
|
||||
)
|
||||
image = pipe("", image=canny_image, guess_mode=True, guidance_scale=3.0).images[0]
|
||||
image.save("guess_mode_generated.png")
|
||||
```
|
||||
|
||||
#### Output image comparison:
|
||||
Canny Control Example
|
||||
|
||||
|no guess_mode with prompt|guess_mode without prompt|
|
||||
|---|---|
|
||||
|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0.png"><img width="128" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0_gm.png"><img width="128" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0_gm.png"/></a>|
|
||||
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
ControlNet requires a *control image* in addition to the text-to-image *prompt*.
|
||||
@@ -249,7 +284,9 @@ Each pretrained model is trained using a different conditioning method that requ
|
||||
|
||||
All checkpoints can be found under the authors' namespace [lllyasviel](https://huggingface.co/lllyasviel).
|
||||
|
||||
### ControlNet with Stable Diffusion 1.5
|
||||
**13.04.2024 Update**: The author has released improved controlnet checkpoints v1.1 - see [here](#controlnet-v1.1).
|
||||
|
||||
### ControlNet v1.0
|
||||
|
||||
| Model Name | Control Image Overview| Control Image Example | Generated Image Example |
|
||||
|---|---|---|---|
|
||||
@@ -262,6 +299,24 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
|
||||
|[lllyasviel/sd-controlnet-scribble](https://huggingface.co/lllyasviel/sd-controlnet_scribble)<br/> *Trained with human scribbles* |A hand-drawn monochrome image with white outlines on a black background.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_vermeer_scribble.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_vermeer_scribble.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_vermeer_scribble_0.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_vermeer_scribble_0.png"/></a> |
|
||||
|[lllyasviel/sd-controlnet-seg](https://huggingface.co/lllyasviel/sd-controlnet_seg)<br/>*Trained with semantic segmentation* |An [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/)'s segmentation protocol image.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_room_seg.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_room_seg.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_room_seg_1.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_room_seg_1.png"/></a> |
|
||||
|
||||
### ControlNet v1.1
|
||||
|
||||
| Model Name | Control Image Overview| Control Image Example | Generated Image Example |
|
||||
|---|---|---|---|
|
||||
|[lllyasviel/control_v11p_sd15_canny](https://huggingface.co/lllyasviel/control_v11p_sd15_canny)<br/> *Trained with canny edge detection* | A monochrome image with white edges on a black background.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11e_sd15_ip2p](https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p)<br/> *Trained with pixel to pixel instruction* | No condition .|<a href="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint)<br/> Trained with image inpainting | No condition.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/main/images/output.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/main/images/output.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_mlsd](https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd)<br/> Trained with multi-level line segment detection | An image with annotated line segments.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11f1p_sd15_depth](https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth)<br/> Trained with depth estimation | An image with depth information, usually represented as a grayscale image.|<a href="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_normalbae](https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae)<br/> Trained with surface normal estimation | An image with surface normal information, usually represented as a color-coded image.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_seg](https://huggingface.co/lllyasviel/control_v11p_sd15_seg)<br/> Trained with image segmentation | An image with segmented regions, usually represented as a color-coded image.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_seg/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_seg/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_seg/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_seg/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_lineart](https://huggingface.co/lllyasviel/control_v11p_sd15_lineart)<br/> Trained with line art generation | An image with line art, usually black lines on a white background.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15s2_lineart_anime](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)<br/> Trained with anime line art generation | An image with anime-style line art.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_openpose](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)<br/> Trained with human pose estimation | An image with human poses, usually represented as a set of keypoints or skeletons.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_scribble](https://huggingface.co/lllyasviel/control_v11p_sd15_scribble)<br/> Trained with scribble-based image generation | An image with scribbles, usually random or user-drawn strokes.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11p_sd15_softedge](https://huggingface.co/lllyasviel/control_v11p_sd15_softedge)<br/> Trained with soft edge image generation | An image with soft edges, usually to create a more painterly or artistic effect.|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/main/images/image_out.png"/></a>|
|
||||
|[lllyasviel/control_v11e_sd15_shuffle](https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle)<br/> Trained with image shuffling | An image with shuffled patches or regions.|<a href="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/control.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/control.png"/></a>|<a href="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/image_out.png"><img width="64" src="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/image_out.png"/></a>|
|
||||
|
||||
## StableDiffusionControlNetPipeline
|
||||
[[autodoc]] StableDiffusionControlNetPipeline
|
||||
- all
|
||||
@@ -272,6 +327,7 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
|
||||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
|
||||
## FlaxStableDiffusionControlNetPipeline
|
||||
[[autodoc]] FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
@@ -30,4 +30,7 @@ Available Checkpoints are:
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
@@ -30,7 +30,11 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan
|
||||
- disable_attention_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
- from_ckpt
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
[[autodoc]] FlaxStableDiffusionImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
- __call__
|
||||
|
||||
@@ -31,7 +31,10 @@ Available checkpoints are:
|
||||
- disable_attention_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
[[autodoc]] FlaxStableDiffusionInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
- __call__
|
||||
|
||||
@@ -68,3 +68,6 @@ images[0].save("snowy_mountains.png")
|
||||
[[autodoc]] StableDiffusionInstructPix2PixPipeline
|
||||
- __call__
|
||||
- all
|
||||
- load_textual_inversion
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
@@ -39,6 +39,10 @@ Available Checkpoints are:
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- enable_vae_tiling
|
||||
- disable_vae_tiling
|
||||
- load_textual_inversion
|
||||
- from_ckpt
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
[[autodoc]] FlaxStableDiffusionPipeline
|
||||
- all
|
||||
|
||||
@@ -58,6 +58,9 @@ The library has three main components:
|
||||
| [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./api/pipelines/ddim) | [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
|
||||
| [if](./if) | [**IF**](./api/pipelines/if) | Image Generation |
|
||||
| [if_img2img](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation |
|
||||
| [if_inpainting](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
|
||||
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
|
||||
@@ -16,8 +16,8 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Optimum Habana 1.4 or later, [here](https://huggingface.co/docs/optimum/habana/installation) is how to install it.
|
||||
- SynapseAI 1.8.
|
||||
- Optimum Habana 1.5 or later, [here](https://huggingface.co/docs/optimum/habana/installation) is how to install it.
|
||||
- SynapseAI 1.9.
|
||||
|
||||
|
||||
## Inference Pipeline
|
||||
@@ -64,7 +64,16 @@ For more information, check out Optimum Habana's [documentation](https://hugging
|
||||
|
||||
Here are the latencies for Habana first-generation Gaudi and Gaudi2 with the [Habana/stable-diffusion](https://huggingface.co/Habana/stable-diffusion) Gaudi configuration (mixed precision bf16/fp32):
|
||||
|
||||
- [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) (512x512 resolution):
|
||||
|
||||
| | Latency (batch size = 1) | Throughput (batch size = 8) |
|
||||
| ---------------------- |:------------------------:|:---------------------------:|
|
||||
| first-generation Gaudi | 4.29s | 0.283 images/s |
|
||||
| Gaudi2 | 1.54s | 0.904 images/s |
|
||||
| first-generation Gaudi | 4.22s | 0.29 images/s |
|
||||
| Gaudi2 | 1.70s | 0.925 images/s |
|
||||
|
||||
- [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) (768x768 resolution):
|
||||
|
||||
| | Latency (batch size = 1) | Throughput |
|
||||
| ---------------------- |:------------------------:|:-------------------------------:|
|
||||
| first-generation Gaudi | 23.3s | 0.045 images/s (batch size = 2) |
|
||||
| Gaudi2 | 7.75s | 0.14 images/s (batch size = 5) |
|
||||
|
||||
116
docs/source/en/optimization/tome.mdx
Normal file
116
docs/source/en/optimization/tome.mdx
Normal file
@@ -0,0 +1,116 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Token Merging
|
||||
|
||||
Token Merging (introduced in [Token Merging: Your ViT But Faster](https://arxiv.org/abs/2210.09461)) works by merging the redundant tokens / patches progressively in the forward pass of a Transformer-based network. It can speed up the inference latency of the underlying network.
|
||||
|
||||
After Token Merging (ToMe) was released, the authors released [Token Merging for Fast Stable Diffusion](https://arxiv.org/abs/2303.17604), which introduced a version of ToMe which is more compatible with Stable Diffusion. We can use ToMe to gracefully speed up the inference latency of a [`DiffusionPipeline`]. This doc discusses how to apply ToMe to the [`StableDiffusionPipeline`], the expected speedups, and the qualitative aspects of using ToMe on the [`StableDiffusionPipeline`].
|
||||
|
||||
## Using ToMe
|
||||
|
||||
The authors of ToMe released a convenient Python library called [`tomesd`](https://github.com/dbolya/tomesd) that lets us apply ToMe to a [`DiffusionPipeline`] like so:
|
||||
|
||||
```diff
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import tomesd
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
+ tomesd.apply_patch(pipeline, ratio=0.5)
|
||||
|
||||
image = pipeline("a photo of an astronaut riding a horse on mars").images[0]
|
||||
```
|
||||
|
||||
And that’s it!
|
||||
|
||||
`tomesd.apply_patch()` exposes [a number of arguments](https://github.com/dbolya/tomesd#usage) to let us strike a balance between the pipeline inference speed and the quality of the generated tokens. Amongst those arguments, the most important one is `ratio`. `ratio` controls the number of tokens that will be merged during the forward pass. For more details on `tomesd`, please refer to the original repository https://github.com/dbolya/tomesd and [the paper](https://arxiv.org/abs/2303.17604).
|
||||
|
||||
## Benchmarking `tomesd` with `StableDiffusionPipeline`
|
||||
|
||||
We benchmarked the impact of using `tomesd` on [`StableDiffusionPipeline`] along with [xformers](https://huggingface.co/docs/diffusers/optimization/xformers) across different image resolutions. We used A100 and V100 as our test GPU devices with the following development environment (with Python 3.8.5):
|
||||
|
||||
```bash
|
||||
- `diffusers` version: 0.15.1
|
||||
- Python version: 3.8.16
|
||||
- PyTorch version (GPU?): 1.13.1+cu116 (True)
|
||||
- Huggingface_hub version: 0.13.2
|
||||
- Transformers version: 4.27.2
|
||||
- Accelerate version: 0.18.0
|
||||
- xFormers version: 0.0.16
|
||||
- tomesd version: 0.1.2
|
||||
```
|
||||
|
||||
We used this script for benchmarking: [https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335](https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335). Following are our findings:
|
||||
|
||||
### A100
|
||||
|
||||
| Resolution | Batch size | Vanilla | ToMe | ToMe + xFormers | ToMe speedup (%) | ToMe + xFormers speedup (%) |
|
||||
| --- | --- | --- | --- | --- | --- | --- |
|
||||
| 512 | 10 | 6.88 | 5.26 | 4.69 | 23.54651163 | 31.83139535 |
|
||||
| | | | | | | |
|
||||
| 768 | 10 | OOM | 14.71 | 11 | | |
|
||||
| | 8 | OOM | 11.56 | 8.84 | | |
|
||||
| | 4 | OOM | 5.98 | 4.66 | | |
|
||||
| | 2 | 4.99 | 3.24 | 3.1 | 35.07014028 | 37.8757515 |
|
||||
| | 1 | 3.29 | 2.24 | 2.03 | 31.91489362 | 38.29787234 |
|
||||
| | | | | | | |
|
||||
| 1024 | 10 | OOM | OOM | OOM | | |
|
||||
| | 8 | OOM | OOM | OOM | | |
|
||||
| | 4 | OOM | 12.51 | 9.09 | | |
|
||||
| | 2 | OOM | 6.52 | 4.96 | | |
|
||||
| | 1 | 6.4 | 3.61 | 2.81 | 43.59375 | 56.09375 |
|
||||
|
||||
***The timings reported here are in seconds. Speedups are calculated over the `Vanilla` timings.***
|
||||
|
||||
### V100
|
||||
|
||||
| Resolution | Batch size | Vanilla | ToMe | ToMe + xFormers | ToMe speedup (%) | ToMe + xFormers speedup (%) |
|
||||
| --- | --- | --- | --- | --- | --- | --- |
|
||||
| 512 | 10 | OOM | 10.03 | 9.29 | | |
|
||||
| | 8 | OOM | 8.05 | 7.47 | | |
|
||||
| | 4 | 5.7 | 4.3 | 3.98 | 24.56140351 | 30.1754386 |
|
||||
| | 2 | 3.14 | 2.43 | 2.27 | 22.61146497 | 27.70700637 |
|
||||
| | 1 | 1.88 | 1.57 | 1.57 | 16.4893617 | 16.4893617 |
|
||||
| | | | | | | |
|
||||
| 768 | 10 | OOM | OOM | 23.67 | | |
|
||||
| | 8 | OOM | OOM | 18.81 | | |
|
||||
| | 4 | OOM | 11.81 | 9.7 | | |
|
||||
| | 2 | OOM | 6.27 | 5.2 | | |
|
||||
| | 1 | 5.43 | 3.38 | 2.82 | 37.75322284 | 48.06629834 |
|
||||
| | | | | | | |
|
||||
| 1024 | 10 | OOM | OOM | OOM | | |
|
||||
| | 8 | OOM | OOM | OOM | | |
|
||||
| | 4 | OOM | OOM | 19.35 | | |
|
||||
| | 2 | OOM | 13 | 10.78 | | |
|
||||
| | 1 | OOM | 6.66 | 5.54 | | |
|
||||
|
||||
As seen in the tables above, the speedup with `tomesd` becomes more pronounced for larger image resolutions. It is also interesting to note that with `tomesd`, it becomes possible to run the pipeline on a higher resolution, like 1024x1024.
|
||||
|
||||
It might be possible to speed up inference even further with [`torch.compile()`](https://huggingface.co/docs/diffusers/optimization/torch2.0).
|
||||
|
||||
## Quality
|
||||
|
||||
As reported in [the paper](https://arxiv.org/abs/2303.17604), ToMe can preserve the quality of the generated images to a great extent while speeding up inference. By increasing the `ratio`, it is possible to further speed up inference, but that might come at the cost of a deterioration in the image quality.
|
||||
|
||||
To test the quality of the generated samples using our setup, we sampled a few prompts from the “Parti Prompts” (introduced in [Parti](https://parti.research.google/)) and performed inference with the [`StableDiffusionPipeline`] in the following settings:
|
||||
|
||||
- Vanilla [`StableDiffusionPipeline`]
|
||||
- [`StableDiffusionPipeline`] + ToMe
|
||||
- [`StableDiffusionPipeline`] + ToMe + xformers
|
||||
|
||||
We didn’t notice any significant decrease in the quality of the generated samples. Here are samples:
|
||||
|
||||

|
||||
|
||||
You can check out the generated samples [here](https://wandb.ai/sayakpaul/tomesd-results/runs/23j4bj3i?workspace=). We used [this script](https://gist.github.com/sayakpaul/8cac98d7f22399085a060992f411ecbd) for conducting this experiment.
|
||||
@@ -74,6 +74,7 @@ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/ma
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
|
||||
```
|
||||
|
||||
Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
|
||||
@@ -113,6 +114,29 @@ accelerate launch train_controlnet.py \
|
||||
--gradient_accumulation_steps=4
|
||||
```
|
||||
|
||||
## Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="path to save model"
|
||||
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_controlnet.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=fusing/fill50k \
|
||||
--resolution=512 \
|
||||
--learning_rate=1e-5 \
|
||||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
||||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
||||
--train_batch_size=4 \
|
||||
--mixed_precision="fp16" \
|
||||
--tracker_project_name="controlnet-demo" \
|
||||
--report_to=wandb
|
||||
```
|
||||
|
||||
## Example results
|
||||
|
||||
#### After 300 steps with batch size 8
|
||||
|
||||
291
docs/source/en/training/custom_diffusion.mdx
Normal file
291
docs/source/en/training/custom_diffusion.mdx
Normal file
@@ -0,0 +1,291 @@
|
||||
<!--Copyright 2023 Custom Diffusion authors The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Custom Diffusion training example
|
||||
|
||||
[Custom Diffusion](https://arxiv.org/abs/2212.04488) is a method to customize text-to-image models like Stable Diffusion given just a few (4~5) images of a subject.
|
||||
The `train_custom_diffusion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
|
||||
This training example was contributed by [Nupur Kumari](https://nupurkmr9.github.io/) (one of the authors of Custom Diffusion).
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install clip-retrieval
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell e.g. a notebook
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
```
|
||||
### Cat example 😺
|
||||
|
||||
Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it.
|
||||
|
||||
We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
|
||||
The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200
|
||||
```
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
export INSTANCE_DIR="./data/cat"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_cat/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="cat" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> cat" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=250 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>"
|
||||
```
|
||||
|
||||
**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.**
|
||||
|
||||
To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps:
|
||||
|
||||
* Install `wandb`: `pip install wandb`.
|
||||
* Authorize: `wandb login`.
|
||||
* Then specify a `validation_prompt` and set `report_to` to `wandb` while launching training. You can also configure the following related arguments:
|
||||
* `num_validation_images`
|
||||
* `validation_steps`
|
||||
|
||||
Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_cat/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="cat" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> cat" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=250 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>" \
|
||||
--validation_prompt="<new1> cat sitting in a bucket" \
|
||||
--report_to="wandb"
|
||||
```
|
||||
|
||||
Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau) where you can check out the intermediate results along with other training details.
|
||||
|
||||
If you specify `--push_to_hub`, the learned parameters will be pushed to a repository on the Hugging Face Hub. Here is an [example repository](https://huggingface.co/sayakpaul/custom-diffusion-cat).
|
||||
|
||||
### Training on multiple concepts 🐱🪵
|
||||
|
||||
Provide a [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with the info about each concept, similar to [this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py).
|
||||
|
||||
To collect the real images run this command for each concept in the json file.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200
|
||||
```
|
||||
|
||||
And then we're ready to start training!
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--concepts_list=./concept_list.json \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--num_class_images=200 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>+<new2>"
|
||||
```
|
||||
|
||||
Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg) where you can check out the intermediate results along with other training details.
|
||||
|
||||
### Training on human faces
|
||||
|
||||
For fine-tuning on human faces we found the following configuration to work better: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, and `freeze_model=crossattn` with at least 15-20 images.
|
||||
|
||||
To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200
|
||||
```
|
||||
|
||||
Then start training!
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
export INSTANCE_DIR="path-to-images"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_person/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="person" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> person" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=1000 \
|
||||
--scale_lr --hflip --noaug \
|
||||
--freeze_model crossattn \
|
||||
--modifier_token "<new1>" \
|
||||
--enable_xformers_memory_efficient_attention
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a model using the above command, you can run inference using the below command. Make sure to include the `modifier token` (e.g. \<new1\> in above example) in your prompt.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda")
|
||||
pipe.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion("path-to-save-model", weight_name="<new1>.bin")
|
||||
|
||||
image = pipe(
|
||||
"<new1> cat sitting in a bucket",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
It's possible to directly load these parameters from a Hub repository:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "sayakpaul/custom-diffusion-cat"
|
||||
card = RepoCard.load(model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
|
||||
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
|
||||
|
||||
image = pipe(
|
||||
"<new1> cat sitting in a bucket",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
Here is an example of performing inference with multiple concepts:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "sayakpaul/custom-diffusion-cat-wooden-pot"
|
||||
card = RepoCard.load(model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
|
||||
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new2>.bin")
|
||||
|
||||
image = pipe(
|
||||
"the <new1> cat sculpture in the style of a <new2> wooden pot",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("multi-subject.png")
|
||||
```
|
||||
|
||||
Here, `cat` and `wooden pot` refer to the multiple concepts.
|
||||
|
||||
### Inference from a training checkpoint
|
||||
|
||||
You can also perform inference from one of the complete checkpoint saved during the training process, if you used the `--checkpointing_steps` argument.
|
||||
|
||||
TODO.
|
||||
|
||||
## Set grads to none
|
||||
|
||||
To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
|
||||
|
||||
More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
|
||||
|
||||
## Experimental results
|
||||
|
||||
You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail.
|
||||
@@ -50,6 +50,20 @@ from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
Finally, download a [few images of a dog](https://huggingface.co/datasets/diffusers/dog-example) to DreamBooth with:
|
||||
|
||||
```py
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir,
|
||||
repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
## Finetuning
|
||||
|
||||
<Tip warning={true}>
|
||||
@@ -60,11 +74,13 @@ DreamBooth finetuning is very sensitive to hyperparameters and easy to overfit.
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
Let's try DreamBooth with a [few images of a dog](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ); download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path:
|
||||
Set the `INSTANCE_DIR` environment variable to the path of the directory containing the dog images.
|
||||
|
||||
Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path_to_training_images"
|
||||
export INSTANCE_DIR="./dog"
|
||||
export OUTPUT_DIR="path_to_saved_model"
|
||||
```
|
||||
|
||||
@@ -94,11 +110,13 @@ Before running the script, make sure you have the requirements installed:
|
||||
pip install -U -r requirements.txt
|
||||
```
|
||||
|
||||
Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument.
|
||||
|
||||
Now you can launch the training script with the following command:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="./dog"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
python train_dreambooth_flax.py \
|
||||
@@ -124,7 +142,7 @@ The authors recommend generating `num_epochs * num_samples` images for prior pre
|
||||
<pt>
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path_to_training_images"
|
||||
export INSTANCE_DIR="./dog"
|
||||
export CLASS_DIR="path_to_class_images"
|
||||
export OUTPUT_DIR="path_to_saved_model"
|
||||
|
||||
@@ -149,7 +167,7 @@ accelerate launch train_dreambooth.py \
|
||||
<jax>
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="./dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -186,7 +204,7 @@ Pass the `--train_text_encoder` argument to the training script to enable finetu
|
||||
<pt>
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path_to_training_images"
|
||||
export INSTANCE_DIR="./dog"
|
||||
export CLASS_DIR="path_to_class_images"
|
||||
export OUTPUT_DIR="path_to_saved_model"
|
||||
|
||||
@@ -213,7 +231,7 @@ accelerate launch train_dreambooth.py \
|
||||
<jax>
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="./dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -349,7 +367,7 @@ Then pass the `--use_8bit_adam` option to the training script:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path_to_training_images"
|
||||
export INSTANCE_DIR="./dog"
|
||||
export CLASS_DIR="path_to_class_images"
|
||||
export OUTPUT_DIR="path_to_saved_model"
|
||||
|
||||
@@ -378,7 +396,7 @@ To run DreamBooth on a 12GB GPU, you'll need to enable gradient checkpointing, t
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="./dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -425,7 +443,7 @@ Launch training with the following command:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path_to_training_images"
|
||||
export INSTANCE_DIR="./dog"
|
||||
export CLASS_DIR="path_to_class_images"
|
||||
export OUTPUT_DIR="path_to_saved_model"
|
||||
|
||||
|
||||
@@ -74,8 +74,7 @@ write_basic_config()
|
||||
As mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset
|
||||
is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper.
|
||||
|
||||
Configure environment variables such as the dataset identifier and the Stable Diffusion
|
||||
checkpoint:
|
||||
Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument. You'll also need to specify the dataset name in `DATASET_ID`:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
@@ -126,6 +125,27 @@ accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
|
||||
|
||||
***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***
|
||||
|
||||
## Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix.py \
|
||||
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
|
||||
--dataset_name=sayakpaul/instructpix2pix-1000-samples \
|
||||
--use_ema \
|
||||
--enable_xformers_memory_efficient_attention \
|
||||
--resolution=512 --random_flip \
|
||||
--train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
|
||||
--max_train_steps=15000 \
|
||||
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
|
||||
--learning_rate=5e-05 --lr_warmup_steps=0 \
|
||||
--conditioning_dropout_prob=0.05 \
|
||||
--mixed_precision=fp16 \
|
||||
--seed=42
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Once training is complete, we can perform inference:
|
||||
|
||||
@@ -16,7 +16,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`].
|
||||
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also
|
||||
support LoRA fine-tuning of the text encoder for DreamBooth in a limited capacity. For more details on how we support
|
||||
LoRA fine-tuning of the text encoder, refer to the discussion on [this PR](https://github.com/huggingface/diffusers/pull/2918).
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -50,7 +52,9 @@ Finetuning a model like Stable Diffusion, which has billions of parameters, can
|
||||
|
||||
Let's finetune [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset to generate your own Pokémon.
|
||||
|
||||
To start, make sure you have the `MODEL_NAME` and `DATASET_NAME` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables are optional and specify where to save the model to on the Hub:
|
||||
Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument. You'll also need to set the `DATASET_NAME` environment variable to the name of the dataset you want to train on.
|
||||
|
||||
The `OUTPUT_DIR` and `HUB_MODEL_ID` variables are optional and specify where to save the model to on the Hub:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
@@ -138,7 +142,9 @@ Load the LoRA weights from your finetuned model *on top of the base model weight
|
||||
|
||||
Let's finetune [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) with DreamBooth and LoRA with some 🐶 [dog images](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ). Download and save these images to a directory.
|
||||
|
||||
To start, make sure you have the `MODEL_NAME` and `INSTANCE_DIR` (path to directory containing images) environment variables set. The `OUTPUT_DIR` variables is optional and specifies where to save the model to on the Hub:
|
||||
To start, specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument. You'll also need to set `INSTANCE_DIR` to the path of the directory containing the images.
|
||||
|
||||
The `OUTPUT_DIR` variables is optional and specifies where to save the model to on the Hub:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
@@ -175,6 +181,11 @@ accelerate launch train_dreambooth_lora.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
It's also possible to additionally fine-tune the text encoder with LoRA. This, in most cases, leads
|
||||
to better results with a slight increase in the compute. To allow fine-tuning the text encoder with LoRA,
|
||||
specify the `--train_text_encoder` while launching the `train_dreambooth_lora.py` script.
|
||||
|
||||
|
||||
### Inference[[dreambooth-inference]]
|
||||
|
||||
Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`]:
|
||||
|
||||
@@ -39,6 +39,8 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
|
||||
- [Dreambooth](./dreambooth)
|
||||
- [LoRA Support](./lora)
|
||||
- [ControlNet](./controlnet)
|
||||
- [InstructPix2Pix](./instructpix2pix)
|
||||
- [Custom Diffusion](./custom_diffusion)
|
||||
|
||||
If possible, please [install xFormers](../optimization/xformers) for memory efficient attention. This could help make your training faster and less memory intensive.
|
||||
|
||||
@@ -50,6 +52,8 @@ If possible, please [install xFormers](../optimization/xformers) for memory effi
|
||||
| [**Dreambooth**](./dreambooth) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
|
||||
| [**Training with LoRA**](./lora) | ✅ | - | - |
|
||||
| [**ControlNet**](./controlnet) | ✅ | ✅ | - |
|
||||
| [**InstructPix2Pix**](./instructpix2pix) | ✅ | ✅ | - |
|
||||
| [**Custom Diffusion**](./custom_diffusion) | ✅ | ✅ | - |
|
||||
|
||||
## Community
|
||||
|
||||
|
||||
@@ -72,7 +72,9 @@ To load a checkpoint to resume training, pass the argument `--resume_from_checkp
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
Launch the [PyTorch training script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) for a fine-tuning run on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset like this:
|
||||
Launch the [PyTorch training script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) for a fine-tuning run on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset like this.
|
||||
|
||||
Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument.
|
||||
|
||||
<literalinclude>
|
||||
{"path": "../../../../examples/text_to_image/README.md",
|
||||
@@ -106,6 +108,31 @@ accelerate launch train_text_to_image.py \
|
||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--output_dir=${OUTPUT_DIR}
|
||||
```
|
||||
|
||||
#### Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export dataset_name="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$dataset_name \
|
||||
--use_ema \
|
||||
--resolution=512 --center_crop --random_flip \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--max_train_steps=15000 \
|
||||
--learning_rate=1e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--output_dir="sd-pokemon-model"
|
||||
```
|
||||
|
||||
</pt>
|
||||
<jax>
|
||||
With Flax, it's possible to train a Stable Diffusion model faster on TPUs and GPUs thanks to [@duongna211](https://github.com/duongna21). This is very efficient on TPU hardware but works great on GPUs too. The Flax training script doesn't support features like gradient checkpointing or gradient accumulation yet, so you'll need a GPU with at least 30GB of memory or a TPU v3.
|
||||
@@ -116,6 +143,8 @@ Before running the script, make sure you have the requirements installed:
|
||||
pip install -U -r requirements_flax.txt
|
||||
```
|
||||
|
||||
Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument.
|
||||
|
||||
Now you can launch the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py) like this:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
@@ -81,9 +81,20 @@ To resume training from a saved checkpoint, pass the following argument to the t
|
||||
|
||||
## Finetuning
|
||||
|
||||
For your training dataset, download these [images of a cat statue](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and store them in a directory.
|
||||
For your training dataset, download these [images of a cat toy](https://huggingface.co/datasets/diffusers/cat_toy_example) and store them in a directory:
|
||||
|
||||
Set the `MODEL_NAME` environment variable to the model repository id, and the `DATA_DIR` environment variable to the path of the directory containing the images. Now you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py):
|
||||
```py
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./cat"
|
||||
snapshot_download(
|
||||
"diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes"
|
||||
)
|
||||
```
|
||||
|
||||
Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument, and the `DATA_DIR` environment variable to the path of the directory containing the images.
|
||||
|
||||
Now you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py):
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -95,7 +106,7 @@ Set the `MODEL_NAME` environment variable to the model repository id, and the `D
|
||||
<pt>
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export DATA_DIR="path-to-dir-containing-images"
|
||||
export DATA_DIR="./cat"
|
||||
|
||||
accelerate launch textual_inversion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
@@ -111,6 +122,18 @@ accelerate launch textual_inversion.py \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="textual_inversion_cat"
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 If you want to increase the trainable capacity, you can associate your placeholder token, *e.g.* `<cat-toy>` to
|
||||
multiple embedding vectors. This can help the model to better capture the style of more (complex) images.
|
||||
To enable training multiple embedding vectors, simply pass:
|
||||
|
||||
```bash
|
||||
--num_vectors=5
|
||||
```
|
||||
|
||||
</Tip>
|
||||
</pt>
|
||||
<jax>
|
||||
If you have access to TPUs, try out the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py) to train even faster (this'll also work for GPUs). With the same configuration settings, the Flax training script should be at least 70% faster than the PyTorch training script! ⚡️
|
||||
@@ -121,11 +144,13 @@ Before you begin, make sure you install the Flax specific dependencies:
|
||||
pip install -U -r requirements_flax.txt
|
||||
```
|
||||
|
||||
Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument.
|
||||
|
||||
Then you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py):
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export DATA_DIR="path-to-dir-containing-images"
|
||||
export DATA_DIR="./cat"
|
||||
|
||||
python textual_inversion_flax.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
|
||||
@@ -122,6 +122,26 @@ accelerate launch train_unconditional.py \
|
||||
<img src="https://user-images.githubusercontent.com/26864830/180248200-928953b4-db38-48db-b0c6-8b740fe6786f.png"/>
|
||||
</div>
|
||||
|
||||
### Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
|
||||
--dataset_name="huggan/pokemon" \
|
||||
--resolution=64 --center_crop --random_flip \
|
||||
--output_dir="ddpm-ema-pokemon-64" \
|
||||
--train_batch_size=16 \
|
||||
--num_epochs=100 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--use_ema \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_warmup_steps=500 \
|
||||
--mixed_precision="fp16" \
|
||||
--logger="wandb"
|
||||
```
|
||||
|
||||
## Finetuning with your own data
|
||||
|
||||
There are two ways to finetune a model on your own dataset:
|
||||
|
||||
@@ -10,17 +10,21 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# How to build a community pipeline
|
||||
# How to contribute a community pipeline
|
||||
|
||||
*Note*: this page was built from the GitHub Issue on Community Pipelines [#841](https://github.com/huggingface/diffusers/issues/841).
|
||||
<Tip>
|
||||
|
||||
Let's make an example!
|
||||
Say you want to define a pipeline that just does a single forward pass to a U-Net and then calls a scheduler only once (Note, this doesn't make any sense from a scientific point of view, but only represents an example of how things work under the hood).
|
||||
💡 Take a look at GitHub Issue [#841](https://github.com/huggingface/diffusers/issues/841) for more context about why we're adding community pipelines to help everyone easily share their work without being slowed down.
|
||||
|
||||
Cool! So you open your favorite IDE and start creating your pipeline 💻.
|
||||
First, what model weights and configurations do we need?
|
||||
We have a U-Net and a scheduler, so our pipeline should take a U-Net and a scheduler as an argument.
|
||||
Also, as stated above, you'd like to be able to load weights and the scheduler config for Hub and share your code with others, so we'll inherit from `DiffusionPipeline`:
|
||||
</Tip>
|
||||
|
||||
Community pipelines allow you to add any additional features you'd like on top of the [`DiffusionPipeline`]. The main benefit of building on top of the `DiffusionPipeline` is anyone can load and use your pipeline by only adding one more argument, making it super easy for the community to access.
|
||||
|
||||
This guide will show you how to create a community pipeline and explain how they work. To keep things simple, you'll create a "one-step" pipeline where the `UNet` does a single forward pass and calls the scheduler once.
|
||||
|
||||
## Initialize the pipeline
|
||||
|
||||
You should start by creating a `one_step_unet.py` file for your community pipeline. In this file, create a pipeline class that inherits from the [`DiffusionPipeline`] to be able to load model weights and the scheduler configuration from the Hub. The one-step pipeline needs a `UNet` and a scheduler, so you'll need to add these as arguments to the `__init__` function:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -32,50 +36,52 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
super().__init__()
|
||||
```
|
||||
|
||||
Now, we must save the `unet` and `scheduler` in a config file so that you can save your pipeline with `save_pretrained`.
|
||||
Therefore, make sure you add every component that is save-able to the `register_modules` function:
|
||||
To ensure your pipeline and its components (`unet` and `scheduler`) can be saved with [`~DiffusionPipeline.save_pretrained`], add them to the `register_modules` function:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
```diff
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
+ self.register_modules(unet=unet, scheduler=scheduler)
|
||||
```
|
||||
|
||||
Cool, the init is done! 🔥 Now, let's go into the forward pass, which we recommend defining as `__call__` . Here you're given all the creative freedom there is. For our amazing "one-step" pipeline, we simply create a random image and call the unet once and the scheduler once:
|
||||
Cool, the `__init__` step is done and you can move to the forward pass now! 🔥
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
## Define the forward pass
|
||||
|
||||
In the forward pass, which we recommend defining as `__call__`, you have complete creative freedom to add whatever feature you'd like. For our amazing one-step pipeline, create a random image and only call the `unet` and `scheduler` once by setting `timestep=1`:
|
||||
|
||||
```diff
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
|
||||
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
def __call__(self):
|
||||
image = torch.randn(
|
||||
(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
||||
)
|
||||
timestep = 1
|
||||
+ def __call__(self):
|
||||
+ image = torch.randn(
|
||||
+ (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
||||
+ )
|
||||
+ timestep = 1
|
||||
|
||||
model_output = self.unet(image, timestep).sample
|
||||
scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
|
||||
+ model_output = self.unet(image, timestep).sample
|
||||
+ scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
|
||||
|
||||
return scheduler_output
|
||||
+ return scheduler_output
|
||||
```
|
||||
|
||||
Cool, that's it! 🚀 You can now run this pipeline by passing a `unet` and a `scheduler` to the init:
|
||||
That's it! 🚀 You can now run this pipeline by passing a `unet` and `scheduler` to it:
|
||||
|
||||
```python
|
||||
from diffusers import DDPMScheduler, Unet2DModel
|
||||
from diffusers import DDPMScheduler, UNet2DModel
|
||||
|
||||
scheduler = DDPMScheduler()
|
||||
unet = UNet2DModel()
|
||||
@@ -85,7 +91,7 @@ pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler)
|
||||
output = pipeline()
|
||||
```
|
||||
|
||||
But what's even better is that you can load pre-existing weights into the pipeline if they match exactly your pipeline structure. This is e.g. the case for [https://huggingface.co/google/ddpm-cifar10-32](https://huggingface.co/google/ddpm-cifar10-32) so that we can do the following:
|
||||
But what's even better is you can load pre-existing weights into the pipeline if the pipeline structure is identical. For example, you can load the [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32) weights into the one-step pipeline:
|
||||
|
||||
```python
|
||||
pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32")
|
||||
@@ -93,33 +99,11 @@ pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-
|
||||
output = pipeline()
|
||||
```
|
||||
|
||||
We want to share this amazing pipeline with the community, so we would open a PR request to add the following code under `one_step_unet.py` to [https://github.com/huggingface/diffusers/tree/main/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) .
|
||||
## Share your pipeline
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
Open a Pull Request on the 🧨 Diffusers [repository](https://github.com/huggingface/diffusers) to add your awesome pipeline in `one_step_unet.py` to the [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) subfolder.
|
||||
|
||||
|
||||
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
def __call__(self):
|
||||
image = torch.randn(
|
||||
(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
||||
)
|
||||
timestep = 1
|
||||
|
||||
model_output = self.unet(image, timestep).sample
|
||||
scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
|
||||
|
||||
return scheduler_output
|
||||
```
|
||||
|
||||
Our amazing pipeline got merged here: [#840](https://github.com/huggingface/diffusers/pull/840).
|
||||
Now everybody that has `diffusers >= 0.4.0` installed can use our pipeline magically 🪄 as follows:
|
||||
Once it is merged, anyone with `diffusers >= 0.4.0` installed can use this pipeline magically 🪄 by specifying it in the `custom_pipeline` argument:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -128,28 +112,59 @@ pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeli
|
||||
pipe()
|
||||
```
|
||||
|
||||
Another way to upload your custom_pipeline, besides sending a PR, is uploading the code that contains it to the Hugging Face Hub, [as exemplified here](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview#loading-custom-pipelines-from-the-hub).
|
||||
Another way to share your community pipeline is to upload the `one_step_unet.py` file directly to your preferred [model repository](https://huggingface.co/docs/hub/models-uploading) on the Hub. Instead of specifying the `one_step_unet.py` file, pass the model repository id to the `custom_pipeline` argument:
|
||||
|
||||
**Try it out now - it works!**
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
In general, you will want to create much more sophisticated pipelines, so we recommend looking at existing pipelines here: [https://github.com/huggingface/diffusers/tree/main/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community).
|
||||
pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="stevhliu/one_step_unet")
|
||||
```
|
||||
|
||||
IMPORTANT:
|
||||
You can use whatever package you want in your community pipeline file - as long as the user has it installed, everything will work fine. Make sure you have one and only one pipeline class that inherits from `DiffusionPipeline` as this will be automatically detected.
|
||||
Take a look at the following table to compare the two sharing workflows to help you decide the best option for you:
|
||||
|
||||
| | GitHub community pipeline | HF Hub community pipeline |
|
||||
|----------------|------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------|
|
||||
| usage | same | same |
|
||||
| review process | open a Pull Request on GitHub and undergo a review process from the Diffusers team before merging; may be slower | upload directly to a Hub repository without any review; this is the fastest workflow |
|
||||
| visibility | included in the official Diffusers repository and documentation | included on your HF Hub profile and relies on your own usage/promotion to gain visibility |
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 You can use whatever package you want in your community pipeline file - as long as the user has it installed, everything will work fine. Make sure you have one and only one pipeline class that inherits from `DiffusionPipeline` because this is automatically detected.
|
||||
|
||||
</Tip>
|
||||
|
||||
## How do community pipelines work?
|
||||
A community pipeline is a class that has to inherit from ['DiffusionPipeline']:
|
||||
and that has been added to `examples/community` [files](https://github.com/huggingface/diffusers/tree/main/examples/community).
|
||||
The community can load the pipeline code via the custom_pipeline argument from DiffusionPipeline. See docs [here](https://huggingface.co/docs/diffusers/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.custom_pipeline):
|
||||
|
||||
This means:
|
||||
The model weights and configs of the pipeline should be loaded from the `pretrained_model_name_or_path` [argument](https://huggingface.co/docs/diffusers/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path):
|
||||
whereas the code that powers the community pipeline is defined in a file added in [`examples/community`](https://github.com/huggingface/diffusers/tree/main/examples/community).
|
||||
A community pipeline is a class that inherits from [`DiffusionPipeline`] which means:
|
||||
|
||||
Now, it might very well be that only some of your pipeline components weights can be downloaded from an official repo.
|
||||
The other components should then be passed directly to init as is the case for the ClIP guidance notebook [here](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb#scrollTo=z9Kglma6hjki).
|
||||
- It can be loaded with the [`custom_pipeline`] argument.
|
||||
- The model weights and scheduler configuration are loaded from [`pretrained_model_name_or_path`].
|
||||
- The code that implements a feature in the community pipeline is defined in a `pipeline.py` file.
|
||||
|
||||
The magic behind all of this is that we load the code directly from GitHub. You can check it out in more detail if you follow the functionality defined here:
|
||||
Sometimes you can't load all the pipeline components weights from an official repository. In this case, the other components should be passed directly to the pipeline:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
|
||||
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
custom_pipeline="clip_guided_stable_diffusion",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
scheduler=scheduler,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
The magic behind community pipelines is contained in the following code. It allows the community pipeline to be loaded from GitHub or the Hub, and it'll be available to all 🧨 Diffusers packages.
|
||||
|
||||
```python
|
||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||||
@@ -164,6 +179,3 @@ else:
|
||||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||
```
|
||||
|
||||
This is why a community pipeline merged to GitHub will be directly available to all `diffusers` packages.
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Custom Pipelines
|
||||
# Community pipelines
|
||||
|
||||
> **For more information about community pipelines, please have a look at [this issue](https://github.com/huggingface/diffusers/issues/841).**
|
||||
|
||||
|
||||
@@ -10,19 +10,21 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Loading and Adding Custom Pipelines
|
||||
# Load community pipelines
|
||||
|
||||
Diffusers allows you to conveniently load any custom pipeline from the Hugging Face Hub as well as any [official community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community)
|
||||
via the [`DiffusionPipeline`] class.
|
||||
Community pipelines are any [`DiffusionPipeline`] class that are different from the original implementation as specified in their paper (for example, the [`StableDiffusionControlNetPipeline`] corresponds to the [Text-to-Image Generation with ControlNet Conditioning](https://arxiv.org/abs/2302.05543) paper). They provide additional functionality or extend the original implementation of a pipeline.
|
||||
|
||||
## Loading custom pipelines from the Hub
|
||||
There are many cool community pipelines like [Speech to Image](https://github.com/huggingface/diffusers/tree/main/examples/community#speech-to-image) or [Composable Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#composable-stable-diffusion), and you can find all the official community pipelines [here](https://github.com/huggingface/diffusers/tree/main/examples/community).
|
||||
|
||||
Custom pipelines can be easily loaded from any model repository on the Hub that defines a diffusion pipeline in a `pipeline.py` file.
|
||||
Let's load a dummy pipeline from [hf-internal-testing/diffusers-dummy-pipeline](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline).
|
||||
To load any community pipeline on the Hub, pass the repository id of the community pipeline to the `custom_pipeline` argument and the model repository where you'd like to load the pipeline weights and components from. For example, the example below loads a dummy pipeline from [`hf-internal-testing/diffusers-dummy-pipeline`](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py) and the pipeline weights and components from [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32):
|
||||
|
||||
All you need to do is pass the custom pipeline repo id with the `custom_pipeline` argument alongside the repo from where you wish to load the pipeline modules.
|
||||
<Tip warning={true}>
|
||||
|
||||
```python
|
||||
🔒 By loading a community pipeline from the Hugging Face Hub, you are trusting that the code you are loading is safe. Make sure to inspect the code online before loading and running it automatically!
|
||||
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
@@ -30,25 +32,9 @@ pipeline = DiffusionPipeline.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
This will load the custom pipeline as defined in the [model repository](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py).
|
||||
Loading an official community pipeline is similar, but you can mix loading weights from an official repository id and pass pipeline components directly. The example below loads the community [CLIP Guided Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#clip-guided-stable-diffusion) pipeline, and you can pass the CLIP model components directly to it:
|
||||
|
||||
<Tip warning={true} >
|
||||
|
||||
By loading a custom pipeline from the Hugging Face Hub, you are trusting that the code you are loading
|
||||
is safe 🔒. Make sure to check out the code online before loading & running it automatically.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Loading official community pipelines
|
||||
|
||||
Community pipelines are summarized in the [community examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community).
|
||||
|
||||
Similarly, you need to pass both the *repo id* from where you wish to load the weights as well as the `custom_pipeline` argument. Here the `custom_pipeline` argument should consist simply of the filename of the community pipeline excluding the `.py` suffix, *e.g.* `clip_guided_stable_diffusion`.
|
||||
|
||||
Since community pipelines are often more complex, one can mix loading weights from an official *repo id*
|
||||
and passing pipeline modules directly.
|
||||
|
||||
```python
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import CLIPImageProcessor, CLIPModel
|
||||
|
||||
@@ -65,59 +51,4 @@ pipeline = DiffusionPipeline.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
## Adding custom pipelines to the Hub
|
||||
|
||||
To add a custom pipeline to the Hub, all you need to do is to define a pipeline class that inherits
|
||||
from [`DiffusionPipeline`] in a `pipeline.py` file.
|
||||
Make sure that the whole pipeline is encapsulated within a single class and that the `pipeline.py` file
|
||||
has only one such class.
|
||||
|
||||
Let's quickly define an example pipeline.
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
class MyPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size)
|
||||
)
|
||||
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
return image
|
||||
```
|
||||
|
||||
Now you can upload this short file under the name `pipeline.py` in your preferred [model repository](https://huggingface.co/docs/hub/models-uploading). For Stable Diffusion pipelines, you may also [join the community organisation for shared pipelines](https://huggingface.co/organizations/sd-diffusers-pipelines-library/share/BUPyDUuHcciGTOKaExlqtfFcyCZsVFdrjr) to upload yours.
|
||||
Finally, we can load the custom pipeline by passing the model repository name, *e.g.* `sd-diffusers-pipelines-library/my_custom_pipeline` alongside the model repository from where we want to load the `unet` and `scheduler` components.
|
||||
|
||||
```python
|
||||
my_pipeline = DiffusionPipeline.from_pretrained(
|
||||
"google/ddpm-cifar10-32", custom_pipeline="patrickvonplaten/my_custom_pipeline"
|
||||
)
|
||||
```
|
||||
For more information about community pipelines, take a look at the [Community pipelines](custom_pipeline_examples) guide for how to use them and if you're interested in adding a community pipeline check out the [How to contribute a community pipeline](contribute_pipeline) guide!
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Reproducibility is important for testing, replicating results, and can even be used to [improve image quality](reusing_seeds). However, the randomness in diffusion models is a desired property because it allows the pipeline to generate different images every time it is run. While you can't expect to get the exact same results across platforms, you can expect results to be reproducible across releases and platforms within a certain tolerance range. Even then, tolerance varies depending on the diffusion pipeline and checkpoint.
|
||||
|
||||
This is why it's important to understand how to control sources of randomness in diffusion models.
|
||||
This is why it's important to understand how to control sources of randomness in diffusion models or use deterministic algorithms.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -24,7 +24,7 @@ This is why it's important to understand how to control sources of randomness in
|
||||
|
||||
</Tip>
|
||||
|
||||
## Inference
|
||||
## Control randomness
|
||||
|
||||
During inference, pipelines rely heavily on random sampling operations which include creating the
|
||||
Gaussian noise tensors to denoise and adding noise to the scheduling step.
|
||||
@@ -147,5 +147,46 @@ susceptible to precision error propagation. Don't expect similar results across
|
||||
different GPU hardware or PyTorch versions. In this case, you'll need to run
|
||||
exactly the same hardware and PyTorch version for full reproducibility.
|
||||
|
||||
## randn_tensor
|
||||
### randn_tensor
|
||||
[[autodoc]] diffusers.utils.randn_tensor
|
||||
|
||||
## Deterministic algorithms
|
||||
|
||||
You can also configure PyTorch to use deterministic algorithms to create a reproducible pipeline. However, you should be aware that deterministic algorithms may be slower than nondeterministic ones and you may observe a decrease in performance. But if reproducibility is important to you, then this is the way to go!
|
||||
|
||||
Nondeterministic behavior occurs when operations are launched in more than one CUDA stream. To avoid this, set the environment varibale [`CUBLAS_WORKSPACE_CONFIG`](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) to `:16:8` to only use one buffer size during runtime.
|
||||
|
||||
PyTorch typically benchmarks multiple algorithms to select the fastest one, but if you want reproducibility, you should disable this feature because the benchmark may select different algorithms each time. Lastly, pass `True` to [`torch.use_deterministic_algorithms`](https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html) to enable deterministic algorithms.
|
||||
|
||||
```py
|
||||
import os
|
||||
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.use_deterministic_algorithms(True)
|
||||
```
|
||||
|
||||
Now when you run the same pipeline twice, you'll get identical results.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DDIMScheduler, StableDiffusionPipeline
|
||||
import numpy as np
|
||||
|
||||
model_id = "runwayml/stable-diffusion-v1-5"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cuda")
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
g = torch.Generator(device="cuda")
|
||||
|
||||
prompt = "A bear is playing a guitar on Times Square"
|
||||
|
||||
g.manual_seed(0)
|
||||
result1 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="latent").images
|
||||
|
||||
g.manual_seed(0)
|
||||
result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="latent").images
|
||||
|
||||
print("L_inf dist = ", abs(result1 - result2).max())
|
||||
"L_inf dist = tensor(0., device='cuda:0')"
|
||||
```
|
||||
@@ -31,7 +31,7 @@ MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt
|
||||
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - |[Aengus (Duc-Anh)](https://github.com/aengusng8) |
|
||||
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
|
||||
|
||||
| TensorRT Stable Diffusion Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - |[Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
@@ -1130,3 +1130,34 @@ Init Image
|
||||
Output Image
|
||||
|
||||

|
||||
|
||||
### TensorRT Text2Image Stable Diffusion Pipeline
|
||||
|
||||
The TensorRT Pipeline can be used to accelerate the Text2Image Stable Diffusion Inference run.
|
||||
|
||||
NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DDIMScheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||
|
||||
# Use the DDIMScheduler scheduler here instead
|
||||
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1",
|
||||
subfolder="scheduler")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
|
||||
custom_pipeline="stable_diffusion_tensorrt_txt2img",
|
||||
revision='fp16',
|
||||
torch_dtype=torch.float16,
|
||||
scheduler=scheduler,)
|
||||
|
||||
# re-use cached folder to save ONNX models and TensorRT Engines
|
||||
pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", revision='fp16',)
|
||||
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a beautiful photograph of Mt. Fuji during cherry blossom"
|
||||
image = pipe(prompt).images[0]
|
||||
image.save('tensorrt_mt_fuji.png')
|
||||
```
|
||||
|
||||
926
examples/community/stable_diffusion_tensorrt_txt2img.py
Normal file
926
examples/community/stable_diffusion_tensorrt_txt2img.py
Normal file
@@ -0,0 +1,926 @@
|
||||
#
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from copy import copy
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnx_graphsurgeon as gs
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from onnx import shape_inference
|
||||
from polygraphy import cuda
|
||||
from polygraphy.backend.common import bytes_from_path
|
||||
from polygraphy.backend.onnx.loader import fold_constants
|
||||
from polygraphy.backend.trt import (
|
||||
CreateConfig,
|
||||
Profile,
|
||||
engine_from_bytes,
|
||||
engine_from_network,
|
||||
network_from_onnx_path,
|
||||
save_engine,
|
||||
)
|
||||
from polygraphy.backend.trt import util as trt_util
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPipelineOutput,
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils import DIFFUSERS_CACHE, logging
|
||||
|
||||
|
||||
"""
|
||||
Installation instructions
|
||||
python3 -m pip install --upgrade tensorrt
|
||||
python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
|
||||
python3 -m pip install onnxruntime
|
||||
"""
|
||||
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Map of numpy dtype -> torch dtype
|
||||
numpy_to_torch_dtype_dict = {
|
||||
np.uint8: torch.uint8,
|
||||
np.int8: torch.int8,
|
||||
np.int16: torch.int16,
|
||||
np.int32: torch.int32,
|
||||
np.int64: torch.int64,
|
||||
np.float16: torch.float16,
|
||||
np.float32: torch.float32,
|
||||
np.float64: torch.float64,
|
||||
np.complex64: torch.complex64,
|
||||
np.complex128: torch.complex128,
|
||||
}
|
||||
if np.version.full_version >= "1.24.0":
|
||||
numpy_to_torch_dtype_dict[np.bool_] = torch.bool
|
||||
else:
|
||||
numpy_to_torch_dtype_dict[np.bool] = torch.bool
|
||||
|
||||
# Map of torch dtype -> numpy dtype
|
||||
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
|
||||
|
||||
|
||||
def device_view(t):
|
||||
return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
|
||||
|
||||
|
||||
class Engine:
|
||||
def __init__(self, engine_path):
|
||||
self.engine_path = engine_path
|
||||
self.engine = None
|
||||
self.context = None
|
||||
self.buffers = OrderedDict()
|
||||
self.tensors = OrderedDict()
|
||||
|
||||
def __del__(self):
|
||||
[buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]
|
||||
del self.engine
|
||||
del self.context
|
||||
del self.buffers
|
||||
del self.tensors
|
||||
|
||||
def build(
|
||||
self,
|
||||
onnx_path,
|
||||
fp16,
|
||||
input_profile=None,
|
||||
enable_preview=False,
|
||||
enable_all_tactics=False,
|
||||
timing_cache=None,
|
||||
workspace_size=0,
|
||||
):
|
||||
logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
|
||||
p = Profile()
|
||||
if input_profile:
|
||||
for name, dims in input_profile.items():
|
||||
assert len(dims) == 3
|
||||
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
|
||||
|
||||
config_kwargs = {}
|
||||
|
||||
config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
|
||||
if enable_preview:
|
||||
# Faster dynamic shapes made optional since it increases engine build time.
|
||||
config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
|
||||
if workspace_size > 0:
|
||||
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
|
||||
if not enable_all_tactics:
|
||||
config_kwargs["tactic_sources"] = []
|
||||
|
||||
engine = engine_from_network(
|
||||
network_from_onnx_path(onnx_path),
|
||||
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs),
|
||||
save_timing_cache=timing_cache,
|
||||
)
|
||||
save_engine(engine, path=self.engine_path)
|
||||
|
||||
def load(self):
|
||||
logger.warning(f"Loading TensorRT engine: {self.engine_path}")
|
||||
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
|
||||
|
||||
def activate(self):
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
||||
def allocate_buffers(self, shape_dict=None, device="cuda"):
|
||||
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
|
||||
binding = self.engine[idx]
|
||||
if shape_dict and binding in shape_dict:
|
||||
shape = shape_dict[binding]
|
||||
else:
|
||||
shape = self.engine.get_binding_shape(binding)
|
||||
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
|
||||
if self.engine.binding_is_input(binding):
|
||||
self.context.set_binding_shape(idx, shape)
|
||||
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
|
||||
self.tensors[binding] = tensor
|
||||
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
|
||||
|
||||
def infer(self, feed_dict, stream):
|
||||
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
|
||||
# shallow copy of ordered dict
|
||||
device_buffers = copy(self.buffers)
|
||||
for name, buf in feed_dict.items():
|
||||
assert isinstance(buf, cuda.DeviceView)
|
||||
device_buffers[name] = buf
|
||||
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
|
||||
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
|
||||
if not noerror:
|
||||
raise ValueError("ERROR: inference failed.")
|
||||
|
||||
return self.tensors
|
||||
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, onnx_graph):
|
||||
self.graph = gs.import_onnx(onnx_graph)
|
||||
|
||||
def cleanup(self, return_onnx=False):
|
||||
self.graph.cleanup().toposort()
|
||||
if return_onnx:
|
||||
return gs.export_onnx(self.graph)
|
||||
|
||||
def select_outputs(self, keep, names=None):
|
||||
self.graph.outputs = [self.graph.outputs[o] for o in keep]
|
||||
if names:
|
||||
for i, name in enumerate(names):
|
||||
self.graph.outputs[i].name = name
|
||||
|
||||
def fold_constants(self, return_onnx=False):
|
||||
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
|
||||
self.graph = gs.import_onnx(onnx_graph)
|
||||
if return_onnx:
|
||||
return onnx_graph
|
||||
|
||||
def infer_shapes(self, return_onnx=False):
|
||||
onnx_graph = gs.export_onnx(self.graph)
|
||||
if onnx_graph.ByteSize() > 2147483648:
|
||||
raise TypeError("ERROR: model size exceeds supported 2GB limit")
|
||||
else:
|
||||
onnx_graph = shape_inference.infer_shapes(onnx_graph)
|
||||
|
||||
self.graph = gs.import_onnx(onnx_graph)
|
||||
if return_onnx:
|
||||
return onnx_graph
|
||||
|
||||
|
||||
class BaseModel:
|
||||
def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77):
|
||||
self.model = model
|
||||
self.name = "SD Model"
|
||||
self.fp16 = fp16
|
||||
self.device = device
|
||||
|
||||
self.min_batch = 1
|
||||
self.max_batch = max_batch_size
|
||||
self.min_image_shape = 256 # min image resolution: 256x256
|
||||
self.max_image_shape = 1024 # max image resolution: 1024x1024
|
||||
self.min_latent_shape = self.min_image_shape // 8
|
||||
self.max_latent_shape = self.max_image_shape // 8
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.text_maxlen = text_maxlen
|
||||
|
||||
def get_model(self):
|
||||
return self.model
|
||||
|
||||
def get_input_names(self):
|
||||
pass
|
||||
|
||||
def get_output_names(self):
|
||||
pass
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return None
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
pass
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
||||
return None
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
return None
|
||||
|
||||
def optimize(self, onnx_graph):
|
||||
opt = Optimizer(onnx_graph)
|
||||
opt.cleanup()
|
||||
opt.fold_constants()
|
||||
opt.infer_shapes()
|
||||
onnx_opt_graph = opt.cleanup(return_onnx=True)
|
||||
return onnx_opt_graph
|
||||
|
||||
def check_dims(self, batch_size, image_height, image_width):
|
||||
assert batch_size >= self.min_batch and batch_size <= self.max_batch
|
||||
assert image_height % 8 == 0 or image_width % 8 == 0
|
||||
latent_height = image_height // 8
|
||||
latent_width = image_width // 8
|
||||
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
|
||||
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
|
||||
return (latent_height, latent_width)
|
||||
|
||||
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
|
||||
min_batch = batch_size if static_batch else self.min_batch
|
||||
max_batch = batch_size if static_batch else self.max_batch
|
||||
latent_height = image_height // 8
|
||||
latent_width = image_width // 8
|
||||
min_image_height = image_height if static_shape else self.min_image_shape
|
||||
max_image_height = image_height if static_shape else self.max_image_shape
|
||||
min_image_width = image_width if static_shape else self.min_image_shape
|
||||
max_image_width = image_width if static_shape else self.max_image_shape
|
||||
min_latent_height = latent_height if static_shape else self.min_latent_shape
|
||||
max_latent_height = latent_height if static_shape else self.max_latent_shape
|
||||
min_latent_width = latent_width if static_shape else self.min_latent_shape
|
||||
max_latent_width = latent_width if static_shape else self.max_latent_shape
|
||||
return (
|
||||
min_batch,
|
||||
max_batch,
|
||||
min_image_height,
|
||||
max_image_height,
|
||||
min_image_width,
|
||||
max_image_width,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
)
|
||||
|
||||
|
||||
def getOnnxPath(model_name, onnx_dir, opt=True):
|
||||
return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx")
|
||||
|
||||
|
||||
def getEnginePath(model_name, engine_dir):
|
||||
return os.path.join(engine_dir, model_name + ".plan")
|
||||
|
||||
|
||||
def build_engines(
|
||||
models: dict,
|
||||
engine_dir,
|
||||
onnx_dir,
|
||||
onnx_opset,
|
||||
opt_image_height,
|
||||
opt_image_width,
|
||||
opt_batch_size=1,
|
||||
force_engine_rebuild=False,
|
||||
static_batch=False,
|
||||
static_shape=True,
|
||||
enable_preview=False,
|
||||
enable_all_tactics=False,
|
||||
timing_cache=None,
|
||||
max_workspace_size=0,
|
||||
):
|
||||
built_engines = {}
|
||||
if not os.path.isdir(onnx_dir):
|
||||
os.makedirs(onnx_dir)
|
||||
if not os.path.isdir(engine_dir):
|
||||
os.makedirs(engine_dir)
|
||||
|
||||
# Export models to ONNX
|
||||
for model_name, model_obj in models.items():
|
||||
engine_path = getEnginePath(model_name, engine_dir)
|
||||
if force_engine_rebuild or not os.path.exists(engine_path):
|
||||
logger.warning("Building Engines...")
|
||||
logger.warning("Engine build can take a while to complete")
|
||||
onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
|
||||
onnx_opt_path = getOnnxPath(model_name, onnx_dir)
|
||||
if force_engine_rebuild or not os.path.exists(onnx_opt_path):
|
||||
if force_engine_rebuild or not os.path.exists(onnx_path):
|
||||
logger.warning(f"Exporting model: {onnx_path}")
|
||||
model = model_obj.get_model()
|
||||
with torch.inference_mode(), torch.autocast("cuda"):
|
||||
inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
inputs,
|
||||
onnx_path,
|
||||
export_params=True,
|
||||
opset_version=onnx_opset,
|
||||
do_constant_folding=True,
|
||||
input_names=model_obj.get_input_names(),
|
||||
output_names=model_obj.get_output_names(),
|
||||
dynamic_axes=model_obj.get_dynamic_axes(),
|
||||
)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
else:
|
||||
logger.warning(f"Found cached model: {onnx_path}")
|
||||
|
||||
# Optimize onnx
|
||||
if force_engine_rebuild or not os.path.exists(onnx_opt_path):
|
||||
logger.warning(f"Generating optimizing model: {onnx_opt_path}")
|
||||
onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))
|
||||
onnx.save(onnx_opt_graph, onnx_opt_path)
|
||||
else:
|
||||
logger.warning(f"Found cached optimized model: {onnx_opt_path} ")
|
||||
|
||||
# Build TensorRT engines
|
||||
for model_name, model_obj in models.items():
|
||||
engine_path = getEnginePath(model_name, engine_dir)
|
||||
engine = Engine(engine_path)
|
||||
onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
|
||||
onnx_opt_path = getOnnxPath(model_name, onnx_dir)
|
||||
|
||||
if force_engine_rebuild or not os.path.exists(engine.engine_path):
|
||||
engine.build(
|
||||
onnx_opt_path,
|
||||
fp16=True,
|
||||
input_profile=model_obj.get_input_profile(
|
||||
opt_batch_size,
|
||||
opt_image_height,
|
||||
opt_image_width,
|
||||
static_batch=static_batch,
|
||||
static_shape=static_shape,
|
||||
),
|
||||
enable_preview=enable_preview,
|
||||
timing_cache=timing_cache,
|
||||
workspace_size=max_workspace_size,
|
||||
)
|
||||
built_engines[model_name] = engine
|
||||
|
||||
# Load and activate TensorRT engines
|
||||
for model_name, model_obj in models.items():
|
||||
engine = built_engines[model_name]
|
||||
engine.load()
|
||||
engine.activate()
|
||||
|
||||
return built_engines
|
||||
|
||||
|
||||
def runEngine(engine, feed_dict, stream):
|
||||
return engine.infer(feed_dict, stream)
|
||||
|
||||
|
||||
class CLIP(BaseModel):
|
||||
def __init__(self, model, device, max_batch_size, embedding_dim):
|
||||
super(CLIP, self).__init__(
|
||||
model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
|
||||
)
|
||||
self.name = "CLIP"
|
||||
|
||||
def get_input_names(self):
|
||||
return ["input_ids"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["text_embeddings", "pooler_output"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
|
||||
batch_size, image_height, image_width, static_batch, static_shape
|
||||
)
|
||||
return {
|
||||
"input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"input_ids": (batch_size, self.text_maxlen),
|
||||
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
|
||||
|
||||
def optimize(self, onnx_graph):
|
||||
opt = Optimizer(onnx_graph)
|
||||
opt.select_outputs([0]) # delete graph output#1
|
||||
opt.cleanup()
|
||||
opt.fold_constants()
|
||||
opt.infer_shapes()
|
||||
opt.select_outputs([0], names=["text_embeddings"]) # rename network output
|
||||
opt_onnx_graph = opt.cleanup(return_onnx=True)
|
||||
return opt_onnx_graph
|
||||
|
||||
|
||||
def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False):
|
||||
return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
|
||||
|
||||
|
||||
class UNet(BaseModel):
|
||||
def __init__(
|
||||
self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4
|
||||
):
|
||||
super(UNet, self).__init__(
|
||||
model=model,
|
||||
fp16=fp16,
|
||||
device=device,
|
||||
max_batch_size=max_batch_size,
|
||||
embedding_dim=embedding_dim,
|
||||
text_maxlen=text_maxlen,
|
||||
)
|
||||
self.unet_dim = unet_dim
|
||||
self.name = "UNet"
|
||||
|
||||
def get_input_names(self):
|
||||
return ["sample", "timestep", "encoder_hidden_states"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["latent"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
"sample": {0: "2B", 2: "H", 3: "W"},
|
||||
"encoder_hidden_states": {0: "2B"},
|
||||
"latent": {0: "2B", 2: "H", 3: "W"},
|
||||
}
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
(
|
||||
min_batch,
|
||||
max_batch,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
|
||||
return {
|
||||
"sample": [
|
||||
(2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
|
||||
(2 * batch_size, self.unet_dim, latent_height, latent_width),
|
||||
(2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
|
||||
],
|
||||
"encoder_hidden_states": [
|
||||
(2 * min_batch, self.text_maxlen, self.embedding_dim),
|
||||
(2 * batch_size, self.text_maxlen, self.embedding_dim),
|
||||
(2 * max_batch, self.text_maxlen, self.embedding_dim),
|
||||
],
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
|
||||
"encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
|
||||
"latent": (2 * batch_size, 4, latent_height, latent_width),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
dtype = torch.float16 if self.fp16 else torch.float32
|
||||
return (
|
||||
torch.randn(
|
||||
2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
|
||||
),
|
||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
|
||||
)
|
||||
|
||||
|
||||
def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False):
|
||||
return UNet(
|
||||
model,
|
||||
fp16=True,
|
||||
device=device,
|
||||
max_batch_size=max_batch_size,
|
||||
embedding_dim=embedding_dim,
|
||||
unet_dim=(9 if inpaint else 4),
|
||||
)
|
||||
|
||||
|
||||
class VAE(BaseModel):
|
||||
def __init__(self, model, device, max_batch_size, embedding_dim):
|
||||
super(VAE, self).__init__(
|
||||
model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
|
||||
)
|
||||
self.name = "VAE decoder"
|
||||
|
||||
def get_input_names(self):
|
||||
return ["latent"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["images"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
(
|
||||
min_batch,
|
||||
max_batch,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
|
||||
return {
|
||||
"latent": [
|
||||
(min_batch, 4, min_latent_height, min_latent_width),
|
||||
(batch_size, 4, latent_height, latent_width),
|
||||
(max_batch, 4, max_latent_height, max_latent_width),
|
||||
]
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"latent": (batch_size, 4, latent_height, latent_width),
|
||||
"images": (batch_size, 3, image_height, image_width),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
|
||||
|
||||
|
||||
def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):
|
||||
return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
|
||||
|
||||
|
||||
class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion.
|
||||
|
||||
This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae"],
|
||||
image_height: int = 768,
|
||||
image_width: int = 768,
|
||||
max_batch_size: int = 16,
|
||||
# ONNX export parameters
|
||||
onnx_opset: int = 17,
|
||||
onnx_dir: str = "onnx",
|
||||
# TensorRT engine build parameters
|
||||
engine_dir: str = "engine",
|
||||
force_engine_rebuild: bool = False,
|
||||
timing_cache: str = "timing_cache",
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
)
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
|
||||
self.stages = stages
|
||||
self.image_height, self.image_width = image_height, image_width
|
||||
self.inpaint = False
|
||||
self.onnx_opset = onnx_opset
|
||||
self.onnx_dir = onnx_dir
|
||||
self.engine_dir = engine_dir
|
||||
self.force_engine_rebuild = force_engine_rebuild
|
||||
self.timing_cache = timing_cache
|
||||
self.build_static_batch = False
|
||||
self.build_dynamic_shape = False
|
||||
self.build_preview_features = False
|
||||
|
||||
self.max_batch_size = max_batch_size
|
||||
# TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
|
||||
if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512:
|
||||
self.max_batch_size = 4
|
||||
|
||||
self.stream = None # loaded in loadResources()
|
||||
self.models = {} # loaded in __loadModels()
|
||||
self.engine = {} # loaded in build_engines()
|
||||
|
||||
def __loadModels(self):
|
||||
# Load pipeline models
|
||||
self.embedding_dim = self.text_encoder.config.hidden_size
|
||||
models_args = {
|
||||
"device": self.torch_device,
|
||||
"max_batch_size": self.max_batch_size,
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"inpaint": self.inpaint,
|
||||
}
|
||||
if "clip" in self.stages:
|
||||
self.models["clip"] = make_CLIP(self.text_encoder, **models_args)
|
||||
if "unet" in self.stages:
|
||||
self.models["unet"] = make_UNet(self.unet, **models_args)
|
||||
if "vae" in self.stages:
|
||||
self.models["vae"] = make_VAE(self.vae, **models_args)
|
||||
|
||||
@classmethod
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
cls.cached_folder = (
|
||||
pretrained_model_name_or_path
|
||||
if os.path.isdir(pretrained_model_name_or_path)
|
||||
else snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
)
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
|
||||
super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings)
|
||||
|
||||
self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir)
|
||||
self.engine_dir = os.path.join(self.cached_folder, self.engine_dir)
|
||||
self.timing_cache = os.path.join(self.cached_folder, self.timing_cache)
|
||||
|
||||
# set device
|
||||
self.torch_device = self._execution_device
|
||||
logger.warning(f"Running inference on device: {self.torch_device}")
|
||||
|
||||
# load models
|
||||
self.__loadModels()
|
||||
|
||||
# build engines
|
||||
self.engine = build_engines(
|
||||
self.models,
|
||||
self.engine_dir,
|
||||
self.onnx_dir,
|
||||
self.onnx_opset,
|
||||
opt_image_height=self.image_height,
|
||||
opt_image_width=self.image_width,
|
||||
force_engine_rebuild=self.force_engine_rebuild,
|
||||
static_batch=self.build_static_batch,
|
||||
static_shape=not self.build_dynamic_shape,
|
||||
enable_preview=self.build_preview_features,
|
||||
timing_cache=self.timing_cache,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __encode_prompt(self, prompt, negative_prompt):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
"""
|
||||
# Tokenize prompt
|
||||
text_input_ids = (
|
||||
self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
.input_ids.type(torch.int32)
|
||||
.to(self.torch_device)
|
||||
)
|
||||
|
||||
text_input_ids_inp = device_view(text_input_ids)
|
||||
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
|
||||
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[
|
||||
"text_embeddings"
|
||||
].clone()
|
||||
|
||||
# Tokenize negative prompt
|
||||
uncond_input_ids = (
|
||||
self.tokenizer(
|
||||
negative_prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
.input_ids.type(torch.int32)
|
||||
.to(self.torch_device)
|
||||
)
|
||||
uncond_input_ids_inp = device_view(uncond_input_ids)
|
||||
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
|
||||
"text_embeddings"
|
||||
]
|
||||
|
||||
# Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def __denoise_latent(
|
||||
self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None
|
||||
):
|
||||
if not isinstance(timesteps, torch.Tensor):
|
||||
timesteps = self.scheduler.timesteps
|
||||
for step_index, timestep in enumerate(timesteps):
|
||||
# Expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
|
||||
if isinstance(mask, torch.Tensor):
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
|
||||
# Predict the noise residual
|
||||
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
|
||||
|
||||
sample_inp = device_view(latent_model_input)
|
||||
timestep_inp = device_view(timestep_float)
|
||||
embeddings_inp = device_view(text_embeddings)
|
||||
noise_pred = runEngine(
|
||||
self.engine["unet"],
|
||||
{"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp},
|
||||
self.stream,
|
||||
)["latent"]
|
||||
|
||||
# Perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
|
||||
latents = 1.0 / 0.18215 * latents
|
||||
return latents
|
||||
|
||||
def __decode_latent(self, latents):
|
||||
images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"]
|
||||
images = (images / 2 + 0.5).clamp(0, 1)
|
||||
return images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
def __loadResources(self, image_height, image_width, batch_size):
|
||||
self.stream = cuda.Stream()
|
||||
|
||||
# Allocate buffers for TensorRT engine bindings
|
||||
for model_name, obj in self.models.items():
|
||||
self.engine[model_name].allocate_buffers(
|
||||
shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
|
||||
"""
|
||||
self.generator = generator
|
||||
self.denoising_steps = num_inference_steps
|
||||
self.guidance_scale = guidance_scale
|
||||
|
||||
# Pre-compute latent input scales and linear multistep coefficients
|
||||
self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)
|
||||
|
||||
# Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
prompt = [prompt]
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}")
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt = [""] * batch_size
|
||||
|
||||
if negative_prompt is not None and isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
|
||||
if batch_size > self.max_batch_size:
|
||||
raise ValueError(
|
||||
f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
|
||||
)
|
||||
|
||||
# load resources
|
||||
self.__loadResources(self.image_height, self.image_width, batch_size)
|
||||
|
||||
with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER):
|
||||
# CLIP text encoder
|
||||
text_embeddings = self.__encode_prompt(prompt, negative_prompt)
|
||||
|
||||
# Pre-initialize latents
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
self.image_height,
|
||||
self.image_width,
|
||||
torch.float32,
|
||||
self.torch_device,
|
||||
generator,
|
||||
)
|
||||
|
||||
# UNet denoiser
|
||||
latents = self.__denoise_latent(latents, text_embeddings)
|
||||
|
||||
# VAE decode latent
|
||||
images = self.__decode_latent(latents)
|
||||
|
||||
images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)
|
||||
images = self.numpy_to_pil(images)
|
||||
return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -372,9 +372,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
num_channels_latents = self.decoder.config.in_channels
|
||||
height = self.decoder.config.sample_size
|
||||
width = self.decoder.config.sample_size
|
||||
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
@@ -425,9 +425,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
channels = self.super_res_first.config.in_channels // 2
|
||||
height = self.super_res_first.config.sample_size
|
||||
width = self.super_res_first.config.sample_size
|
||||
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
|
||||
@@ -452,9 +452,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
num_channels_latents = self.decoder.config.in_channels
|
||||
height = self.decoder.config.sample_size
|
||||
width = self.decoder.config.sample_size
|
||||
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
@@ -505,9 +505,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
channels = self.super_res_first.config.in_channels // 2
|
||||
height = self.super_res_first.config.sample_size
|
||||
width = self.super_res_first.config.sample_size
|
||||
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
|
||||
@@ -96,6 +96,29 @@ accelerate launch train_controlnet.py \
|
||||
--gradient_accumulation_steps=4
|
||||
```
|
||||
|
||||
## Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="path to save model"
|
||||
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_controlnet.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=fusing/fill50k \
|
||||
--resolution=512 \
|
||||
--learning_rate=1e-5 \
|
||||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
||||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
||||
--train_batch_size=4 \
|
||||
--mixed_precision="fp16" \
|
||||
--tracker_project_name="controlnet-demo" \
|
||||
--report_to=wandb
|
||||
```
|
||||
|
||||
## Example results
|
||||
|
||||
#### After 300 steps with batch size 8
|
||||
@@ -284,9 +307,9 @@ TPU_TYPE=v4-8
|
||||
VM_NAME=hg_flax
|
||||
|
||||
gcloud alpha compute tpus tpu-vm create $VM_NAME \
|
||||
--zone $ZONE \
|
||||
--accelerator-type $TPU_TYPE \
|
||||
--version tpu-vm-v4-base
|
||||
--zone $ZONE \
|
||||
--accelerator-type $TPU_TYPE \
|
||||
--version tpu-vm-v4-base
|
||||
|
||||
gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
|
||||
```
|
||||
@@ -326,6 +349,7 @@ If you want to use Weights and Biases logging, you should also install `wandb` n
|
||||
pip install wandb
|
||||
```
|
||||
|
||||
|
||||
Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
|
||||
|
||||
```
|
||||
@@ -343,8 +367,8 @@ Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment v
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="control_out"
|
||||
export HUB_MODEL_ID="fill-circle-controlnet"
|
||||
export OUTPUT_DIR="runs/fill-circle-{timestamp}"
|
||||
export HUB_MODEL_ID="controlnet-fill-circle"
|
||||
```
|
||||
|
||||
And finally start the training
|
||||
@@ -363,32 +387,36 @@ python3 train_controlnet_flax.py \
|
||||
--revision="non-ema" \
|
||||
--from_pt \
|
||||
--report_to="wandb" \
|
||||
--max_train_steps=10000 \
|
||||
--tracker_project_name=$HUB_MODEL_ID \
|
||||
--num_train_epochs=11 \
|
||||
--push_to_hub \
|
||||
--hub_model_id=$HUB_MODEL_ID
|
||||
```
|
||||
|
||||
Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
|
||||
|
||||
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command:
|
||||
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this blog article](https://huggingface.co/blog/train-your-controlnet)):
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="runs/uncanny-faces-{timestamp}"
|
||||
export HUB_MODEL_ID="controlnet-uncanny-faces"
|
||||
|
||||
python3 train_controlnet_flax.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
|
||||
--streaming \
|
||||
--conditioning_image_column=spiga_seg \
|
||||
--image_column=image \
|
||||
--caption_column=image_caption \
|
||||
--resolution=512 \
|
||||
--max_train_samples 50 \
|
||||
--max_train_steps 5 \
|
||||
--learning_rate=1e-5 \
|
||||
--validation_steps=2 \
|
||||
--train_batch_size=1 \
|
||||
--revision="flax" \
|
||||
--report_to="wandb"
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
|
||||
--streaming \
|
||||
--conditioning_image_column=spiga_seg \
|
||||
--image_column=image \
|
||||
--caption_column=image_caption \
|
||||
--resolution=512 \
|
||||
--max_train_samples 100000 \
|
||||
--learning_rate=1e-5 \
|
||||
--train_batch_size=1 \
|
||||
--revision="flax" \
|
||||
--report_to="wandb" \
|
||||
--tracker_project_name=$HUB_MODEL_ID
|
||||
```
|
||||
|
||||
Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
|
||||
@@ -400,16 +428,35 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream
|
||||
When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:
|
||||
|
||||
```bash
|
||||
--checkpointing_steps=500
|
||||
--checkpointing_steps=500
|
||||
```
|
||||
This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500
|
||||
|
||||
You can then start your training from this saved checkpoint with
|
||||
|
||||
```bash
|
||||
--controlnet_model_name_or_path="./control_out/500"
|
||||
--controlnet_model_name_or_path="./control_out/500"
|
||||
```
|
||||
|
||||
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.
|
||||
|
||||
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
|
||||
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
|
||||
|
||||
You can **profile your code** with:
|
||||
|
||||
```bash
|
||||
--profile_steps==5
|
||||
```
|
||||
|
||||
Refer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin:
|
||||
|
||||
```bash
|
||||
pip install tensorflow tensorboard-plugin-profile
|
||||
tensorboard --logdir runs/fill-circle-100steps-20230411_165612/
|
||||
```
|
||||
|
||||
The profile can then be inspected at http://localhost:6006/#profile
|
||||
|
||||
Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).
|
||||
|
||||
Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).
|
||||
|
||||
@@ -55,7 +55,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -525,6 +525,11 @@ def parse_args(input_args=None):
|
||||
" or the same number of `--validation_prompt`s and `--validation_image`s"
|
||||
)
|
||||
|
||||
if args.resolution % 8 != 0:
|
||||
raise ValueError(
|
||||
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@@ -607,6 +612,7 @@ def make_train_dataset(args, tokenizer, accelerator):
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
@@ -615,6 +621,7 @@ def make_train_dataset(args, tokenizer, accelerator):
|
||||
conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import jax
|
||||
@@ -58,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -75,20 +76,11 @@ def image_grid(imgs, rows, cols):
|
||||
return grid
|
||||
|
||||
|
||||
def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype):
|
||||
logger.info("Running validation... ")
|
||||
def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype):
|
||||
logger.info("Running validation...")
|
||||
|
||||
pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
dtype=weight_dtype,
|
||||
revision=args.revision,
|
||||
from_pt=args.from_pt,
|
||||
)
|
||||
params = jax_utils.replicate(params)
|
||||
params["controlnet"] = controlnet_params
|
||||
pipeline_params = pipeline_params.copy()
|
||||
pipeline_params["controlnet"] = controlnet_params
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prng_seed = jax.random.split(rng, jax.device_count())
|
||||
@@ -120,7 +112,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
|
||||
images = pipeline(
|
||||
prompt_ids=prompt_ids,
|
||||
image=processed_image,
|
||||
params=params,
|
||||
params=pipeline_params,
|
||||
prng_seed=prng_seed,
|
||||
num_inference_steps=50,
|
||||
jit=True,
|
||||
@@ -175,6 +167,7 @@ tags:
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- controlnet
|
||||
- jax-diffusers-event
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
@@ -220,6 +213,28 @@ def parse_args():
|
||||
default=None,
|
||||
help="Revision of controlnet model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile_steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="How many training steps to profile in the beginning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile_validation",
|
||||
action="store_true",
|
||||
help="Whether to profile the (last) validation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile_memory",
|
||||
action="store_true",
|
||||
help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ccache",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Enables compilation cache.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--controlnet_from_pt",
|
||||
action="store_true",
|
||||
@@ -234,8 +249,9 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="controlnet-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
default="runs/{timestamp}",
|
||||
help="The output directory where the model predictions and checkpoints will be written. "
|
||||
"Can contain placeholders: {timestamp}.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
@@ -317,15 +333,6 @@ def parse_args():
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_steps",
|
||||
type=int,
|
||||
@@ -459,6 +466,8 @@ def parse_args():
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
args = parser.parse_args()
|
||||
args.output_dir = args.output_dir.replace("{timestamp}", time.strftime("%Y%m%d_%H%M%S"))
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
@@ -783,6 +792,17 @@ def main():
|
||||
]:
|
||||
controlnet_params[key] = unet_params[key]
|
||||
|
||||
pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
dtype=weight_dtype,
|
||||
revision=args.revision,
|
||||
from_pt=args.from_pt,
|
||||
)
|
||||
pipeline_params = jax_utils.replicate(pipeline_params)
|
||||
|
||||
# Optimization
|
||||
if args.scale_lr:
|
||||
args.learning_rate = args.learning_rate * total_train_batch_size
|
||||
@@ -952,6 +972,11 @@ def main():
|
||||
metrics = {"loss": loss}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
|
||||
def l2(xs):
|
||||
return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))
|
||||
|
||||
metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad))
|
||||
|
||||
return new_state, metrics, new_train_rng
|
||||
|
||||
# Create parallel version of the train step
|
||||
@@ -983,32 +1008,38 @@ def main():
|
||||
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
|
||||
logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}")
|
||||
|
||||
if jax.process_index() == 0:
|
||||
if jax.process_index() == 0 and args.report_to == "wandb":
|
||||
wandb.define_metric("*", step_metric="train/step")
|
||||
wandb.define_metric("train/step", step_metric="walltime")
|
||||
wandb.config.update(
|
||||
{
|
||||
"num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
|
||||
"total_train_batch_size": total_train_batch_size,
|
||||
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
|
||||
"num_devices": jax.device_count(),
|
||||
"controlnet_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)),
|
||||
}
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
global_step = step0 = 0
|
||||
epochs = tqdm(
|
||||
range(args.num_train_epochs),
|
||||
desc="Epoch ... ",
|
||||
position=0,
|
||||
disable=jax.process_index() > 0,
|
||||
)
|
||||
if args.profile_memory:
|
||||
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof"))
|
||||
t00 = t0 = time.monotonic()
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
|
||||
train_metrics = []
|
||||
train_metric = None
|
||||
|
||||
steps_per_epoch = (
|
||||
args.max_train_samples // total_train_batch_size
|
||||
if args.streaming
|
||||
if args.streaming or args.max_train_samples
|
||||
else len(train_dataset) // total_train_batch_size
|
||||
)
|
||||
train_step_progress_bar = tqdm(
|
||||
@@ -1020,10 +1051,18 @@ def main():
|
||||
)
|
||||
# train
|
||||
for batch in train_dataloader:
|
||||
if args.profile_steps and global_step == 1:
|
||||
train_metric["loss"].block_until_ready()
|
||||
jax.profiler.start_trace(args.output_dir)
|
||||
if args.profile_steps and global_step == 1 + args.profile_steps:
|
||||
train_metric["loss"].block_until_ready()
|
||||
jax.profiler.stop_trace()
|
||||
|
||||
batch = shard(batch)
|
||||
state, train_metric, train_rngs = p_train_step(
|
||||
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
|
||||
)
|
||||
with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
|
||||
state, train_metric, train_rngs = p_train_step(
|
||||
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
|
||||
)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
train_step_progress_bar.update(1)
|
||||
@@ -1037,17 +1076,25 @@ def main():
|
||||
and global_step % args.validation_steps == 0
|
||||
and jax.process_index() == 0
|
||||
):
|
||||
_ = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
|
||||
_ = log_validation(
|
||||
pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype
|
||||
)
|
||||
|
||||
if global_step % args.logging_steps == 0 and jax.process_index() == 0:
|
||||
if args.report_to == "wandb":
|
||||
train_metrics = jax_utils.unreplicate(train_metrics)
|
||||
train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
|
||||
wandb.log(
|
||||
{
|
||||
"walltime": time.monotonic() - t00,
|
||||
"train/step": global_step,
|
||||
"train/epoch": epoch,
|
||||
"train/loss": jax_utils.unreplicate(train_metric)["loss"],
|
||||
"train/epoch": global_step / dataset_length,
|
||||
"train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0),
|
||||
**{f"train/{k}": v for k, v in train_metrics.items()},
|
||||
}
|
||||
)
|
||||
t0, step0 = time.monotonic(), global_step
|
||||
train_metrics = []
|
||||
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
|
||||
controlnet.save_pretrained(
|
||||
f"{args.output_dir}/{global_step}",
|
||||
@@ -1058,10 +1105,16 @@ def main():
|
||||
train_step_progress_bar.close()
|
||||
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
# Final validation & store model.
|
||||
if jax.process_index() == 0:
|
||||
if args.validation_prompt is not None:
|
||||
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
|
||||
if args.profile_validation:
|
||||
jax.profiler.start_trace(args.output_dir)
|
||||
image_logs = log_validation(
|
||||
pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype
|
||||
)
|
||||
if args.profile_validation:
|
||||
jax.profiler.stop_trace()
|
||||
else:
|
||||
image_logs = None
|
||||
|
||||
@@ -1084,6 +1137,10 @@ def main():
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
if args.profile_memory:
|
||||
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_final.prof"))
|
||||
logger.info("Finished training.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
280
examples/custom_diffusion/README.md
Normal file
280
examples/custom_diffusion/README.md
Normal file
@@ -0,0 +1,280 @@
|
||||
# Custom Diffusion training example
|
||||
|
||||
[Custom Diffusion](https://arxiv.org/abs/2212.04488) is a method to customize text-to-image models like Stable Diffusion given just a few (4~5) images of a subject.
|
||||
The `train_custom_diffusion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install clip-retrieval
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell e.g. a notebook
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
### Cat example 😺
|
||||
|
||||
Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it.
|
||||
|
||||
We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
|
||||
The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200
|
||||
```
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
export INSTANCE_DIR="./data/cat"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_cat/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="cat" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> cat" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=250 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>"
|
||||
```
|
||||
|
||||
**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.**
|
||||
|
||||
To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps:
|
||||
|
||||
* Install `wandb`: `pip install wandb`.
|
||||
* Authorize: `wandb login`.
|
||||
* Then specify a `validation_prompt` and set `report_to` to `wandb` while launching training. You can also configure the following related arguments:
|
||||
* `num_validation_images`
|
||||
* `validation_steps`
|
||||
|
||||
Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_cat/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="cat" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> cat" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=250 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>" \
|
||||
--validation_prompt="<new1> cat sitting in a bucket" \
|
||||
--report_to="wandb"
|
||||
```
|
||||
|
||||
Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau) where you can check out the intermediate results along with other training details.
|
||||
|
||||
If you specify `--push_to_hub`, the learned parameters will be pushed to a repository on the Hugging Face Hub. Here is an [example repository](https://huggingface.co/sayakpaul/custom-diffusion-cat).
|
||||
|
||||
### Training on multiple concepts 🐱🪵
|
||||
|
||||
Provide a [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with the info about each concept, similar to [this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py).
|
||||
|
||||
To collect the real images run this command for each concept in the json file.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200
|
||||
```
|
||||
|
||||
And then we're ready to start training!
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--concepts_list=./concept_list.json \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--num_class_images=200 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>+<new2>"
|
||||
```
|
||||
|
||||
Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg) where you can check out the intermediate results along with other training details.
|
||||
|
||||
### Training on human faces
|
||||
|
||||
For fine-tuning on human faces we found the following configuration to work better: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, and `freeze_model=crossattn` with at least 15-20 images.
|
||||
|
||||
To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200
|
||||
```
|
||||
|
||||
Then start training!
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
export INSTANCE_DIR="path-to-images"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_person/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="person" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> person" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=1000 \
|
||||
--scale_lr --hflip --noaug \
|
||||
--freeze_model crossattn \
|
||||
--modifier_token "<new1>" \
|
||||
--enable_xformers_memory_efficient_attention
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a model using the above command, you can run inference using the below command. Make sure to include the `modifier token` (e.g. \<new1\> in above example) in your prompt.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipe.unet.load_attn_procs(
|
||||
"path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin"
|
||||
)
|
||||
pipe.load_textual_inversion("path-to-save-model", weight_name="<new1>.bin")
|
||||
|
||||
image = pipe(
|
||||
"<new1> cat sitting in a bucket",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
It's possible to directly load these parameters from a Hub repository:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "sayakpaul/custom-diffusion-cat"
|
||||
card = RepoCard.load(model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(
|
||||
"cuda")
|
||||
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
|
||||
|
||||
image = pipe(
|
||||
"<new1> cat sitting in a bucket",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
Here is an example of performing inference with multiple concepts:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "sayakpaul/custom-diffusion-cat-wooden-pot"
|
||||
card = RepoCard.load(model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(
|
||||
"cuda")
|
||||
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new2>.bin")
|
||||
|
||||
image = pipe(
|
||||
"the <new1> cat sculpture in the style of a <new2> wooden pot",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("multi-subject.png")
|
||||
```
|
||||
|
||||
Here, `cat` and `wooden pot` refer to the multiple concepts.
|
||||
|
||||
### Inference from a training checkpoint
|
||||
|
||||
You can also perform inference from one of the complete checkpoint saved during the training process, if you used the `--checkpointing_steps` argument.
|
||||
|
||||
TODO.
|
||||
|
||||
## Set grads to none
|
||||
To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
|
||||
|
||||
More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
|
||||
|
||||
## Experimental results
|
||||
You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail.
|
||||
6
examples/custom_diffusion/requirements.txt
Normal file
6
examples/custom_diffusion/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
accelerate
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
87
examples/custom_diffusion/retrieve.py
Normal file
87
examples/custom_diffusion/retrieve.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2023 Custom Diffusion authors. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from clip_retrieval.clip_client import ClipClient
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def retrieve(class_prompt, class_data_dir, num_class_images):
|
||||
factor = 1.5
|
||||
num_images = int(factor * num_class_images)
|
||||
client = ClipClient(
|
||||
url="https://knn.laion.ai/knn-service", indice_name="laion_400m", num_images=num_images, aesthetic_weight=0.1
|
||||
)
|
||||
|
||||
os.makedirs(f"{class_data_dir}/images", exist_ok=True)
|
||||
if len(list(Path(f"{class_data_dir}/images").iterdir())) >= num_class_images:
|
||||
return
|
||||
|
||||
while True:
|
||||
class_images = client.query(text=class_prompt)
|
||||
if len(class_images) >= factor * num_class_images or num_images > 1e4:
|
||||
break
|
||||
else:
|
||||
num_images = int(factor * num_images)
|
||||
client = ClipClient(
|
||||
url="https://knn.laion.ai/knn-service",
|
||||
indice_name="laion_400m",
|
||||
num_images=num_images,
|
||||
aesthetic_weight=0.1,
|
||||
)
|
||||
|
||||
count = 0
|
||||
total = 0
|
||||
pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
|
||||
|
||||
with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open(
|
||||
f"{class_data_dir}/images.txt", "w"
|
||||
) as f3:
|
||||
while total < num_class_images:
|
||||
images = class_images[count]
|
||||
count += 1
|
||||
try:
|
||||
img = requests.get(images["url"])
|
||||
if img.status_code == 200:
|
||||
_ = Image.open(BytesIO(img.content))
|
||||
with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f:
|
||||
f.write(img.content)
|
||||
f1.write(images["caption"] + "\n")
|
||||
f2.write(images["url"] + "\n")
|
||||
f3.write(f"{class_data_dir}/images/{total}.jpg" + "\n")
|
||||
total += 1
|
||||
pbar.update(1)
|
||||
else:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
return
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("", add_help=False)
|
||||
parser.add_argument("--class_prompt", help="text prompt to retrieve images", required=True, type=str)
|
||||
parser.add_argument("--class_data_dir", help="path to save images", required=True, type=str)
|
||||
parser.add_argument("--num_class_images", help="number of images to download", default=200, type=int)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
retrieve(args.class_prompt, args.class_data_dir, args.num_class_images)
|
||||
1289
examples/custom_diffusion/train_custom_diffusion.py
Normal file
1289
examples/custom_diffusion/train_custom_diffusion.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -45,15 +45,28 @@ write_basic_config()
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data.
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
And launch the training using
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
And launch the training using:
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_dreambooth.py \
|
||||
@@ -77,7 +90,7 @@ According to the paper, it's recommended to generate `num_epochs * num_samples`
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -108,7 +121,7 @@ To install `bitandbytes` please refer to this [readme](https://github.com/TimDet
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -141,7 +154,7 @@ It is possible to run dreambooth on a 12GB GPU by using the following optimizati
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -185,7 +198,7 @@ does not seem to be compatible with DeepSpeed at the moment.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -217,7 +230,7 @@ ___Note: Training text encoder requires more memory, with this option the traini
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -300,7 +313,7 @@ Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
```
|
||||
|
||||
@@ -342,6 +355,12 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr
|
||||
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
|
||||
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
|
||||
|
||||
Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we
|
||||
enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).
|
||||
|
||||
With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).
|
||||
|
||||
|
||||
### Inference
|
||||
|
||||
After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
|
||||
@@ -386,7 +405,7 @@ pip install -U -r requirements_flax.txt
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
python train_dreambooth_flax.py \
|
||||
@@ -405,7 +424,7 @@ python train_dreambooth_flax.py \
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -429,7 +448,7 @@ python train_dreambooth_flax.py \
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -43,22 +44,23 @@ from diffusers import (
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
@@ -83,6 +85,8 @@ inference: true
|
||||
|
||||
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
|
||||
{img_str}
|
||||
|
||||
LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
@@ -219,6 +223,11 @@ def parse_args(input_args=None):
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
@@ -547,7 +556,13 @@ def main(args):
|
||||
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
||||
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
||||
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
# TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
||||
raise ValueError(
|
||||
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -691,7 +706,7 @@ def main(args):
|
||||
# => 32 layers
|
||||
|
||||
# Set correct lora layers
|
||||
lora_attn_procs = {}
|
||||
unet_lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
@@ -703,12 +718,33 @@ def main(args):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
unet_lora_attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
|
||||
unet.set_attn_processor(lora_attn_procs)
|
||||
lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
unet.set_attn_processor(unet_lora_attn_procs)
|
||||
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
accelerator.register_for_checkpointing(unet_lora_layers)
|
||||
|
||||
accelerator.register_for_checkpointing(lora_layers)
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
|
||||
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
|
||||
text_encoder_lora_layers = None
|
||||
if args.train_text_encoder:
|
||||
text_lora_attn_procs = {}
|
||||
for name, module in text_encoder.named_modules():
|
||||
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
|
||||
text_lora_attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=module.out_features, cross_attention_dim=None
|
||||
)
|
||||
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
|
||||
temp_pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, text_encoder=text_encoder
|
||||
)
|
||||
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
|
||||
text_encoder = temp_pipeline.text_encoder
|
||||
accelerator.register_for_checkpointing(text_encoder_lora_layers)
|
||||
del temp_pipeline
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
@@ -739,8 +775,13 @@ def main(args):
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# Optimizer creation
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet_lora_layers.parameters()
|
||||
)
|
||||
optimizer = optimizer_class(
|
||||
lora_layers.parameters(),
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
@@ -784,9 +825,14 @@ def main(args):
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
lora_layers, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -845,6 +891,8 @@ def main(args):
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
@@ -900,7 +948,11 @@ def main(args):
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = lora_layers.parameters()
|
||||
params_to_clip = (
|
||||
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet_lora_layers.parameters()
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -914,7 +966,14 @@ def main(args):
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
# We combine the text encoder and UNet LoRA parameters with a simple
|
||||
# custom logic. `accelerator.save_state()` won't know that. So,
|
||||
# use `LoraLoaderMixin.save_lora_weights()`.
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
@@ -970,7 +1029,12 @@ def main(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = unet.to(torch.float32)
|
||||
unet.save_attn_procs(args.output_dir)
|
||||
text_encoder = text_encoder.to(torch.float32)
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
@@ -981,7 +1045,7 @@ def main(args):
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
|
||||
# load attention processors
|
||||
pipeline.unet.load_attn_procs(args.output_dir)
|
||||
pipeline.load_attn_procs(args.output_dir)
|
||||
|
||||
# run inference
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
@@ -1010,6 +1074,7 @@ def main(args):
|
||||
repo_id,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
prompt=args.instance_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
)
|
||||
|
||||
@@ -113,6 +113,27 @@ accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
|
||||
|
||||
***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***
|
||||
|
||||
## Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix.py \
|
||||
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
|
||||
--dataset_name=sayakpaul/instructpix2pix-1000-samples \
|
||||
--use_ema \
|
||||
--enable_xformers_memory_efficient_attention \
|
||||
--resolution=512 --random_flip \
|
||||
--train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
|
||||
--max_train_steps=15000 \
|
||||
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
|
||||
--learning_rate=5e-05 --lr_warmup_steps=0 \
|
||||
--conditioning_dropout_prob=0.05 \
|
||||
--mixed_precision=fp16 \
|
||||
--seed=42
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Once training is complete, we can perform inference:
|
||||
|
||||
@@ -51,7 +51,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -735,7 +735,7 @@ def main():
|
||||
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
).to(dtype=weight_dtype)
|
||||
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# Distillation for quantization on Textual Inversion models to personalize text2image
|
||||
|
||||
[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images._By using just 3-5 images new concepts can be taught to Stable Diffusion and the model personalized on your own images_
|
||||
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
We have enabled distillation for quantization in `textual_inversion.py` to do quantization aware training as well as distillation on the model generated by Textual Inversion method.
|
||||
|
||||
## Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Prepare Datasets
|
||||
|
||||
One picture which is from the huggingface datasets [sd-concepts-library/dicoo2](https://huggingface.co/sd-concepts-library/dicoo2) is needed, and save it to the `./dicoo` directory. The picture is shown below:
|
||||
|
||||
<a href="https://huggingface.co/sd-concepts-library/dicoo2/blob/main/concept_images/1.jpeg">
|
||||
<img src="https://huggingface.co/sd-concepts-library/dicoo2/resolve/main/concept_images/1.jpeg" width = "300" height="300">
|
||||
</a>
|
||||
|
||||
## Get a FP32 Textual Inversion model
|
||||
|
||||
Use the following command to fine-tune the Stable Diffusion model on the above dataset to obtain the FP32 Textual Inversion model.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export DATA_DIR="./dicoo"
|
||||
|
||||
accelerate launch textual_inversion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_data_dir=$DATA_DIR \
|
||||
--learnable_property="object" \
|
||||
--placeholder_token="<dicoo>" --initializer_token="toy" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--max_train_steps=3000 \
|
||||
--learning_rate=5.0e-04 --scale_lr \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="dicoo_model"
|
||||
```
|
||||
|
||||
## Do distillation for quantization
|
||||
|
||||
Distillation for quantization is a method that combines [intermediate layer knowledge distillation](https://github.com/intel/neural-compressor/blob/master/docs/source/distillation.md#intermediate-layer-knowledge-distillation) and [quantization aware training](https://github.com/intel/neural-compressor/blob/master/docs/source/quantization.md#quantization-aware-training) in the same training process to improve the performance of the quantized model. Provided a FP32 model, the distillation for quantization approach will take this model itself as the teacher model and transfer the knowledges of the specified layers to the student model, i.e. quantized version of the FP32 model, during the quantization aware training process.
|
||||
|
||||
Once you have the FP32 Textual Inversion model, the following command will take the FP32 Textual Inversion model as input to do distillation for quantization and generate the INT8 Textual Inversion model.
|
||||
|
||||
```bash
|
||||
export FP32_MODEL_NAME="./dicoo_model"
|
||||
export DATA_DIR="./dicoo"
|
||||
|
||||
accelerate launch textual_inversion.py \
|
||||
--pretrained_model_name_or_path=$FP32_MODEL_NAME \
|
||||
--train_data_dir=$DATA_DIR \
|
||||
--use_ema --learnable_property="object" \
|
||||
--placeholder_token="<dicoo>" --initializer_token="toy" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--max_train_steps=300 \
|
||||
--learning_rate=5.0e-04 --max_grad_norm=3 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="int8_model" \
|
||||
--do_quantization --do_distillation --verify_loading
|
||||
```
|
||||
|
||||
After the distillation for quantization process, the quantized UNet would be 4 times smaller (3279MB -> 827MB).
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a INT8 model with the above command, the inference can be done simply using the `text2images.py` script. Make sure to include the `placeholder_token` in your prompt.
|
||||
|
||||
```bash
|
||||
export INT8_MODEL_NAME="./int8_model"
|
||||
|
||||
python text2images.py \
|
||||
--pretrained_model_name_or_path=$INT8_MODEL_NAME \
|
||||
--caption "a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brighly buildings." \
|
||||
--images_num 4
|
||||
```
|
||||
|
||||
Here is the comparison of images generated by the FP32 model (left) and INT8 model (right) respectively:
|
||||
|
||||
<p float="left">
|
||||
<img src="https://huggingface.co/datasets/Intel/textual_inversion_dicoo_dfq/resolve/main/FP32.png" width = "300" height = "300" alt="FP32" align=center />
|
||||
<img src="https://huggingface.co/datasets/Intel/textual_inversion_dicoo_dfq/resolve/main/INT8.png" width = "300" height = "300" alt="INT8" align=center />
|
||||
</p>
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
accelerate
|
||||
torchvision
|
||||
transformers>=4.25.0
|
||||
ftfy
|
||||
tensorboard
|
||||
modelcards
|
||||
neural-compressor
|
||||
@@ -0,0 +1,112 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
from neural_compressor.utils.pytorch import load
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--caption",
|
||||
type=str,
|
||||
default="robotic cat with wings",
|
||||
help="Text used to generate images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--images_num",
|
||||
type=int,
|
||||
default=4,
|
||||
help="How much images to generate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Seed for random process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ci",
|
||||
"--cuda_id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="cuda_id.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def image_grid(imgs, rows, cols):
|
||||
if not len(imgs) == rows * cols:
|
||||
raise ValueError("The specified number of rows and columns are not correct.")
|
||||
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||
grid_w, grid_h = grid.size
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
return grid
|
||||
|
||||
|
||||
def generate_images(
|
||||
pipeline,
|
||||
prompt="robotic cat with wings",
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=50,
|
||||
num_images_per_prompt=1,
|
||||
seed=42,
|
||||
):
|
||||
generator = torch.Generator(pipeline.device).manual_seed(seed)
|
||||
images = pipeline(
|
||||
prompt,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=generator,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
).images
|
||||
_rows = int(math.sqrt(num_images_per_prompt))
|
||||
grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
|
||||
return grid, images
|
||||
|
||||
|
||||
args = parse_args()
|
||||
# Load models and create wrapper for stable diffusion
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer
|
||||
)
|
||||
pipeline.safety_checker = lambda images, clip_input: (images, False)
|
||||
if os.path.exists(os.path.join(args.pretrained_model_name_or_path, "best_model.pt")):
|
||||
unet = load(args.pretrained_model_name_or_path, model=unet)
|
||||
unet.eval()
|
||||
setattr(pipeline, "unet", unet)
|
||||
else:
|
||||
unet = unet.to(torch.device("cuda", args.cuda_id))
|
||||
pipeline = pipeline.to(unet.device)
|
||||
grid, images = generate_images(pipeline, prompt=args.caption, num_images_per_prompt=args.images_num, seed=args.seed)
|
||||
grid.save(os.path.join(args.pretrained_model_name_or_path, "{}.png".format("_".join(args.caption.split()))))
|
||||
dirname = os.path.join(args.pretrained_model_name_or_path, "_".join(args.caption.split()))
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
for idx, image in enumerate(images):
|
||||
image.save(os.path.join(dirname, "{}.png".format(idx + 1)))
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,7 @@
|
||||
## Multi Token Textual Inversion
|
||||
## [Deprecated] Multi Token Textual Inversion
|
||||
|
||||
**IMPORTART: This research project is deprecated. Multi Token Textual Inversion is now supported natively in [the officail textual inversion example](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion#running-locally-with-pytorch).**
|
||||
|
||||
The author of this project is [Isamu Isozaki](https://github.com/isamu-isozaki) - please make sure to tag the author for issue and PRs as well as @patrickvonplaten.
|
||||
|
||||
We add multi token support to textual inversion. I added
|
||||
|
||||
@@ -23,6 +23,7 @@ import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
from diffusers import DiffusionPipeline, UNet2DConditionModel
|
||||
@@ -104,6 +105,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token a
|
||||
--validation_prompt <cat-toy>
|
||||
--validation_steps 1
|
||||
--save_steps 1
|
||||
--num_vectors 2
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
@@ -221,6 +226,92 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
|
||||
|
||||
def test_dreambooth_lora(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` in their names.
|
||||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
def test_dreambooth_lora_with_text_encoder(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--train_text_encoder
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
|
||||
|
||||
# the names of the keys of the state dict should either start with `unet`
|
||||
# or `text_encoder`.
|
||||
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
|
||||
keys = lora_state_dict.keys()
|
||||
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
|
||||
self.assertTrue(is_correct_naming)
|
||||
|
||||
def test_custom_diffusion(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/custom_diffusion/train_custom_diffusion.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt <new1>
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 1.0e-05
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--modifier_token <new1>
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "<new1>.bin")))
|
||||
|
||||
def test_text_to_image(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
@@ -111,6 +111,31 @@ image = pipe(prompt="yoda").images[0]
|
||||
image.save("yoda-pokemon.png")
|
||||
```
|
||||
|
||||
#### Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export dataset_name="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$dataset_name \
|
||||
--use_ema \
|
||||
--resolution=512 --center_crop --random_flip \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--max_train_steps=15000 \
|
||||
--learning_rate=1e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--output_dir="sd-pokemon-model"
|
||||
```
|
||||
|
||||
|
||||
#### Training with Min-SNR weighting
|
||||
|
||||
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -64,8 +64,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
tokenizer=tokenizer,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
safety_checker=None,
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -340,11 +340,10 @@ def main():
|
||||
|
||||
return examples
|
||||
|
||||
if jax.process_index() == 0:
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
# Set the training transforms
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
|
||||
@@ -47,7 +47,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -36,32 +36,33 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e
|
||||
accelerate config
|
||||
```
|
||||
|
||||
|
||||
### Cat toy example
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
Run the following command to authenticate your token
|
||||
First, let's login so that we can upload the checkpoint to the Hub during training:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
If you have already cloned the repo, then you won't need to go through these steps.
|
||||
Now let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .
|
||||
|
||||
<br>
|
||||
Let's first download it locally:
|
||||
|
||||
Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data.
|
||||
```py
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
And launch the training using
|
||||
local_dir = "./cat"
|
||||
snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes")
|
||||
```
|
||||
|
||||
This will be our training data.
|
||||
Now we can launch the training using
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export DATA_DIR="path-to-dir-containing-images"
|
||||
export DATA_DIR="./cat"
|
||||
|
||||
accelerate launch textual_inversion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
@@ -75,11 +76,24 @@ accelerate launch textual_inversion.py \
|
||||
--learning_rate=5.0e-04 --scale_lr \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--push_to_hub \
|
||||
--output_dir="textual_inversion_cat"
|
||||
```
|
||||
|
||||
A full training run takes ~1 hour on one V100 GPU.
|
||||
|
||||
**Note**: As described in [the official paper](https://arxiv.org/abs/2208.01618)
|
||||
only one embedding vector is used for the placeholder token, *e.g.* `"<cat-toy>"`.
|
||||
However, one can also add multiple embedding vectors for the placeholder token
|
||||
to inclease the number of fine-tuneable parameters. This can help the model to learn
|
||||
more complex details. To use multiple embedding vectors, you can should define `--num_vectors`
|
||||
to a number larger than one, *e.g.*:
|
||||
```
|
||||
--num_vectors 5
|
||||
```
|
||||
|
||||
The saved textual inversion vectors will then be larger in size compared to the default case.
|
||||
|
||||
### Inference
|
||||
|
||||
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
|
||||
|
||||
@@ -77,11 +77,39 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- textual_inversion
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# Textual inversion text2image fine-tuning - {repo_id}
|
||||
These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
|
||||
{img_str}
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
@@ -94,6 +122,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
safety_checker=None,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
@@ -124,11 +153,16 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
return images
|
||||
|
||||
|
||||
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
|
||||
def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path):
|
||||
logger.info("Saving embeddings")
|
||||
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
||||
learned_embeds = (
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
|
||||
)
|
||||
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
@@ -144,9 +178,15 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--only_save_embeds",
|
||||
action="store_true",
|
||||
default=False,
|
||||
default=True,
|
||||
help="Save only the embeddings for the new concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_vectors",
|
||||
type=int,
|
||||
default=1,
|
||||
help="How many textual inversion vectors shall be used to learn the concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
@@ -581,8 +621,19 @@ def main():
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
placeholder_tokens = [args.placeholder_token]
|
||||
|
||||
if args.num_vectors < 1:
|
||||
raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}")
|
||||
|
||||
# add dummy tokens for multi-vector
|
||||
additional_tokens = []
|
||||
for i in range(1, args.num_vectors):
|
||||
additional_tokens.append(f"{args.placeholder_token}_{i}")
|
||||
placeholder_tokens += additional_tokens
|
||||
|
||||
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
|
||||
if num_added_tokens != args.num_vectors:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
@@ -595,14 +646,16 @@ def main():
|
||||
raise ValueError("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
with torch.no_grad():
|
||||
for token_id in placeholder_token_ids:
|
||||
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
||||
|
||||
# Freeze vae and unet
|
||||
vae.requires_grad_(False)
|
||||
@@ -810,7 +863,9 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
|
||||
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
@@ -818,11 +873,12 @@ def main():
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
images = []
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % args.save_steps == 0:
|
||||
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
||||
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
@@ -831,7 +887,9 @@ def main():
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
||||
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
|
||||
images = log_validation(
|
||||
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
|
||||
)
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
@@ -858,9 +916,15 @@ def main():
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
# Save the newly trained embeddings
|
||||
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
||||
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(
|
||||
repo_id,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
repo_folder=args.output_dir,
|
||||
)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
## Training examples
|
||||
## Training an unconditional diffusion model
|
||||
|
||||
Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).
|
||||
|
||||
@@ -76,6 +76,27 @@ A full training run takes 2 hours on 4xV100 GPUs.
|
||||
|
||||
<img src="https://user-images.githubusercontent.com/26864830/180248200-928953b4-db38-48db-b0c6-8b740fe6786f.png" width="700" />
|
||||
|
||||
### Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
|
||||
--dataset_name="huggan/pokemon" \
|
||||
--resolution=64 --center_crop --random_flip \
|
||||
--output_dir="ddpm-ema-pokemon-64" \
|
||||
--train_batch_size=16 \
|
||||
--num_epochs=100 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--use_ema \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_warmup_steps=500 \
|
||||
--mixed_precision="fp16" \
|
||||
--logger="wandb"
|
||||
```
|
||||
|
||||
To be able to use Weights and Biases (`wandb`) as a logger you need to install the library: `pip install wandb`.
|
||||
|
||||
### Using your own data
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
1257
scripts/convert_if.py
Normal file
1257
scripts/convert_if.py
Normal file
File diff suppressed because it is too large
Load Diff
2
setup.py
2
setup.py
@@ -226,7 +226,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.15.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.16.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.15.0"
|
||||
__version__ = "0.16.1"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import (
|
||||
@@ -109,12 +109,17 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .loaders import TextualInversionLoaderMixin
|
||||
from .pipelines import (
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
AudioLDMPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
|
||||
@@ -109,6 +109,7 @@ class ConfigMixin:
|
||||
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
||||
# or solve in a more general way.
|
||||
kwargs.pop("kwargs", None)
|
||||
|
||||
if not hasattr(self, "_internal_dict"):
|
||||
internal_dict = kwargs
|
||||
else:
|
||||
@@ -118,6 +119,24 @@ class ConfigMixin:
|
||||
|
||||
self._internal_dict = FrozenDict(internal_dict)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
||||
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
|
||||
|
||||
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
|
||||
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
||||
"""
|
||||
|
||||
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
||||
is_attribute = name in self.__dict__
|
||||
|
||||
if is_in_config and not is_attribute:
|
||||
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
|
||||
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
|
||||
return self._internal_dict[name]
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
@@ -532,6 +551,9 @@ class ConfigMixin:
|
||||
return value
|
||||
|
||||
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
||||
# Don't save "_ignore_files"
|
||||
config_dict.pop("_ignore_files", None)
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
|
||||
@@ -13,11 +13,17 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .models.attention_processor import LoRAAttnProcessor
|
||||
from .models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
LoRAAttnProcessor,
|
||||
)
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HF_HUB_OFFLINE,
|
||||
@@ -46,6 +52,9 @@ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
TEXT_INVERSION_NAME = "learned_embeds.bin"
|
||||
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
||||
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
||||
|
||||
|
||||
class AttnProcsLayers(torch.nn.Module):
|
||||
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
||||
@@ -213,6 +222,7 @@ class UNet2DConditionLoadersMixin:
|
||||
attn_processors = {}
|
||||
|
||||
is_lora = all("lora" in k for k in state_dict.keys())
|
||||
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
||||
|
||||
if is_lora:
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
@@ -229,9 +239,38 @@ class UNet2DConditionLoadersMixin:
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
elif is_custom_diffusion:
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
if len(value) == 0:
|
||||
custom_diffusion_grouped_dict[key] = {}
|
||||
else:
|
||||
if "to_out" in key:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
else:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
||||
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
for key, value_dict in custom_diffusion_grouped_dict.items():
|
||||
if len(value_dict) == 0:
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
|
||||
)
|
||||
else:
|
||||
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
|
||||
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
|
||||
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=True,
|
||||
train_q_out=train_q_out,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
else:
|
||||
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
|
||||
raise ValueError(
|
||||
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
||||
)
|
||||
|
||||
# set correct dtype & device
|
||||
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
|
||||
@@ -285,16 +324,31 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
model_to_save = AttnProcsLayers(self.attn_processors)
|
||||
|
||||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
is_custom_diffusion = any(
|
||||
isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
||||
for (_, x) in self.attn_processors.items()
|
||||
)
|
||||
if is_custom_diffusion:
|
||||
model_to_save = AttnProcsLayers(
|
||||
{
|
||||
y: x
|
||||
for (y, x) in self.attn_processors.items()
|
||||
if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
||||
}
|
||||
)
|
||||
state_dict = model_to_save.state_dict()
|
||||
for name, attn in self.attn_processors.items():
|
||||
if len(attn.state_dict()) == 0:
|
||||
state_dict[name] = {}
|
||||
else:
|
||||
model_to_save = AttnProcsLayers(self.attn_processors)
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
|
||||
|
||||
# Save the model
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
@@ -356,7 +410,7 @@ class TextualInversionLoaderMixin:
|
||||
replacement = token
|
||||
i = 1
|
||||
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
||||
replacement += f"{token}_{i}"
|
||||
replacement += f" {token}_{i}"
|
||||
i += 1
|
||||
|
||||
prompt = prompt.replace(token, replacement)
|
||||
@@ -431,6 +485,7 @@ class TextualInversionLoaderMixin:
|
||||
Example:
|
||||
|
||||
To load a textual inversion embedding vector in `diffusers` format:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
@@ -456,13 +511,14 @@ class TextualInversionLoaderMixin:
|
||||
model_id = "runwayml/stable-diffusion-v1-5"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
pipe.load_textual_inversion("./charturnerv2.pt")
|
||||
pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
|
||||
|
||||
prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
|
||||
|
||||
image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
image.save("character.png")
|
||||
```
|
||||
|
||||
"""
|
||||
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
|
||||
raise ValueError(
|
||||
@@ -792,7 +848,7 @@ class LoraLoaderMixin:
|
||||
"""
|
||||
# Loop over the original attention modules.
|
||||
for name, _ in self.text_encoder.named_modules():
|
||||
if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):
|
||||
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
|
||||
# Retrieve the module and its corresponding LoRA processor.
|
||||
module = self.text_encoder.get_submodule(name)
|
||||
# Construct a new function that performs the LoRA merging. We will monkey patch
|
||||
@@ -1051,3 +1107,197 @@ class LoraLoaderMixin:
|
||||
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
|
||||
class FromCkptMixin:
|
||||
"""This helper class allows to directly load .ckpt stable diffusion file_extension
|
||||
into the respective classes."""
|
||||
|
||||
@classmethod
|
||||
def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights saved in the original .ckpt format.
|
||||
|
||||
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the .ckpt file on the Hub. Should be in the format
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>"`
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
use_safetensors (`bool`, *optional* ):
|
||||
If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
|
||||
default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
|
||||
`safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
|
||||
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
||||
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults
|
||||
to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
||||
inference. Non-EMA weights are usually better to continue fine-tuning.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted. This is necessary when running stable
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
|
||||
Base. Use 768 for Stable Diffusion v2.
|
||||
prediction_type (`str`, *optional*):
|
||||
The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
|
||||
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
|
||||
num_in_channels (`int`, *optional*, defaults to None):
|
||||
The number of input channels. If `None`, it will be automatically inferred.
|
||||
scheduler_type (`str`, *optional*, defaults to 'pndm'):
|
||||
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
||||
"ddim"]`.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
||||
Whether to load the safety checker or not. Defaults to `True`.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
||||
`__init__` method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> # Download pipeline from huggingface.co and cache.
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
||||
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||
... )
|
||||
|
||||
>>> # Download pipeline from local file
|
||||
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
|
||||
|
||||
>>> # Enable float16 and move to GPU
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
||||
... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", 512)
|
||||
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||
prediction_type = kwargs.pop("prediction_type", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
|
||||
pipeline_name = cls.__name__
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is True:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# TODO: For now we only support stable diffusion
|
||||
stable_unclip = None
|
||||
controlnet = False
|
||||
|
||||
if pipeline_name == "StableDiffusionControlNetPipeline":
|
||||
model_type = "FrozenCLIPEmbedder"
|
||||
controlnet = True
|
||||
elif "StableDiffusion" in pipeline_name:
|
||||
model_type = "FrozenCLIPEmbedder"
|
||||
elif pipeline_name == "StableUnCLIPPipeline":
|
||||
model_type == "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "txt2img"
|
||||
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
|
||||
model_type == "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "img2img"
|
||||
elif pipeline_name == "PaintByExamplePipeline":
|
||||
model_type == "PaintByExample"
|
||||
elif pipeline_name == "LDMTextToImagePipeline":
|
||||
model_type == "LDMTextToImage"
|
||||
else:
|
||||
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = str(Path().joinpath(*ckpt_path.parts[:2]))
|
||||
file_path = str(Path().joinpath(*ckpt_path.parts[2:]))
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
pipeline_class=cls,
|
||||
model_type=model_type,
|
||||
stable_unclip=stable_unclip,
|
||||
controlnet=controlnet,
|
||||
from_safetensors=from_safetensors,
|
||||
extract_ema=extract_ema,
|
||||
image_size=image_size,
|
||||
scheduler_type=scheduler_type,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=load_safety_checker,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(torch_dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
@@ -60,7 +60,6 @@ class AttentionBlock(nn.Module):
|
||||
self.channels = channels
|
||||
|
||||
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
||||
self.num_head_size = num_head_channels
|
||||
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
|
||||
|
||||
# define q,k,v as linear layers
|
||||
@@ -72,20 +71,30 @@ class AttentionBlock(nn.Module):
|
||||
self.proj_attn = nn.Linear(channels, channels, bias=True)
|
||||
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
self._use_2_0_attn = True
|
||||
self._attention_op = None
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.num_heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3)
|
||||
if merge_head_and_batch:
|
||||
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
|
||||
head_size = self.num_heads
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
|
||||
if unmerge_head_and_batch:
|
||||
batch_head_size, seq_len, dim = tensor.shape
|
||||
batch_size = batch_head_size // head_size
|
||||
|
||||
tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
|
||||
else:
|
||||
batch_size, _, seq_len, dim = tensor.shape
|
||||
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
@@ -134,14 +143,24 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
scale = 1 / math.sqrt(self.channels / self.num_heads)
|
||||
|
||||
query_proj = self.reshape_heads_to_batch_dim(query_proj)
|
||||
key_proj = self.reshape_heads_to_batch_dim(key_proj)
|
||||
value_proj = self.reshape_heads_to_batch_dim(value_proj)
|
||||
_use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
|
||||
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
|
||||
|
||||
query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
|
||||
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
|
||||
value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)
|
||||
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
# Memory efficient attention
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
|
||||
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
|
||||
)
|
||||
hidden_states = hidden_states.to(query_proj.dtype)
|
||||
elif use_torch_2_0_attn:
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.to(query_proj.dtype)
|
||||
else:
|
||||
@@ -162,7 +181,7 @@ class AttentionBlock(nn.Module):
|
||||
hidden_states = torch.bmm(attention_probs, value_proj)
|
||||
|
||||
# reshape hidden_states
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
|
||||
@@ -149,6 +149,9 @@ class Attention(nn.Module):
|
||||
is_lora = hasattr(self, "processor") and isinstance(
|
||||
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
|
||||
)
|
||||
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
||||
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
||||
)
|
||||
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if self.added_kv_proj_dim is not None:
|
||||
@@ -192,6 +195,17 @@ class Attention(nn.Module):
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
processor.to(self.processor.to_q_lora.up.weight.device)
|
||||
elif is_custom_diffusion:
|
||||
processor = CustomDiffusionXFormersAttnProcessor(
|
||||
train_kv=self.processor.train_kv,
|
||||
train_q_out=self.processor.train_q_out,
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
attention_op=attention_op,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_custom_diffusion"):
|
||||
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
||||
else:
|
||||
processor = XFormersAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
@@ -203,6 +217,16 @@ class Attention(nn.Module):
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
processor.to(self.processor.to_q_lora.up.weight.device)
|
||||
elif is_custom_diffusion:
|
||||
processor = CustomDiffusionAttnProcessor(
|
||||
train_kv=self.processor.train_kv,
|
||||
train_q_out=self.processor.train_q_out,
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_custom_diffusion"):
|
||||
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
||||
else:
|
||||
processor = AttnProcessor()
|
||||
|
||||
@@ -459,6 +483,84 @@ class LoRAAttnProcessor(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDiffusionAttnProcessor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
train_kv=True,
|
||||
train_q_out=True,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.train_kv = train_kv
|
||||
self.train_q_out = train_q_out
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
# `_custom_diffusion` id for easy serialization and loading.
|
||||
if self.train_kv:
|
||||
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
if self.train_q_out:
|
||||
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.to_out_custom_diffusion = nn.ModuleList([])
|
||||
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
||||
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
if self.train_q_out:
|
||||
query = self.to_q_custom_diffusion(hidden_states)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
crossattn = False
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
crossattn = True
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
if self.train_kv:
|
||||
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
||||
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
||||
else:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if crossattn:
|
||||
detach = torch.ones_like(key)
|
||||
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
||||
key = detach * key + (1 - detach) * key.detach()
|
||||
value = detach * value + (1 - detach) * value.detach()
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if self.train_q_out:
|
||||
# linear proj
|
||||
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
||||
else:
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnAddedKVProcessor:
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
residual = hidden_states
|
||||
@@ -699,6 +801,91 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
train_kv=True,
|
||||
train_q_out=False,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
dropout=0.0,
|
||||
attention_op: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.train_kv = train_kv
|
||||
self.train_q_out = train_q_out
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.attention_op = attention_op
|
||||
|
||||
# `_custom_diffusion` id for easy serialization and loading.
|
||||
if self.train_kv:
|
||||
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
if self.train_q_out:
|
||||
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.to_out_custom_diffusion = nn.ModuleList([])
|
||||
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
||||
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if self.train_q_out:
|
||||
query = self.to_q_custom_diffusion(hidden_states)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
crossattn = False
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
crossattn = True
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
if self.train_kv:
|
||||
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
||||
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
||||
else:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if crossattn:
|
||||
detach = torch.ones_like(key)
|
||||
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
||||
key = detach * key + (1 - detach) * key.detach()
|
||||
value = detach * value + (1 - detach) * value.detach()
|
||||
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
||||
)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if self.train_q_out:
|
||||
# linear proj
|
||||
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
||||
else:
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SlicedAttnProcessor:
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
@@ -834,4 +1021,6 @@ AttentionProcessor = Union[
|
||||
AttnAddedKVProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
]
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, apply_forward_hook, deprecate
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
@@ -123,16 +123,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
@property
|
||||
def block_out_channels(self):
|
||||
deprecate(
|
||||
"block_out_channels",
|
||||
"1.0.0",
|
||||
"Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.block_out_channels
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (Encoder, Decoder)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -119,6 +119,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -456,6 +457,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
# check channel order
|
||||
@@ -556,8 +558,20 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample *= conditioning_scale
|
||||
if guess_mode:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
|
||||
scales *= conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample *= scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample *= conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
@@ -377,3 +377,69 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
|
||||
conditioning = timesteps_emb + class_labels # (N, D)
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class TextTimeEmbedding(nn.Module):
|
||||
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(encoder_dim)
|
||||
self.pool = AttentionPooling(num_heads, encoder_dim)
|
||||
self.proj = nn.Linear(encoder_dim, time_embed_dim)
|
||||
self.norm2 = nn.LayerNorm(time_embed_dim)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.pool(hidden_states)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttentionPooling(nn.Module):
|
||||
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
||||
|
||||
def __init__(self, num_heads, embed_dim, dtype=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
||||
self.num_heads = num_heads
|
||||
self.dim_per_head = embed_dim // self.num_heads
|
||||
|
||||
def forward(self, x):
|
||||
bs, length, width = x.size()
|
||||
|
||||
def shape(x):
|
||||
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
|
||||
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
|
||||
x = x.transpose(1, 2)
|
||||
return x
|
||||
|
||||
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
|
||||
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
|
||||
|
||||
# (bs*n_heads, class_token_length, dim_per_head)
|
||||
q = shape(self.q_proj(class_token))
|
||||
# (bs*n_heads, length+class_token_length, dim_per_head)
|
||||
k = shape(self.k_proj(x))
|
||||
v = shape(self.v_proj(x))
|
||||
|
||||
# (bs*n_heads, class_token_length, length+class_token_length):
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
|
||||
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
|
||||
# (bs*n_heads, dim_per_head, class_token_length)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
|
||||
# (bs, length+1, width)
|
||||
a = a.reshape(bs, -1, 1).transpose(1, 2)
|
||||
|
||||
return a[:, 0, :] # cls_token
|
||||
|
||||
@@ -110,6 +110,12 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
||||
.replace("_1", ".1")
|
||||
.replace("_2", ".2")
|
||||
.replace("_3", ".3")
|
||||
.replace("_4", ".4")
|
||||
.replace("_5", ".5")
|
||||
.replace("_6", ".6")
|
||||
.replace("_7", ".7")
|
||||
.replace("_8", ".8")
|
||||
.replace("_9", ".9")
|
||||
)
|
||||
|
||||
flax_key = ".".join(flax_key_tuple_array)
|
||||
|
||||
@@ -15,9 +15,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import itertools
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device
|
||||
@@ -32,6 +33,7 @@ from ..utils import (
|
||||
WEIGHTS_NAME,
|
||||
_add_variant,
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
@@ -59,7 +61,8 @@ if is_safetensors_available():
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
try:
|
||||
return next(parameter.parameters()).device
|
||||
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
||||
return next(parameters_and_buffers).device
|
||||
except StopIteration:
|
||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
@@ -74,7 +77,8 @@ def get_parameter_device(parameter: torch.nn.Module):
|
||||
|
||||
def get_parameter_dtype(parameter: torch.nn.Module):
|
||||
try:
|
||||
return next(parameter.parameters()).dtype
|
||||
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
||||
return next(parameters_and_buffers).dtype
|
||||
except StopIteration:
|
||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
@@ -156,6 +160,24 @@ class ModelMixin(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
||||
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
||||
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
||||
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
||||
"""
|
||||
|
||||
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
||||
is_attribute = name in self.__dict__
|
||||
|
||||
if is_in_config and not is_attribute:
|
||||
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
||||
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
||||
return self._internal_dict[name]
|
||||
|
||||
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
||||
return super().__getattr__(name)
|
||||
|
||||
@property
|
||||
def is_gradient_checkpointing(self) -> bool:
|
||||
"""
|
||||
|
||||
@@ -225,7 +225,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||
@@ -190,16 +190,6 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
fc_dim=block_out_channels[-1] // 4,
|
||||
)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
@@ -216,16 +216,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -21,9 +21,9 @@ import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, deprecate, logging
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
@@ -97,11 +97,16 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
class_embed_type (`str`, *optional*, defaults to None):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to None):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
time_embedding_type (`str`, *optional*, default to `positional`):
|
||||
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
||||
time_embedding_dim (`int`, *optional*, default to `None`):
|
||||
An optional override for the dimension of the projected time embedding.
|
||||
time_embedding_act_fn (`str`, *optional*, default to `None`):
|
||||
Optional activation function to use on the time embeddings only one time before they as passed to the rest
|
||||
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
||||
@@ -155,12 +160,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_skip_time_act: bool = False,
|
||||
resnet_out_scale_factor: int = 1.0,
|
||||
time_embedding_type: str = "positional",
|
||||
time_embedding_dim: Optional[int] = None,
|
||||
time_embedding_act_fn: Optional[str] = None,
|
||||
timestep_post_act: Optional[str] = None,
|
||||
time_cond_proj_dim: Optional[int] = None,
|
||||
@@ -170,6 +177,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
class_embeddings_concat: bool = False,
|
||||
mid_block_only_cross_attention: Optional[bool] = None,
|
||||
cross_attention_norm: Optional[str] = None,
|
||||
addition_embed_type_num_heads=64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -214,7 +222,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
time_embed_dim = block_out_channels[0] * 2
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
||||
if time_embed_dim % 2 != 0:
|
||||
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
@@ -222,7 +230,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
)
|
||||
timestep_input_dim = time_embed_dim
|
||||
elif time_embedding_type == "positional":
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
||||
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
@@ -248,7 +256,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == "projection":
|
||||
@@ -273,6 +281,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == "text":
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
)
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.")
|
||||
|
||||
if time_embedding_act_fn is None:
|
||||
self.time_embed_act = None
|
||||
elif time_embedding_act_fn == "swish":
|
||||
@@ -437,7 +457,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||
)
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
if act_fn == "swish":
|
||||
self.conv_act = lambda x: F.silu(x)
|
||||
elif act_fn == "mish":
|
||||
self.conv_act = nn.Mish()
|
||||
elif act_fn == "silu":
|
||||
self.conv_act = nn.SiLU()
|
||||
elif act_fn == "gelu":
|
||||
self.conv_act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
else:
|
||||
self.conv_norm_out = None
|
||||
self.conv_act = None
|
||||
@@ -447,16 +478,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
||||
)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
@@ -658,7 +679,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
@@ -672,6 +693,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# there might be better ways to encapsulate this.
|
||||
class_labels = class_labels.to(dtype=sample.dtype)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
|
||||
if self.config.class_embeddings_concat:
|
||||
@@ -679,6 +704,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
else:
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
emb = emb + aug_emb
|
||||
|
||||
if self.time_embed_act is not None:
|
||||
emb = self.time_embed_act(emb)
|
||||
|
||||
|
||||
@@ -212,6 +212,7 @@ class Decoder(nn.Module):
|
||||
sample = z
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
@@ -222,6 +223,7 @@ class Decoder(nn.Module):
|
||||
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
@@ -229,6 +231,7 @@ class Decoder(nn.Module):
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(sample)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
|
||||
@@ -44,6 +44,14 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
|
||||
@@ -56,7 +56,7 @@ class RobertaSeriesConfig(XLMRobertaConfig):
|
||||
|
||||
|
||||
class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
base_model_prefix = "roberta"
|
||||
config_class = RobertaSeriesConfig
|
||||
@@ -65,6 +65,10 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.roberta = XLMRobertaModel(config)
|
||||
self.transformation = nn.Linear(config.hidden_size, config.project_dim)
|
||||
self.has_pre_transformation = getattr(config, "has_pre_transformation", False)
|
||||
if self.has_pre_transformation:
|
||||
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
||||
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
@@ -95,15 +99,26 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_hidden_states=True if self.has_pre_transformation else output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
projection_state = self.transformation(outputs.last_hidden_state)
|
||||
if self.has_pre_transformation:
|
||||
sequence_output2 = outputs["hidden_states"][-2]
|
||||
sequence_output2 = self.pre_LN(sequence_output2)
|
||||
projection_state2 = self.transformation_pre(sequence_output2)
|
||||
|
||||
return TransformationModelOutput(
|
||||
projection_state=projection_state,
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
return TransformationModelOutput(
|
||||
projection_state=projection_state2,
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
projection_state = self.transformation(outputs.last_hidden_state)
|
||||
return TransformationModelOutput(
|
||||
projection_state=projection_state,
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@@ -57,6 +57,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -96,6 +96,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
@@ -503,7 +511,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -293,7 +293,7 @@ class AudioLDMPipeline(DiffusionPipeline):
|
||||
|
||||
waveform = self.vocoder(mel_spectrogram)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
waveform = waveform.cpu()
|
||||
waveform = waveform.cpu().float()
|
||||
return waveform
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
|
||||
54
src/diffusers/pipelines/deepfloyd_if/__init__.py
Normal file
54
src/diffusers/pipelines/deepfloyd_if/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
||||
from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
|
||||
from .timesteps import (
|
||||
fast27_timesteps,
|
||||
smart27_timesteps,
|
||||
smart50_timesteps,
|
||||
smart100_timesteps,
|
||||
smart185_timesteps,
|
||||
super27_timesteps,
|
||||
super40_timesteps,
|
||||
super100_timesteps,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IFPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
Output class for Stable Diffusion pipelines.
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
nsfw_detected (`List[bool]`)
|
||||
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content or a watermark. `None` if safety checking could not be performed.
|
||||
watermark_detected (`List[bool]`)
|
||||
List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
|
||||
checking could not be performed.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
nsfw_detected: Optional[List[bool]]
|
||||
watermark_detected: Optional[List[bool]]
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_if import IFPipeline
|
||||
from .pipeline_if_img2img import IFImg2ImgPipeline
|
||||
from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline
|
||||
from .pipeline_if_inpainting import IFInpaintingPipeline
|
||||
from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline
|
||||
from .pipeline_if_superresolution import IFSuperResolutionPipeline
|
||||
from .safety_checker import IFSafetyChecker
|
||||
from .watermark import IFWatermarker
|
||||
854
src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
Normal file
854
src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
Normal file
@@ -0,0 +1,854 @@
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import IFPipelineOutput
|
||||
from .safety_checker import IFSafetyChecker
|
||||
from .watermark import IFWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_bs4_available():
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline
|
||||
>>> from diffusers.utils import pt_to_pil
|
||||
>>> import torch
|
||||
|
||||
>>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
|
||||
>>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
|
||||
|
||||
>>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images
|
||||
|
||||
>>> # save intermediate image
|
||||
>>> pil_image = pt_to_pil(image)
|
||||
>>> pil_image[0].save("./if_stage_I.png")
|
||||
|
||||
>>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained(
|
||||
... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> super_res_1_pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> image = super_res_1_pipe(
|
||||
... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt"
|
||||
... ).images
|
||||
|
||||
>>> # save intermediate image
|
||||
>>> pil_image = pt_to_pil(image)
|
||||
>>> pil_image[0].save("./if_stage_I.png")
|
||||
|
||||
>>> safety_modules = {
|
||||
... "feature_extractor": pipe.feature_extractor,
|
||||
... "safety_checker": pipe.safety_checker,
|
||||
... "watermarker": pipe.watermarker,
|
||||
... }
|
||||
>>> super_res_2_pipe = DiffusionPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
|
||||
... )
|
||||
>>> super_res_2_pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> image = super_res_2_pipe(
|
||||
... prompt=prompt,
|
||||
... image=image,
|
||||
... ).images
|
||||
>>> image[0].save("./if_stage_II.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class IFPipeline(DiffusionPipeline):
|
||||
tokenizer: T5Tokenizer
|
||||
text_encoder: T5EncoderModel
|
||||
|
||||
unet: UNet2DConditionModel
|
||||
scheduler: DDPMScheduler
|
||||
|
||||
feature_extractor: Optional[CLIPImageProcessor]
|
||||
safety_checker: Optional[IFSafetyChecker]
|
||||
|
||||
watermarker: Optional[IFWatermarker]
|
||||
|
||||
bad_punct_regex = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDPMScheduler,
|
||||
safety_checker: Optional[IFSafetyChecker],
|
||||
feature_extractor: Optional[CLIPImageProcessor],
|
||||
watermarker: Optional[IFWatermarker],
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the IF license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
watermarker=watermarker,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.text_encoder,
|
||||
self.unet,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
|
||||
if self.text_encoder is not None:
|
||||
_, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
|
||||
|
||||
# Accelerate will move the next model to the device _before_ calling the offload hook of the
|
||||
# previous model. This will cause both models to be present on the device at the same time.
|
||||
# IF uses T5 for its text encoder which is really large. We can manually call the offload
|
||||
# hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
|
||||
# the GPU.
|
||||
self.text_encoder_offload_hook = hook
|
||||
|
||||
_, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
|
||||
|
||||
# if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
|
||||
self.unet_offload_hook = hook
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
|
||||
max_length = 77
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
dtype = self.unet.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, nsfw_detected, watermark_detected = self.safety_checker(
|
||||
images=image,
|
||||
clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
|
||||
)
|
||||
else:
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator):
|
||||
shape = (batch_size, num_channels, height, width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
|
||||
return intermediate_images
|
||||
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_inference_steps: int = 100,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
clean_caption: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The width in pixels of the generated image.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
||||
returning a tuple, the first element is a list with the generated images, and the second element is a list
|
||||
of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
|
||||
or watermarked content, according to the `safety_checker`.
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
# 2. Define call parameters
|
||||
height = height or self.unet.config.sample_size
|
||||
width = width or self.unet.config.sample_size
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
clean_caption=clean_caption,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if timesteps is not None:
|
||||
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare intermediate images
|
||||
intermediate_images = self.prepare_intermediate_images(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.unet.config.in_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# HACK: see comment in `enable_model_cpu_offload`
|
||||
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
|
||||
self.text_encoder_offload_hook.offload()
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
model_input = (
|
||||
torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
|
||||
)
|
||||
model_input = self.scheduler.scale_model_input(model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
intermediate_images = self.scheduler.step(
|
||||
noise_pred, t, intermediate_images, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, intermediate_images)
|
||||
|
||||
image = intermediate_images
|
||||
|
||||
if output_type == "pil":
|
||||
# 8. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# 9. Run safety checker
|
||||
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# 10. Convert to PIL
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# 11. Apply watermark
|
||||
if self.watermarker is not None:
|
||||
self.watermarker.apply_watermark(image, self.unet.config.sample_size)
|
||||
elif output_type == "pt":
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
else:
|
||||
# 8. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# 9. Run safety checker
|
||||
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, nsfw_detected, watermark_detected)
|
||||
|
||||
return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
|
||||
979
src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
Normal file
979
src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
Normal file
@@ -0,0 +1,979 @@
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
PIL_INTERPOLATION,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import IFPipelineOutput
|
||||
from .safety_checker import IFSafetyChecker
|
||||
from .watermark import IFWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_bs4_available():
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image:
|
||||
w, h = images.size
|
||||
|
||||
coef = w / h
|
||||
|
||||
w, h = img_size, img_size
|
||||
|
||||
if coef >= 1:
|
||||
w = int(round(img_size / 8 * coef) * 8)
|
||||
else:
|
||||
h = int(round(img_size / 8 / coef) * 8)
|
||||
|
||||
images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline
|
||||
>>> from diffusers.utils import pt_to_pil
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from io import BytesIO
|
||||
|
||||
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
>>> response = requests.get(url)
|
||||
>>> original_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
>>> original_image = original_image.resize((768, 512))
|
||||
|
||||
>>> pipe = IFImg2ImgPipeline.from_pretrained(
|
||||
... "DeepFloyd/IF-I-XL-v1.0",
|
||||
... variant="fp16",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = "A fantasy landscape in style minecraft"
|
||||
>>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
|
||||
|
||||
>>> image = pipe(
|
||||
... image=original_image,
|
||||
... prompt_embeds=prompt_embeds,
|
||||
... negative_prompt_embeds=negative_embeds,
|
||||
... output_type="pt",
|
||||
... ).images
|
||||
|
||||
>>> # save intermediate image
|
||||
>>> pil_image = pt_to_pil(image)
|
||||
>>> pil_image[0].save("./if_stage_I.png")
|
||||
|
||||
>>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained(
|
||||
... "DeepFloyd/IF-II-L-v1.0",
|
||||
... text_encoder=None,
|
||||
... variant="fp16",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
>>> super_res_1_pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> image = super_res_1_pipe(
|
||||
... image=image,
|
||||
... original_image=original_image,
|
||||
... prompt_embeds=prompt_embeds,
|
||||
... negative_prompt_embeds=negative_embeds,
|
||||
... ).images
|
||||
>>> image[0].save("./if_stage_II.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class IFImg2ImgPipeline(DiffusionPipeline):
|
||||
tokenizer: T5Tokenizer
|
||||
text_encoder: T5EncoderModel
|
||||
|
||||
unet: UNet2DConditionModel
|
||||
scheduler: DDPMScheduler
|
||||
|
||||
feature_extractor: Optional[CLIPImageProcessor]
|
||||
safety_checker: Optional[IFSafetyChecker]
|
||||
|
||||
watermarker: Optional[IFWatermarker]
|
||||
|
||||
bad_punct_regex = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDPMScheduler,
|
||||
safety_checker: Optional[IFSafetyChecker],
|
||||
feature_extractor: Optional[CLIPImageProcessor],
|
||||
watermarker: Optional[IFWatermarker],
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the IF license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
watermarker=watermarker,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.text_encoder,
|
||||
self.unet,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
|
||||
if self.text_encoder is not None:
|
||||
_, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
|
||||
|
||||
# Accelerate will move the next model to the device _before_ calling the offload hook of the
|
||||
# previous model. This will cause both models to be present on the device at the same time.
|
||||
# IF uses T5 for its text encoder which is really large. We can manually call the offload
|
||||
# hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
|
||||
# the GPU.
|
||||
self.text_encoder_offload_hook = hook
|
||||
|
||||
_, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
|
||||
|
||||
# if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
|
||||
self.unet_offload_hook = hook
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
|
||||
max_length = 77
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
dtype = self.unet.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, nsfw_detected, watermark_detected = self.safety_checker(
|
||||
images=image,
|
||||
clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
|
||||
)
|
||||
else:
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
batch_size,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if isinstance(image, list):
|
||||
check_image_type = image[0]
|
||||
else:
|
||||
check_image_type = image
|
||||
|
||||
if (
|
||||
not isinstance(check_image_type, torch.Tensor)
|
||||
and not isinstance(check_image_type, PIL.Image.Image)
|
||||
and not isinstance(check_image_type, np.ndarray)
|
||||
):
|
||||
raise ValueError(
|
||||
"`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
|
||||
f" {type(check_image_type)}"
|
||||
)
|
||||
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image_batch_size = image.shape[0]
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image_batch_size = 1
|
||||
elif isinstance(image, np.ndarray):
|
||||
image_batch_size = image.shape[0]
|
||||
else:
|
||||
assert False
|
||||
|
||||
if batch_size != image_batch_size:
|
||||
raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor:
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
def numpy_to_pt(images):
|
||||
if images.ndim == 3:
|
||||
images = images[..., None]
|
||||
|
||||
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
||||
return images
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
new_image = []
|
||||
|
||||
for image_ in image:
|
||||
image_ = image_.convert("RGB")
|
||||
image_ = resize(image_, self.unet.sample_size)
|
||||
image_ = np.array(image_)
|
||||
image_ = image_.astype(np.float32)
|
||||
image_ = image_ / 127.5 - 1
|
||||
new_image.append(image_)
|
||||
|
||||
image = new_image
|
||||
|
||||
image = np.stack(image, axis=0) # to np
|
||||
image = numpy_to_pt(image) # to pt
|
||||
|
||||
elif isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
||||
image = numpy_to_pt(image)
|
||||
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
||||
|
||||
return image
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_intermediate_images(
|
||||
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None
|
||||
):
|
||||
_, channels, height, width = image.shape
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
shape = (batch_size, channels, height, width)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
image = image.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
image = self.scheduler.add_noise(image, noise, timestep)
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[
|
||||
PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]
|
||||
] = None,
|
||||
strength: float = 0.7,
|
||||
num_inference_steps: int = 80,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 10.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
clean_caption: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||
be maximum and the denoising process will run for the full number of iterations specified in
|
||||
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
||||
returning a tuple, the first element is a list with the generated images, and the second element is a list
|
||||
of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
|
||||
or watermarked content, according to the `safety_checker`.
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
self.check_inputs(
|
||||
prompt, image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
clean_caption=clean_caption,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
dtype = prompt_embeds.dtype
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if timesteps is not None:
|
||||
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
|
||||
|
||||
# 5. Prepare intermediate images
|
||||
image = self.preprocess_image(image)
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
noise_timestep = timesteps[0:1]
|
||||
noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
intermediate_images = self.prepare_intermediate_images(
|
||||
image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# HACK: see comment in `enable_model_cpu_offload`
|
||||
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
|
||||
self.text_encoder_offload_hook.offload()
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
model_input = (
|
||||
torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
|
||||
)
|
||||
model_input = self.scheduler.scale_model_input(model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
intermediate_images = self.scheduler.step(
|
||||
noise_pred, t, intermediate_images, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, intermediate_images)
|
||||
|
||||
image = intermediate_images
|
||||
|
||||
if output_type == "pil":
|
||||
# 8. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# 9. Run safety checker
|
||||
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# 10. Convert to PIL
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# 11. Apply watermark
|
||||
if self.watermarker is not None:
|
||||
self.watermarker.apply_watermark(image, self.unet.config.sample_size)
|
||||
elif output_type == "pt":
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
else:
|
||||
# 8. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# 9. Run safety checker
|
||||
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, nsfw_detected, watermark_detected)
|
||||
|
||||
return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
|
||||
File diff suppressed because it is too large
Load Diff
1098
src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
Normal file
1098
src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,947 @@
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import IFPipelineOutput
|
||||
from .safety_checker import IFSafetyChecker
|
||||
from .watermark import IFWatermarker
|
||||
|
||||
|
||||
if is_bs4_available():
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline
|
||||
>>> from diffusers.utils import pt_to_pil
|
||||
>>> import torch
|
||||
|
||||
>>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"'
|
||||
>>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt)
|
||||
|
||||
>>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images
|
||||
|
||||
>>> # save intermediate image
|
||||
>>> pil_image = pt_to_pil(image)
|
||||
>>> pil_image[0].save("./if_stage_I.png")
|
||||
|
||||
>>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained(
|
||||
... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> super_res_1_pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> image = super_res_1_pipe(
|
||||
... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds
|
||||
... ).images
|
||||
>>> image[0].save("./if_stage_II.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class IFSuperResolutionPipeline(DiffusionPipeline):
|
||||
tokenizer: T5Tokenizer
|
||||
text_encoder: T5EncoderModel
|
||||
|
||||
unet: UNet2DConditionModel
|
||||
scheduler: DDPMScheduler
|
||||
image_noising_scheduler: DDPMScheduler
|
||||
|
||||
feature_extractor: Optional[CLIPImageProcessor]
|
||||
safety_checker: Optional[IFSafetyChecker]
|
||||
|
||||
watermarker: Optional[IFWatermarker]
|
||||
|
||||
bad_punct_regex = re.compile(
|
||||
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDPMScheduler,
|
||||
image_noising_scheduler: DDPMScheduler,
|
||||
safety_checker: Optional[IFSafetyChecker],
|
||||
feature_extractor: Optional[CLIPImageProcessor],
|
||||
watermarker: Optional[IFWatermarker],
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the IF license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
if unet.config.in_channels != 6:
|
||||
logger.warn(
|
||||
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
image_noising_scheduler=image_noising_scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
watermarker=watermarker,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.text_encoder,
|
||||
self.unet,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
|
||||
if self.text_encoder is not None:
|
||||
_, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
|
||||
|
||||
# Accelerate will move the next model to the device _before_ calling the offload hook of the
|
||||
# previous model. This will cause both models to be present on the device at the same time.
|
||||
# IF uses T5 for its text encoder which is really large. We can manually call the offload
|
||||
# hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
|
||||
# the GPU.
|
||||
self.text_encoder_offload_hook = hook
|
||||
|
||||
_, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
|
||||
|
||||
# if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
|
||||
self.unet_offload_hook = hook
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warn("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
if type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
|
||||
max_length = 77
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
dtype = self.unet.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, nsfw_detected, watermark_detected = self.safety_checker(
|
||||
images=image,
|
||||
clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
|
||||
)
|
||||
else:
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
batch_size,
|
||||
noise_level,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})"
|
||||
)
|
||||
|
||||
if isinstance(image, list):
|
||||
check_image_type = image[0]
|
||||
else:
|
||||
check_image_type = image
|
||||
|
||||
if (
|
||||
not isinstance(check_image_type, torch.Tensor)
|
||||
and not isinstance(check_image_type, PIL.Image.Image)
|
||||
and not isinstance(check_image_type, np.ndarray)
|
||||
):
|
||||
raise ValueError(
|
||||
"`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
|
||||
f" {type(check_image_type)}"
|
||||
)
|
||||
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image_batch_size = image.shape[0]
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image_batch_size = 1
|
||||
elif isinstance(image, np.ndarray):
|
||||
image_batch_size = image.shape[0]
|
||||
else:
|
||||
assert False
|
||||
|
||||
if batch_size != image_batch_size:
|
||||
raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_intermediate_images
|
||||
def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator):
|
||||
shape = (batch_size, num_channels, height, width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
|
||||
return intermediate_images
|
||||
|
||||
def preprocess_image(self, image, num_images_per_prompt, device):
|
||||
if not isinstance(image, torch.Tensor) and not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
|
||||
|
||||
image = np.stack(image, axis=0) # to np
|
||||
torch.from_numpy(image.transpose(0, 3, 1, 2))
|
||||
elif isinstance(image[0], np.ndarray):
|
||||
image = np.stack(image, axis=0) # to np
|
||||
if image.ndim == 5:
|
||||
image = image[0]
|
||||
|
||||
image = torch.from_numpy(image.transpose(0, 3, 1, 2))
|
||||
elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
|
||||
dims = image[0].ndim
|
||||
|
||||
if dims == 3:
|
||||
image = torch.stack(image, dim=0)
|
||||
elif dims == 4:
|
||||
image = torch.concat(image, dim=0)
|
||||
else:
|
||||
raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
|
||||
|
||||
image = image.to(device=device, dtype=self.unet.dtype)
|
||||
|
||||
image = image.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
noise_level: int = 250,
|
||||
clean_caption: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
|
||||
The image to be upscaled.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
noise_level (`int`, *optional*, defaults to 250):
|
||||
The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
||||
returning a tuple, the first element is a list with the generated images, and the second element is a list
|
||||
of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
|
||||
or watermarked content, according to the `safety_checker`.
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
batch_size,
|
||||
noise_level,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
|
||||
height = self.unet.config.sample_size
|
||||
width = self.unet.config.sample_size
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
clean_caption=clean_caption,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if timesteps is not None:
|
||||
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare intermediate images
|
||||
num_channels = self.unet.config.in_channels // 2
|
||||
intermediate_images = self.prepare_intermediate_images(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare upscaled image and noise level
|
||||
image = self.preprocess_image(image, num_images_per_prompt, device)
|
||||
upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)
|
||||
|
||||
noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
|
||||
noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
|
||||
upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_level = torch.cat([noise_level] * 2)
|
||||
|
||||
# HACK: see comment in `enable_model_cpu_offload`
|
||||
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
|
||||
self.text_encoder_offload_hook.offload()
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
model_input = torch.cat([intermediate_images, upscaled], dim=1)
|
||||
|
||||
model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
|
||||
model_input = self.scheduler.scale_model_input(model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
class_labels=noise_level,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
intermediate_images = self.scheduler.step(
|
||||
noise_pred, t, intermediate_images, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, intermediate_images)
|
||||
|
||||
image = intermediate_images
|
||||
|
||||
if output_type == "pil":
|
||||
# 9. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# 10. Run safety checker
|
||||
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# 11. Convert to PIL
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# 12. Apply watermark
|
||||
if self.watermarker is not None:
|
||||
self.watermarker.apply_watermark(image, self.unet.config.sample_size)
|
||||
elif output_type == "pt":
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
else:
|
||||
# 9. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
# 10. Run safety checker
|
||||
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, nsfw_detected, watermark_detected)
|
||||
|
||||
return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected)
|
||||
59
src/diffusers/pipelines/deepfloyd_if/safety_checker.py
Normal file
59
src/diffusers/pipelines/deepfloyd_if/safety_checker.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel
|
||||
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class IFSafetyChecker(PreTrainedModel):
|
||||
config_class = CLIPConfig
|
||||
|
||||
_no_split_modules = ["CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.vision_model = CLIPVisionModelWithProjection(config.vision_config)
|
||||
|
||||
self.p_head = nn.Linear(config.vision_config.projection_dim, 1)
|
||||
self.w_head = nn.Linear(config.vision_config.projection_dim, 1)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5):
|
||||
image_embeds = self.vision_model(clip_input)[0]
|
||||
|
||||
nsfw_detected = self.p_head(image_embeds)
|
||||
nsfw_detected = nsfw_detected.flatten()
|
||||
nsfw_detected = nsfw_detected > p_threshold
|
||||
nsfw_detected = nsfw_detected.tolist()
|
||||
|
||||
if any(nsfw_detected):
|
||||
logger.warning(
|
||||
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
|
||||
" Try again with a different prompt and/or seed."
|
||||
)
|
||||
|
||||
for idx, nsfw_detected_ in enumerate(nsfw_detected):
|
||||
if nsfw_detected_:
|
||||
images[idx] = np.zeros(images[idx].shape)
|
||||
|
||||
watermark_detected = self.w_head(image_embeds)
|
||||
watermark_detected = watermark_detected.flatten()
|
||||
watermark_detected = watermark_detected > w_threshold
|
||||
watermark_detected = watermark_detected.tolist()
|
||||
|
||||
if any(watermark_detected):
|
||||
logger.warning(
|
||||
"Potential watermarked content was detected in one or more images. A black image will be returned instead."
|
||||
" Try again with a different prompt and/or seed."
|
||||
)
|
||||
|
||||
for idx, watermark_detected_ in enumerate(watermark_detected):
|
||||
if watermark_detected_:
|
||||
images[idx] = np.zeros(images[idx].shape)
|
||||
|
||||
return images, nsfw_detected, watermark_detected
|
||||
579
src/diffusers/pipelines/deepfloyd_if/timesteps.py
Normal file
579
src/diffusers/pipelines/deepfloyd_if/timesteps.py
Normal file
@@ -0,0 +1,579 @@
|
||||
fast27_timesteps = [
|
||||
999,
|
||||
800,
|
||||
799,
|
||||
600,
|
||||
599,
|
||||
500,
|
||||
400,
|
||||
399,
|
||||
377,
|
||||
355,
|
||||
333,
|
||||
311,
|
||||
288,
|
||||
266,
|
||||
244,
|
||||
222,
|
||||
200,
|
||||
199,
|
||||
177,
|
||||
155,
|
||||
133,
|
||||
111,
|
||||
88,
|
||||
66,
|
||||
44,
|
||||
22,
|
||||
0,
|
||||
]
|
||||
|
||||
smart27_timesteps = [
|
||||
999,
|
||||
976,
|
||||
952,
|
||||
928,
|
||||
905,
|
||||
882,
|
||||
858,
|
||||
857,
|
||||
810,
|
||||
762,
|
||||
715,
|
||||
714,
|
||||
572,
|
||||
429,
|
||||
428,
|
||||
286,
|
||||
285,
|
||||
238,
|
||||
190,
|
||||
143,
|
||||
142,
|
||||
118,
|
||||
95,
|
||||
71,
|
||||
47,
|
||||
24,
|
||||
0,
|
||||
]
|
||||
|
||||
smart50_timesteps = [
|
||||
999,
|
||||
988,
|
||||
977,
|
||||
966,
|
||||
955,
|
||||
944,
|
||||
933,
|
||||
922,
|
||||
911,
|
||||
900,
|
||||
899,
|
||||
879,
|
||||
859,
|
||||
840,
|
||||
820,
|
||||
800,
|
||||
799,
|
||||
766,
|
||||
733,
|
||||
700,
|
||||
699,
|
||||
650,
|
||||
600,
|
||||
599,
|
||||
500,
|
||||
499,
|
||||
400,
|
||||
399,
|
||||
350,
|
||||
300,
|
||||
299,
|
||||
266,
|
||||
233,
|
||||
200,
|
||||
199,
|
||||
179,
|
||||
159,
|
||||
140,
|
||||
120,
|
||||
100,
|
||||
99,
|
||||
88,
|
||||
77,
|
||||
66,
|
||||
55,
|
||||
44,
|
||||
33,
|
||||
22,
|
||||
11,
|
||||
0,
|
||||
]
|
||||
|
||||
smart100_timesteps = [
|
||||
999,
|
||||
995,
|
||||
992,
|
||||
989,
|
||||
985,
|
||||
981,
|
||||
978,
|
||||
975,
|
||||
971,
|
||||
967,
|
||||
964,
|
||||
961,
|
||||
957,
|
||||
956,
|
||||
951,
|
||||
947,
|
||||
942,
|
||||
937,
|
||||
933,
|
||||
928,
|
||||
923,
|
||||
919,
|
||||
914,
|
||||
913,
|
||||
908,
|
||||
903,
|
||||
897,
|
||||
892,
|
||||
887,
|
||||
881,
|
||||
876,
|
||||
871,
|
||||
870,
|
||||
864,
|
||||
858,
|
||||
852,
|
||||
846,
|
||||
840,
|
||||
834,
|
||||
828,
|
||||
827,
|
||||
820,
|
||||
813,
|
||||
806,
|
||||
799,
|
||||
792,
|
||||
785,
|
||||
784,
|
||||
777,
|
||||
770,
|
||||
763,
|
||||
756,
|
||||
749,
|
||||
742,
|
||||
741,
|
||||
733,
|
||||
724,
|
||||
716,
|
||||
707,
|
||||
699,
|
||||
698,
|
||||
688,
|
||||
677,
|
||||
666,
|
||||
656,
|
||||
655,
|
||||
645,
|
||||
634,
|
||||
623,
|
||||
613,
|
||||
612,
|
||||
598,
|
||||
584,
|
||||
570,
|
||||
569,
|
||||
555,
|
||||
541,
|
||||
527,
|
||||
526,
|
||||
505,
|
||||
484,
|
||||
483,
|
||||
462,
|
||||
440,
|
||||
439,
|
||||
396,
|
||||
395,
|
||||
352,
|
||||
351,
|
||||
308,
|
||||
307,
|
||||
264,
|
||||
263,
|
||||
220,
|
||||
219,
|
||||
176,
|
||||
132,
|
||||
88,
|
||||
44,
|
||||
0,
|
||||
]
|
||||
|
||||
smart185_timesteps = [
|
||||
999,
|
||||
997,
|
||||
995,
|
||||
992,
|
||||
990,
|
||||
988,
|
||||
986,
|
||||
984,
|
||||
981,
|
||||
979,
|
||||
977,
|
||||
975,
|
||||
972,
|
||||
970,
|
||||
968,
|
||||
966,
|
||||
964,
|
||||
961,
|
||||
959,
|
||||
957,
|
||||
956,
|
||||
954,
|
||||
951,
|
||||
949,
|
||||
946,
|
||||
944,
|
||||
941,
|
||||
939,
|
||||
936,
|
||||
934,
|
||||
931,
|
||||
929,
|
||||
926,
|
||||
924,
|
||||
921,
|
||||
919,
|
||||
916,
|
||||
914,
|
||||
913,
|
||||
910,
|
||||
907,
|
||||
905,
|
||||
902,
|
||||
899,
|
||||
896,
|
||||
893,
|
||||
891,
|
||||
888,
|
||||
885,
|
||||
882,
|
||||
879,
|
||||
877,
|
||||
874,
|
||||
871,
|
||||
870,
|
||||
867,
|
||||
864,
|
||||
861,
|
||||
858,
|
||||
855,
|
||||
852,
|
||||
849,
|
||||
846,
|
||||
843,
|
||||
840,
|
||||
837,
|
||||
834,
|
||||
831,
|
||||
828,
|
||||
827,
|
||||
824,
|
||||
821,
|
||||
817,
|
||||
814,
|
||||
811,
|
||||
808,
|
||||
804,
|
||||
801,
|
||||
798,
|
||||
795,
|
||||
791,
|
||||
788,
|
||||
785,
|
||||
784,
|
||||
780,
|
||||
777,
|
||||
774,
|
||||
770,
|
||||
766,
|
||||
763,
|
||||
760,
|
||||
756,
|
||||
752,
|
||||
749,
|
||||
746,
|
||||
742,
|
||||
741,
|
||||
737,
|
||||
733,
|
||||
730,
|
||||
726,
|
||||
722,
|
||||
718,
|
||||
714,
|
||||
710,
|
||||
707,
|
||||
703,
|
||||
699,
|
||||
698,
|
||||
694,
|
||||
690,
|
||||
685,
|
||||
681,
|
||||
677,
|
||||
673,
|
||||
669,
|
||||
664,
|
||||
660,
|
||||
656,
|
||||
655,
|
||||
650,
|
||||
646,
|
||||
641,
|
||||
636,
|
||||
632,
|
||||
627,
|
||||
622,
|
||||
618,
|
||||
613,
|
||||
612,
|
||||
607,
|
||||
602,
|
||||
596,
|
||||
591,
|
||||
586,
|
||||
580,
|
||||
575,
|
||||
570,
|
||||
569,
|
||||
563,
|
||||
557,
|
||||
551,
|
||||
545,
|
||||
539,
|
||||
533,
|
||||
527,
|
||||
526,
|
||||
519,
|
||||
512,
|
||||
505,
|
||||
498,
|
||||
491,
|
||||
484,
|
||||
483,
|
||||
474,
|
||||
466,
|
||||
457,
|
||||
449,
|
||||
440,
|
||||
439,
|
||||
428,
|
||||
418,
|
||||
407,
|
||||
396,
|
||||
395,
|
||||
381,
|
||||
366,
|
||||
352,
|
||||
351,
|
||||
330,
|
||||
308,
|
||||
307,
|
||||
286,
|
||||
264,
|
||||
263,
|
||||
242,
|
||||
220,
|
||||
219,
|
||||
176,
|
||||
175,
|
||||
132,
|
||||
131,
|
||||
88,
|
||||
44,
|
||||
0,
|
||||
]
|
||||
|
||||
super27_timesteps = [
|
||||
999,
|
||||
991,
|
||||
982,
|
||||
974,
|
||||
966,
|
||||
958,
|
||||
950,
|
||||
941,
|
||||
933,
|
||||
925,
|
||||
916,
|
||||
908,
|
||||
900,
|
||||
899,
|
||||
874,
|
||||
850,
|
||||
825,
|
||||
800,
|
||||
799,
|
||||
700,
|
||||
600,
|
||||
500,
|
||||
400,
|
||||
300,
|
||||
200,
|
||||
100,
|
||||
0,
|
||||
]
|
||||
|
||||
super40_timesteps = [
|
||||
999,
|
||||
992,
|
||||
985,
|
||||
978,
|
||||
971,
|
||||
964,
|
||||
957,
|
||||
949,
|
||||
942,
|
||||
935,
|
||||
928,
|
||||
921,
|
||||
914,
|
||||
907,
|
||||
900,
|
||||
899,
|
||||
879,
|
||||
859,
|
||||
840,
|
||||
820,
|
||||
800,
|
||||
799,
|
||||
766,
|
||||
733,
|
||||
700,
|
||||
699,
|
||||
650,
|
||||
600,
|
||||
599,
|
||||
500,
|
||||
499,
|
||||
400,
|
||||
399,
|
||||
300,
|
||||
299,
|
||||
200,
|
||||
199,
|
||||
100,
|
||||
99,
|
||||
0,
|
||||
]
|
||||
|
||||
super100_timesteps = [
|
||||
999,
|
||||
996,
|
||||
992,
|
||||
989,
|
||||
985,
|
||||
982,
|
||||
979,
|
||||
975,
|
||||
972,
|
||||
968,
|
||||
965,
|
||||
961,
|
||||
958,
|
||||
955,
|
||||
951,
|
||||
948,
|
||||
944,
|
||||
941,
|
||||
938,
|
||||
934,
|
||||
931,
|
||||
927,
|
||||
924,
|
||||
920,
|
||||
917,
|
||||
914,
|
||||
910,
|
||||
907,
|
||||
903,
|
||||
900,
|
||||
899,
|
||||
891,
|
||||
884,
|
||||
876,
|
||||
869,
|
||||
861,
|
||||
853,
|
||||
846,
|
||||
838,
|
||||
830,
|
||||
823,
|
||||
815,
|
||||
808,
|
||||
800,
|
||||
799,
|
||||
788,
|
||||
777,
|
||||
766,
|
||||
755,
|
||||
744,
|
||||
733,
|
||||
722,
|
||||
711,
|
||||
700,
|
||||
699,
|
||||
688,
|
||||
677,
|
||||
666,
|
||||
655,
|
||||
644,
|
||||
633,
|
||||
622,
|
||||
611,
|
||||
600,
|
||||
599,
|
||||
585,
|
||||
571,
|
||||
557,
|
||||
542,
|
||||
528,
|
||||
514,
|
||||
500,
|
||||
499,
|
||||
485,
|
||||
471,
|
||||
457,
|
||||
442,
|
||||
428,
|
||||
414,
|
||||
400,
|
||||
399,
|
||||
379,
|
||||
359,
|
||||
340,
|
||||
320,
|
||||
300,
|
||||
299,
|
||||
279,
|
||||
259,
|
||||
240,
|
||||
220,
|
||||
200,
|
||||
199,
|
||||
166,
|
||||
133,
|
||||
100,
|
||||
99,
|
||||
66,
|
||||
33,
|
||||
0,
|
||||
]
|
||||
46
src/diffusers/pipelines/deepfloyd_if/watermark.py
Normal file
46
src/diffusers/pipelines/deepfloyd_if/watermark.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import List
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ...configuration_utils import ConfigMixin
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import PIL_INTERPOLATION
|
||||
|
||||
|
||||
class IFWatermarker(ModelMixin, ConfigMixin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.register_buffer("watermark_image", torch.zeros((62, 62, 4)))
|
||||
self.watermark_image_as_pil = None
|
||||
|
||||
def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None):
|
||||
# copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287
|
||||
|
||||
h = images[0].height
|
||||
w = images[0].width
|
||||
|
||||
sample_size = sample_size or h
|
||||
|
||||
coef = min(h / sample_size, w / sample_size)
|
||||
img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w)
|
||||
|
||||
S1, S2 = 1024**2, img_w * img_h
|
||||
K = (S2 / S1) ** 0.5
|
||||
wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K)
|
||||
|
||||
if self.watermark_image_as_pil is None:
|
||||
watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy()
|
||||
watermark_image = Image.fromarray(watermark_image, mode="RGBA")
|
||||
self.watermark_image_as_pil = watermark_image
|
||||
|
||||
wm_img = self.watermark_image_as_pil.resize(
|
||||
(wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None
|
||||
)
|
||||
|
||||
for pil_img in images:
|
||||
pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1])
|
||||
|
||||
return images
|
||||
@@ -19,6 +19,7 @@ import importlib
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -29,7 +30,6 @@ import PIL
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, model_info, snapshot_download
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import diffusers
|
||||
@@ -55,6 +55,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
numpy_to_pil,
|
||||
)
|
||||
|
||||
|
||||
@@ -200,24 +201,24 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
||||
# .bin, .safetensors, ...
|
||||
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
||||
# -00001-of-00002
|
||||
transformers_index_format = "\d{5}-of-\d{5}"
|
||||
transformers_index_format = r"\d{5}-of-\d{5}"
|
||||
|
||||
if variant is not None:
|
||||
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors`
|
||||
variant_file_re = re.compile(
|
||||
f"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
||||
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
# `text_encoder/pytorch_model.bin.index.fp16.json`
|
||||
variant_index_re = re.compile(
|
||||
f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors`
|
||||
non_variant_file_re = re.compile(
|
||||
f"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
||||
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
# `text_encoder/pytorch_model.bin.index.json`
|
||||
non_variant_index_re = re.compile(f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
||||
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
||||
|
||||
if variant is not None:
|
||||
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
@@ -507,7 +508,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
setattr(self, name, module)
|
||||
|
||||
def __setattr__(self, name: str, value: Any):
|
||||
if hasattr(self, name) and hasattr(self.config, name):
|
||||
if name in self.__dict__ and hasattr(self.config, name):
|
||||
# We need to overwrite the config if name exists in config
|
||||
if isinstance(getattr(self.config, name), (tuple, list)):
|
||||
if value is not None and self.config[name][0] is not None:
|
||||
@@ -540,11 +541,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||
"""
|
||||
self.save_config(save_directory)
|
||||
|
||||
model_index_dict = dict(self.config)
|
||||
model_index_dict.pop("_class_name")
|
||||
model_index_dict.pop("_diffusers_version")
|
||||
model_index_dict.pop("_class_name", None)
|
||||
model_index_dict.pop("_diffusers_version", None)
|
||||
model_index_dict.pop("_module", None)
|
||||
|
||||
expected_modules, optional_kwargs = self._get_signature_keys(self)
|
||||
@@ -557,7 +556,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||
return True
|
||||
|
||||
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
model_cls = sub_model.__class__
|
||||
@@ -571,7 +569,13 @@ class DiffusionPipeline(ConfigMixin):
|
||||
save_method_name = None
|
||||
# search for the model's base class in LOADABLE_CLASSES
|
||||
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||||
library = importlib.import_module(library_name)
|
||||
if library_name in sys.modules:
|
||||
library = importlib.import_module(library_name)
|
||||
else:
|
||||
logger.info(
|
||||
f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
|
||||
)
|
||||
|
||||
for base_class, save_load_methods in library_classes.items():
|
||||
class_candidate = getattr(library, base_class, None)
|
||||
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
||||
@@ -581,6 +585,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if save_method_name is not None:
|
||||
break
|
||||
|
||||
if save_method_name is None:
|
||||
logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
|
||||
# make sure that unsaveable components are not tried to be loaded afterward
|
||||
self.register_to_config(**{pipeline_component_name: (None, None)})
|
||||
continue
|
||||
|
||||
save_method = getattr(sub_model, save_method_name)
|
||||
|
||||
# Call the save method with the argument safe_serialization only if it's supported
|
||||
@@ -596,6 +606,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
# finally save the config
|
||||
self.save_config(save_directory)
|
||||
|
||||
def to(
|
||||
self,
|
||||
torch_device: Optional[Union[str, torch.device]] = None,
|
||||
@@ -610,7 +623,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and not isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
|
||||
return hasattr(module, "_hf_hook") and not isinstance(
|
||||
module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook)
|
||||
)
|
||||
|
||||
def module_is_offloaded(module):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
|
||||
@@ -635,26 +650,38 @@ class DiffusionPipeline(ConfigMixin):
|
||||
)
|
||||
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
||||
|
||||
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
||||
for name in module_names:
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
for module in modules:
|
||||
is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
|
||||
|
||||
if is_loaded_in_8bit and torch_dtype is not None:
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
|
||||
)
|
||||
|
||||
if is_loaded_in_8bit and torch_device is not None:
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
|
||||
)
|
||||
else:
|
||||
module.to(torch_device, torch_dtype)
|
||||
if (
|
||||
module.dtype == torch.float16
|
||||
and str(torch_device) in ["cpu"]
|
||||
and not silence_dtype_warnings
|
||||
and not is_offloaded
|
||||
):
|
||||
logger.warning(
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||
)
|
||||
|
||||
if (
|
||||
module.dtype == torch.float16
|
||||
and str(torch_device) in ["cpu"]
|
||||
and not silence_dtype_warnings
|
||||
and not is_offloaded
|
||||
):
|
||||
logger.warning(
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
@@ -664,12 +691,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
`torch.device`: The torch device on which the pipeline is located.
|
||||
"""
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
||||
|
||||
for module in modules:
|
||||
return module.device
|
||||
|
||||
for name in module_names:
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
return module.device
|
||||
return torch.device("cpu")
|
||||
|
||||
@classmethod
|
||||
@@ -875,6 +902,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
# pop out "_ignore_files" as it is only needed for download
|
||||
config_dict.pop("_ignore_files", None)
|
||||
|
||||
# 2. Define which model components should load variants
|
||||
# We retrieve the information by matching whether variant
|
||||
# model checkpoints exist in the subfolders
|
||||
@@ -1046,7 +1076,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
return_cached_folder = kwargs.pop("return_cached_folder", False)
|
||||
if return_cached_folder:
|
||||
message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.17.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`."
|
||||
deprecate("return_cached_folder", "0.17.0", message, take_from=kwargs)
|
||||
deprecate("return_cached_folder", "0.17.0", message)
|
||||
return model, cached_folder
|
||||
|
||||
return model
|
||||
@@ -1192,12 +1222,19 @@ class DiffusionPipeline(ConfigMixin):
|
||||
)
|
||||
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
|
||||
ignore_filenames = config_dict.pop("_ignore_files", [])
|
||||
|
||||
# retrieve all folder_names that contain relevant files
|
||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
|
||||
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
||||
|
||||
# remove ignored filenames
|
||||
model_filenames = set(model_filenames) - set(ignore_filenames)
|
||||
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
||||
|
||||
# if the whole pipeline is cached we don't have to ping the Hub
|
||||
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
||||
version.parse(__version__).base_version
|
||||
@@ -1358,16 +1395,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
if images.shape[-1] == 1:
|
||||
# special case for grayscale (single channel) images
|
||||
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
||||
else:
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
return numpy_to_pil(images)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
if not hasattr(self, "_progress_bar_config"):
|
||||
@@ -1438,13 +1466,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
||||
|
||||
for module_name in module_names:
|
||||
module = getattr(self, module_name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
fn_recursive_set_mem_eff(module)
|
||||
for module in modules:
|
||||
fn_recursive_set_mem_eff(module)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
@@ -1471,10 +1498,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def set_attention_slice(self, slice_size: Optional[int]):
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")]
|
||||
|
||||
for module_name in module_names:
|
||||
module = getattr(self, module_name)
|
||||
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size)
|
||||
for module in modules:
|
||||
module.set_attention_slice(slice_size)
|
||||
|
||||
@@ -31,33 +31,30 @@ from transformers import (
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
PriorTransformer,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
PriorTransformer,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
UnCLIPScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
|
||||
from ...utils import is_omegaconf_available, is_safetensors_available, logging
|
||||
from ...utils.import_utils import BACKENDS_MAPPING
|
||||
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from ..paint_by_example import PaintByExampleImageEncoder
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -990,7 +987,8 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
clip_stats_path: Optional[str] = None,
|
||||
controlnet: Optional[bool] = None,
|
||||
load_safety_checker: bool = True,
|
||||
) -> StableDiffusionPipeline:
|
||||
pipeline_class: DiffusionPipeline = None,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||
config file.
|
||||
@@ -1018,6 +1016,8 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
model_type (`str`, *optional*, defaults to `None`):
|
||||
The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
|
||||
"FrozenCLIPEmbedder", "PaintByExample"]`.
|
||||
is_img2img (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model should be loaded as an img2img pipeline.
|
||||
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
||||
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
|
||||
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
||||
@@ -1026,12 +1026,29 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
Whether the attention computation should always be upcasted. This is necessary when running stable
|
||||
diffusion 2.1.
|
||||
device (`str`, *optional*, defaults to `None`):
|
||||
The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is
|
||||
in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
|
||||
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
The device to use. Pass `None` to determine automatically.
|
||||
from_safetensors (`str`, *optional*, defaults to `False`):
|
||||
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
||||
Whether to load the safety checker or not. Defaults to `True`.
|
||||
pipeline_class (`str`, *optional*, defaults to `None`):
|
||||
The pipeline class to use. Pass `None` to determine automatically.
|
||||
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
"""
|
||||
|
||||
# import pipelines here to avoid circular import error when using from_ckpt method
|
||||
from diffusers import (
|
||||
LDMTextToImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
|
||||
if pipeline_class is None:
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
|
||||
if prediction_type == "v-prediction":
|
||||
prediction_type = "v_prediction"
|
||||
|
||||
@@ -1193,7 +1210,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline(
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -1293,7 +1310,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline(
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -1320,7 +1337,7 @@ def download_controlnet_from_original_ckpt(
|
||||
upcast_attention: Optional[bool] = None,
|
||||
device: str = None,
|
||||
from_safetensors: bool = False,
|
||||
) -> StableDiffusionPipeline:
|
||||
) -> DiffusionPipeline:
|
||||
if not is_omegaconf_available():
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
|
||||
|
||||
@@ -528,7 +528,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """
|
||||
... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
|
||||
... )
|
||||
>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
|
||||
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
|
||||
... )
|
||||
>>> params["controlnet"] = controlnet_params
|
||||
|
||||
|
||||
@@ -56,7 +56,18 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
scheduler: Any,
|
||||
max_noise_level: int = 350,
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level)
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
watermarker=None,
|
||||
max_noise_level=max_noise_level,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
||||
@@ -20,7 +20,7 @@ from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -53,13 +53,21 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -76,7 +76,7 @@ class AttentionStore:
|
||||
|
||||
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
||||
if self.cur_att_layer >= 0 and is_cross:
|
||||
if attn.shape[1] == self.attn_res**2:
|
||||
if attn.shape[1] == np.prod(self.attn_res):
|
||||
self.step_store[place_in_unet].append(attn)
|
||||
|
||||
self.cur_att_layer += 1
|
||||
@@ -98,7 +98,7 @@ class AttentionStore:
|
||||
attention_maps = self.get_average_attention()
|
||||
for location in from_where:
|
||||
for item in attention_maps[location]:
|
||||
cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1])
|
||||
cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
|
||||
out.append(cross_maps)
|
||||
out = torch.cat(out, dim=0)
|
||||
out = out.sum(0) / out.shape[0]
|
||||
@@ -109,7 +109,7 @@ class AttentionStore:
|
||||
self.step_store = self.get_empty_store()
|
||||
self.attention_store = {}
|
||||
|
||||
def __init__(self, attn_res=16):
|
||||
def __init__(self, attn_res):
|
||||
"""
|
||||
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
|
||||
process
|
||||
@@ -724,7 +724,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
max_iter_to_alter: int = 25,
|
||||
thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
|
||||
scale_factor: int = 20,
|
||||
attn_res: int = 16,
|
||||
attn_res: Optional[Tuple[int]] = (16, 16),
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -796,8 +796,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
|
||||
scale_factor (`int`, *optional*, default to 20):
|
||||
Scale factor that controls the step size of each Attend and Excite update.
|
||||
attn_res (`int`, *optional*, default to 16):
|
||||
The resolution of most semantic attention map.
|
||||
attn_res (`tuple`, *optional*, default computed from width and height):
|
||||
The 2D resolution of the semantic attention map.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -870,7 +870,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
self.attention_store = AttentionStore(attn_res=attn_res)
|
||||
if attn_res is None:
|
||||
attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
|
||||
self.attention_store = AttentionStore(attn_res)
|
||||
self.register_attention_control()
|
||||
|
||||
# default config for step size from original repo
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user