Compare commits

..

15 Commits

Author SHA1 Message Date
sayakpaul
c980c53a8a fix-copies 2024-03-07 17:51:18 +05:30
sayakpaul
a0fcca6dd5 remove device. 2024-03-07 17:46:23 +05:30
sayakpaul
c2cc79c2c0 fix more 2024-03-07 17:42:52 +05:30
sayakpaul
ec64d34d5c checking 2024-03-07 17:41:57 +05:30
sayakpaul
43fbc3aec5 debug 2024-03-07 17:40:24 +05:30
Sayak Paul
196835695e fix: support for loading playground v2.5 single file checkpoint. (#7230)
* fix: support for loading playground v2.5 single file checkpoint.

* remove is_playground_model.

* fix: edm key

* apply Dhruv's comments but errors.

* fix: things.

* delegate model_type inference to a function.

* address Dhruv's comment.

* address rest of the comments.

* fix: kwargs

* fix

* update

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>
2024-03-07 15:31:03 +05:30
Sayak Paul
0d4dfbbd0a [Examples] fix: prior preservation setting in DreamBooth LoRA SDXL script. (#7242)
fix: prior preservation setting in DreamBooth LoRA SDXL script.

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2024-03-07 15:19:58 +05:30
Rinne
ada3bb941b fix: remove duplicated code in TemporalBasicTransformerBlock. (#7212)
fix: remove duplicate code in TemporalBasicTransformerBlock.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-07 13:25:22 +05:30
Linoy Tsaban
b5814c5555 add DoRA training feature to sdxl dreambooth lora script (#7235)
* dora in canonical script

* add mention of DoRA to readme
2024-03-07 11:43:37 +05:30
Paakhhi
9940573618 Refactor Prompt2Prompt: Inherit from DiffusionPipeline (#7211)
refactor: inherit from DiffusionPipeline instead of StableDiffusionPipeline
2024-03-06 19:34:40 -10:00
Sayak Paul
59433ca1ae [Core] move out the utilities from pipeline_utils.py (#7234)
move out the utilities from pipeline_utils.py
2024-03-07 08:45:24 +05:30
Nate Landman
534f5d54fa Update train_dreambooth_lora_sdxl_advanced.py (#7227)
adding the type gives you

```
TypeError: _StoreTrueAction.__init__() got an unexpected keyword argument 'type'
```
2024-03-06 12:41:48 +01:00
Kashif Rasul
40aa47b998 [Pipiline] Wuerstchen v3 aka Stable Cascasde pipeline (#6487)
* initial diffNext v3

* move to v3 folder

* imports

* dry up the unets

* no switch_level

* fix init

* add switch_level tp config

* Fixed some things

* Added pooled text embeddings

* Initial work on adding image encoder

* changes from @dome272

* Stuff for the image encoder processing and variable naming in decoder

* fix arg name

* inference fixes

* inference fixes

* default TimestepBlock without conds

* c_skip=0 by default

* fix bfloat16 to cpu

* use config

* undo temp change

* fix gen_c_embeddings args

* change text encoding

* text encoding

* undo print

* undo .gitignore change

* Allow WuerstchenV3PriorPipeline to use the base DDPM & DDIM schedulers

* use WuerstchenV3Unet in both pipelines

* fix imports

* initial failing tests

* cleanup

* use scheduler.timesterps

* some fixes to the tests, still not fully working

* fix tests

* fix prior tests

* add dropout to the model_kwargs

* more tests passing

* update expected_slice

* initial rename

* rename tests

* rename class names

* make fix-copies

* initial docs

* autodocs

* typos

* fix arg docs

* add text_encoder info

* combined pipeline has optional image arg

* fix documentation

* Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* use self.config

* Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* c_in -> in_channels

* removed kwargs from unet's forward

* Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* remove older callback api

* removed kwargs and fixed decoder guidance > 1

* decoder takes emeds

* check and use image_embeds

* fixed all but one decoder test

* fix decoder tests

* update callback api

* fix some more combined tests

* push combined pipeline

* initial docs

* fix doc_string

* update combined api

* no test_callback_inputs test for combined pipeline

* add optional components

* fix ordering of components

* fix combined tests

* update convert script

* Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* fix imports

* move effnet out of deniosing loop

* prompt_embeds_pooled only when doing guidance

* Fix repeat shape

* move StableCascadeUnet to models/unets/

* more descriptive names

* converted when numpy()

* StableCascadePriorPipelineOutput docs

* rename StableCascadeUNet

* add slow tests

* fix slow tests

* update

* update

* updated model_path

* add args for weights

* set push_to_hub to false

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Dominic Rampas <d6582533@gmail.com>
Co-authored-by: Pablo Pernias <pablo@pernias.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: 99991 <99991@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-03-06 15:07:25 +05:30
Jinay Jain
1bc0d37ffe [bug] Fix float/int guidance scale not working in StableVideoDiffusionPipeline (#7143)
* [bug] Fix float/int guidance scale not working in `StableVideoDiffusionPipeline`

* Add test to disable CFG on SVD

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-03-06 14:05:07 +05:30
bram-w
eb942b866a SDXL Turbo support and example launch (#6473)
* support and example launch for sdxl turbo

* White space fixes

* Trailing whitespace character

* ruff format

* fix guidance_scale and steps for turbo mode

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Radames Ajna <radamajna@gmail.com>
2024-03-06 11:51:01 +05:30
42 changed files with 4961 additions and 1116 deletions

View File

@@ -318,6 +318,8 @@
title: Semantic Guidance
- local: api/pipelines/shap_e
title: Shap-E
- local: api/pipelines/stable_cascade
title: Stable Cascade
- sections:
- local: api/pipelines/stable_diffusion/overview
title: Overview

View File

@@ -0,0 +1,88 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable Cascade
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
Diffusion 1.5.
Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions
like finetuning, LoRA, ControlNet, IP-Adapter, LCM etc. are possible with this method as well.
The original codebase can be found at [Stability-AI/StableCascade](https://github.com/Stability-AI/StableCascade).
## Model Overview
Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images,
hence the name "Stable Cascade".
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
for generating the small 24 x 24 latents given a text prompt.
## Uses
### Direct Use
The model is intended for research purposes for now. Possible research areas and tasks include
- Research on generative models.
- Safe deployment of models which have the potential to generate harmful content.
- Probing and understanding the limitations and biases of generative models.
- Generation of artworks and use in design and other artistic processes.
- Applications in educational or creative tools.
Excluded uses are described below.
### Out-of-Scope Use
The model was not trained to be factual or true representations of people or events,
and therefore using the model to generate such content is out-of-scope for the abilities of this model.
The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
## Limitations and Bias
### Limitations
- Faces and people in general may not be generated properly.
- The autoencoding part of the model is lossy.
## StableCascadeCombinedPipeline
[[autodoc]] StableCascadeCombinedPipeline
- all
- __call__
## StableCascadePriorPipeline
[[autodoc]] StableCascadePriorPipeline
- all
- __call__
## StableCascadePriorPipelineOutput
[[autodoc]] pipelines.stable_cascade.pipeline_stable_cascade_prior.StableCascadePriorPipelineOutput
## StableCascadeDecoderPipeline
[[autodoc]] StableCascadeDecoderPipeline
- all
- __call__

View File

@@ -666,7 +666,6 @@ def parse_args(input_args=None):
)
parser.add_argument(
"--use_dora",
type=bool,
action="store_true",
default=False,
help=(

View File

@@ -15,18 +15,47 @@
from __future__ import annotations
import abc
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.models.attention import Attention
from diffusers.pipelines.stable_diffusion import (
StableDiffusionPipeline,
StableDiffusionPipelineOutput,
from packaging import version
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel
from diffusers.configuration_utils import FrozenDict, deprecate
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import (
FromSingleFileMixin,
IPAdapterMixin,
LoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models.attention import Attention
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
@@ -43,34 +72,486 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
class Prompt2PromptPipeline(StableDiffusionPipeline):
class Prompt2PromptPipeline(
DiffusionPipeline,
TextualInversionLoaderMixin,
LoraLoaderMixin,
IPAdapterMixin,
FromSingleFileMixin,
):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
Prompt-to-Prompt-Pipeline for text-to-image generation using Stable Diffusion. This model inherits from
[`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for
all the pipelines (such as downloading or saving, running on a particular device, etc.)
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler
([`SchedulerMixin`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder ([`~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
A `UNet2DConditionModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
**kwargs,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
output_hidden_states=True,
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,

View File

@@ -243,3 +243,29 @@ accelerate launch train_dreambooth_lora_sdxl.py \
> [!CAUTION]
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
### DoRA training
The script now supports DoRA training too!
> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353),
**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.
The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.
> [!NOTE]
> 💡DoRA training is still _experimental_
> and is likely to require different hyperparameter values to perform best compared to a LoRA.
> Specifically, we've noticed 2 differences to take into account your training:
> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA)
> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example.
> This is also aligned with some of the quantitative analysis shown in the paper.
**Usage**
1. To use DoRA you need to install `peft` from main:
```bash
pip install git+https://github.com/huggingface/peft.git
```
2. Enable DoRA training by adding this flag
```bash
--use_dora
```
**Inference**
The inference is the same as if you train a regular LoRA 🤗

View File

@@ -647,6 +647,15 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--use_dora",
action="store_true",
default=False,
help=(
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -868,6 +877,8 @@ def collate_fn(examples, with_prior_preservation=False):
if with_prior_preservation:
pixel_values += [example["class_images"] for example in examples]
prompts += [example["class_prompt"] for example in examples]
original_sizes += [example["original_size"] for example in examples]
crop_top_lefts += [example["crop_top_left"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
@@ -1147,6 +1158,7 @@ def main(args):
# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
@@ -1158,6 +1170,7 @@ def main(args):
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],

View File

@@ -61,6 +61,34 @@ accelerate launch train_diffusion_dpo_sdxl.py \
--push_to_hub
```
## SDXL Turbo training command
```bash
accelerate launch train_diffusion_dpo_sdxl.py \
--pretrained_model_name_or_path=stabilityai/sdxl-turbo \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir="diffusion-sdxl-turbo-dpo" \
--mixed_precision="fp16" \
--dataset_name=kashif/pickascore \
--train_batch_size=8 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing \
--use_8bit_adam \
--rank=8 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=2000 \
--checkpointing_steps=500 \
--run_validation --validation_steps=50 \
--seed="0" \
--report_to="wandb" \
--is_turbo --resolution 512 \
--push_to_hub
```
## Acknowledgements
This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/.

View File

@@ -118,9 +118,16 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v
images = []
context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
guidance_scale = 5.0
num_inference_steps = 25
if args.is_turbo:
guidance_scale = 0.0
num_inference_steps = 4
for prompt in VALIDATION_PROMPTS:
with context:
image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0]
image = pipeline(
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
).images[0]
images.append(image)
tracker_key = "test" if is_final_validation else "validation"
@@ -141,7 +148,10 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v
if is_final_validation:
pipeline.disable_lora()
no_lora_images = [
pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS
pipeline(
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
).images[0]
for prompt in VALIDATION_PROMPTS
]
for tracker in accelerator.trackers:
@@ -423,6 +433,11 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--is_turbo",
action="store_true",
help=("Use if tuning SDXL Turbo instead of SDXL"),
)
parser.add_argument(
"--rank",
type=int,
@@ -444,6 +459,9 @@ def parse_args(input_args=None):
if args.dataset_name is None:
raise ValueError("Must provide a `dataset_name`.")
if args.is_turbo:
assert "turbo" in args.pretrained_model_name_or_path
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
@@ -560,6 +578,36 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
def enforce_zero_terminal_snr(scheduler):
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L93
# Original implementation https://arxiv.org/pdf/2305.08891.pdf
# Turbo needs zero terminal SNR
# Turbo: https://static1.squarespace.com/static/6213c340453c3f502425776e/t/65663480a92fba51d0e1023f/1701197769659/adversarial_diffusion_distillation.pdf
# Convert betas to alphas_bar_sqrt
alphas = 1 - scheduler.betas
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = alphas_bar.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
alphas_cumprod = torch.cumprod(alphas, dim=0)
scheduler.alphas_cumprod = alphas_cumprod
return
if args.is_turbo:
enforce_zero_terminal_snr(noise_scheduler)
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
@@ -909,6 +957,10 @@ def main(args):
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
).repeat(2)
if args.is_turbo:
# Learn a 4 timestep schedule
timesteps_0_to_3 = timesteps % 4
timesteps = 250 * timesteps_0_to_3 + 249
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)

View File

@@ -0,0 +1,215 @@
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
import accelerate
import torch
from safetensors.torch import load_file
from transformers import (
AutoTokenizer,
CLIPConfig,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
)
from diffusers import (
DDPMWuerstchenScheduler,
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
)
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights")
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to")
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
args = parser.parse_args()
model_path = args.model_path
device = "cpu"
# set paths to model weights
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
# Clip Text encoder and tokenizer
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
config.text_config.projection_dim = config.projection_dim
text_encoder = CLIPTextModelWithProjection.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
)
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
# image processor
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
# Prior
if args.use_safetensors:
orig_state_dict = load_file(prior_checkpoint_path, device=device)
else:
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
with accelerate.init_empty_weights():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=2048,
block_out_channels=[2048, 2048],
num_attention_heads=[32, 32],
down_num_layers_per_block=[8, 24],
up_num_layers_per_block=[24, 8],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
load_model_dict_into_meta(prior_model, state_dict)
# scheduler for prior and decoder
scheduler = DDPMWuerstchenScheduler()
# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub)
# Decoder
if args.use_safetensors:
orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
# rename clip_mapper to clip_txt_pooled_mapper
elif key.endswith("clip_mapper.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
with accelerate.init_empty_weights():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 640, 1280, 1280],
down_num_layers_per_block=[2, 6, 28, 6],
up_num_layers_per_block=[6, 28, 6, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[3, 3, 2, 2],
num_attention_heads=[0, 0, 20, 20],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
)
load_model_dict_into_meta(decoder, state_dict)
# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub)
# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub)

View File

@@ -86,6 +86,7 @@ else:
"MotionAdapter",
"MultiAdapter",
"PriorTransformer",
"StableCascadeUNet",
"T2IAdapter",
"T5FilmDecoder",
"Transformer2DModel",
@@ -259,6 +260,9 @@ else:
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
"StableCascadePriorPipeline",
"StableDiffusionAdapterPipeline",
"StableDiffusionAttendAndExcitePipeline",
"StableDiffusionControlNetImg2ImgPipeline",
@@ -626,6 +630,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
StableDiffusionAdapterPipeline,
StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline,

View File

@@ -143,5 +143,4 @@ class FromOriginalVAEMixin:
if torch_dtype is not None:
vae = vae.to(torch_dtype)
vae.eval()
return vae

View File

@@ -133,5 +133,4 @@ class FromOriginalControlNetMixin:
if torch_dtype is not None:
controlnet = controlnet.to(torch_dtype)
controlnet.eval()
return controlnet

View File

@@ -63,13 +63,20 @@ def build_sub_model_components(
num_in_channels=num_in_channels,
image_size=image_size,
torch_dtype=torch_dtype,
model_type=model_type,
)
return unet_components
if component_name == "vae":
scaling_factor = kwargs.get("scaling_factor", None)
vae_components = create_diffusers_vae_model_from_ldm(
pipeline_class_name, original_config, checkpoint, image_size, scaling_factor, torch_dtype
pipeline_class_name,
original_config,
checkpoint,
image_size,
scaling_factor,
torch_dtype,
model_type=model_type,
)
return vae_components
@@ -124,11 +131,12 @@ def build_sub_model_components(
def set_additional_components(
pipeline_class_name,
original_config,
checkpoint=None,
model_type=None,
):
components = {}
if pipeline_class_name in REFINER_PIPELINES:
model_type = infer_model_type(original_config, model_type=model_type)
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
is_refiner = model_type == "SDXL-Refiner"
components.update(
{

View File

@@ -28,6 +28,7 @@ from ..schedulers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EDMDPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
@@ -175,6 +176,7 @@ DIFFUSERS_TO_LDM_MAPPING = {
LDM_VAE_KEY = "first_stage_model."
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
LDM_UNET_KEY = "model.diffusion_model."
LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
@@ -305,7 +307,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
return original_config
def infer_model_type(original_config, model_type=None):
def infer_model_type(original_config, checkpoint=None, model_type=None):
if model_type is not None:
return model_type
@@ -323,7 +325,9 @@ def infer_model_type(original_config, model_type=None):
elif has_network_config:
context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"]
if context_dim == 2048:
if "edm_mean" in checkpoint and "edm_std" in checkpoint:
model_type = "Playground"
elif context_dim == 2048:
model_type = "SDXL"
else:
model_type = "SDXL-Refiner"
@@ -344,13 +348,13 @@ def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=
return image_size
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
model_type = infer_model_type(original_config, model_type)
model_type = infer_model_type(original_config, checkpoint, model_type)
if pipeline_class_name == "StableDiffusionUpscalePipeline":
image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
return image_size
elif model_type in ["SDXL", "SDXL-Refiner"]:
elif model_type in ["SDXL", "SDXL-Refiner", "Playground"]:
image_size = 1024
return image_size
@@ -506,12 +510,14 @@ def create_controlnet_diffusers_config(original_config, image_size: int):
return controlnet_config
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None):
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None, latents_mean=None, latents_std=None):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
if scaling_factor is None and "scale_factor" in original_config["model"]["params"]:
if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
scaling_factor = original_config["model"]["params"]["scale_factor"]
elif scaling_factor is None:
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
@@ -531,6 +537,8 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
"layers_per_block": vae_params["num_res_blocks"],
"scaling_factor": scaling_factor,
}
if latents_mean is not None and latents_std is not None:
config.update({"latents_mean": latents_mean, "latents_std": latents_std})
return config
@@ -1172,6 +1180,7 @@ def create_diffusers_unet_model_from_ldm(
extract_ema=False,
image_size=None,
torch_dtype=None,
model_type=None,
):
from ..models import UNet2DConditionModel
@@ -1190,7 +1199,9 @@ def create_diffusers_unet_model_from_ldm(
else:
num_in_channels = 4
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
image_size = set_image_size(
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
)
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["in_channels"] = num_in_channels
unet_config["upcast_attention"] = upcast_attention
@@ -1223,14 +1234,40 @@ def create_diffusers_unet_model_from_ldm(
def create_diffusers_vae_model_from_ldm(
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None, torch_dtype=None
pipeline_class_name,
original_config,
checkpoint,
image_size=None,
scaling_factor=None,
torch_dtype=None,
model_type=None,
):
# import here to avoid circular imports
from ..models import AutoencoderKL
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
image_size = set_image_size(
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
)
model_type = infer_model_type(original_config, checkpoint, model_type)
vae_config = create_vae_diffusers_config(original_config, image_size=image_size, scaling_factor=scaling_factor)
if model_type == "Playground":
edm_mean = (
checkpoint["edm_mean"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_mean"].tolist()
)
edm_std = (
checkpoint["edm_std"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_std"].tolist()
)
else:
edm_mean = None
edm_std = None
vae_config = create_vae_diffusers_config(
original_config,
image_size=image_size,
scaling_factor=scaling_factor,
latents_mean=edm_mean,
latents_std=edm_std,
)
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
@@ -1265,7 +1302,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
local_files_only=False,
torch_dtype=None,
):
model_type = infer_model_type(original_config, model_type=model_type)
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
if model_type == "FrozenOpenCLIPEmbedder":
config_name = "stabilityai/stable-diffusion-2"
@@ -1332,7 +1369,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
"text_encoder_2": text_encoder_2,
}
elif model_type == "SDXL":
elif model_type in ["SDXL", "Playground"]:
try:
config_name = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
@@ -1383,7 +1420,7 @@ def create_scheduler_from_ldm(
model_type=None,
):
scheduler_config = get_default_scheduler_config()
model_type = infer_model_type(original_config, model_type=model_type)
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
@@ -1406,7 +1443,8 @@ def create_scheduler_from_ldm(
if model_type in ["SDXL", "SDXL-Refiner"]:
scheduler_type = "euler"
elif model_type == "Playground":
scheduler_type = "edm_dpm_solver_multistep"
else:
beta_start = original_config["model"]["params"].get("linear_start", 0.02)
beta_end = original_config["model"]["params"].get("linear_end", 0.085)
@@ -1438,6 +1476,26 @@ def create_scheduler_from_ldm(
elif scheduler_type == "ddim":
scheduler = DDIMScheduler.from_config(scheduler_config)
elif scheduler_type == "edm_dpm_solver_multistep":
scheduler_config = {
"algorithm_type": "dpmsolver++",
"dynamic_thresholding_ratio": 0.995,
"euler_at_final": False,
"final_sigmas_type": "zero",
"lower_order_final": True,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"rho": 7.0,
"sample_max_value": 1.0,
"sigma_data": 0.5,
"sigma_max": 80.0,
"sigma_min": 0.002,
"solver_order": 2,
"solver_type": "midpoint",
"thresholding": False,
}
scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config)
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")

View File

@@ -47,6 +47,7 @@ if is_torch_available():
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
_import_structure["vq_model"] = ["VQModel"]
@@ -80,6 +81,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
I2VGenXLUNet,
Kandinsky3UNet,
MotionAdapter,
StableCascadeUNet,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,

View File

@@ -440,7 +440,6 @@ class TemporalBasicTransformerBlock(nn.Module):
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,

View File

@@ -39,7 +39,6 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_single_file_checkpoint,
is_torch_version,
logging,
)
@@ -49,8 +48,6 @@ from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populat
logger = logging.get_logger(__name__)
SINGLE_FILE_LOADABLE_CLASSES = {"ControlNetModel", "AutoencoderKL"}
if is_torch_version(">=", "1.9.0"):
_LOW_CPU_MEM_USAGE_DEFAULT = True
else:
@@ -500,90 +497,102 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
if is_single_file_checkpoint(pretrained_model_name_or_path):
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
raise ValueError(
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
)
logger.info("Single file checkpoint detected...")
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
model = model.eval()
return model
else:
cache_dir = kwargs.pop("cache_dir", None)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
from_flax = kwargs.pop("from_flax", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
cache_dir = kwargs.pop("cache_dir", None)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
from_flax = kwargs.pop("from_flax", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if device_map is not None and not is_accelerate_available():
raise NotImplementedError(
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
" `device_map=None`. You can install accelerate with `pip install accelerate`."
)
if device_map is not None and not is_accelerate_available():
raise NotImplementedError(
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
" `device_map=None`. You can install accelerate with `pip install accelerate`."
)
# Check if we can handle device_map and dispatching the weights
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
# Check if we can handle device_map and dispatching the weights
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}
# load config
config, unused_kwargs, commit_hash = cls.load_config(
config_path,
# load config
config, unused_kwargs, commit_hash = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
user_agent=user_agent,
**kwargs,
)
# load model
model_file = None
if from_flax:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=FLAX_WEIGHTS_NAME,
cache_dir=cache_dir,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
@@ -591,62 +600,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
token=token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
user_agent=user_agent,
**kwargs,
commit_hash=commit_hash,
)
model = cls.from_config(config, **unused_kwargs)
# load model
model_file = None
if from_flax:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=FLAX_WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
model = cls.from_config(config, **unused_kwargs)
# Convert the weights
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
# Convert the weights
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else:
if use_safetensors:
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
except IOError as e:
if not allow_pickle:
raise e
pass
if model_file is None:
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else:
if use_safetensors:
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
@@ -658,48 +626,95 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
user_agent=user_agent,
commit_hash=commit_hash,
)
except IOError as e:
if not allow_pickle:
raise e
pass
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
if low_cpu_mem_usage:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
model = cls.from_config(config, **unused_kwargs)
if low_cpu_mem_usage:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
model = cls.from_config(config, **unused_kwargs)
# if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None:
param_device = "cpu"
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)
unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_name_or_path,
# if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None:
param_device = "cpu"
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_name_or_path,
)
if len(unexpected_keys) > 0:
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
try:
accelerate.load_checkpoint_and_dispatch(
model,
model_file,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
)
except AttributeError as e:
# When using accelerate loading, we do not have the ability to load the state
# dict and rename the weight names manually. Additionally, accelerate skips
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
# (which look like they should be private variables?), so we can't use the standard hooks
# to rename parameters on load. We need to mimic the original weight names so the correct
# attributes are available. After we have loaded the weights, we convert the deprecated
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
# the weights so we don't have to do this again.
if "'Attention' object has no attribute" in str(e):
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
" please also re-upload it or open a PR on the original repository."
)
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
try:
model._temp_convert_self_to_deprecated_attention_blocks()
accelerate.load_checkpoint_and_dispatch(
model,
model_file,
@@ -709,80 +724,52 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
)
except AttributeError as e:
# When using accelerate loading, we do not have the ability to load the state
# dict and rename the weight names manually. Additionally, accelerate skips
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
# (which look like they should be private variables?), so we can't use the standard hooks
# to rename parameters on load. We need to mimic the original weight names so the correct
# attributes are available. After we have loaded the weights, we convert the deprecated
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
# the weights so we don't have to do this again.
model._undo_temp_convert_self_to_deprecated_attention_blocks()
else:
raise e
if "'Attention' object has no attribute" in str(e):
logger.warn(
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
" please also re-upload it or open a PR on the original repository."
)
model._temp_convert_self_to_deprecated_attention_blocks()
accelerate.load_checkpoint_and_dispatch(
model,
model_file,
device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
)
model._undo_temp_convert_self_to_deprecated_attention_blocks()
else:
raise e
loading_info = {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
else:
model = cls.from_config(config, **unused_kwargs)
loading_info = {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
else:
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
model_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
model_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
elif torch_dtype is not None:
model = model.to(torch_dtype)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if output_loading_info:
return model, loading_info
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
elif torch_dtype is not None:
model = model.to(torch_dtype)
return model
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if output_loading_info:
return model, loading_info
return model
@classmethod
def _load_pretrained_model(

View File

@@ -10,6 +10,7 @@ if is_torch_available():
from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .unet_stable_cascade import StableCascadeUNet
from .uvit_2d import UVit2DModel

View File

@@ -0,0 +1,609 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ..attention_processor import Attention
from ..modeling_utils import ModelMixin
# Copied from diffusers.pipelines.wuerstchen.modeling_wuerstchen_common.WuerstchenLayerNorm with WuerstchenLayerNorm -> SDCascadeLayerNorm
class SDCascadeLayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = super().forward(x)
return x.permute(0, 3, 1, 2)
class SDCascadeTimestepBlock(nn.Module):
def __init__(self, c, c_timestep, conds=[]):
super().__init__()
linear_cls = nn.Linear
self.mapper = linear_cls(c_timestep, c * 2)
self.conds = conds
for cname in conds:
setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2))
def forward(self, x, t):
t = t.chunk(len(self.conds) + 1, dim=1)
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
for i, c in enumerate(self.conds):
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
a, b = a + ac, b + bc
return x * (1 + a) + b
class SDCascadeResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
super().__init__()
self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
nn.Linear(c + c_skip, c * 4),
nn.GELU(),
GlobalResponseNorm(c * 4),
nn.Dropout(dropout),
nn.Linear(c * 4, c),
)
def forward(self, x, x_skip=None):
x_res = x
x = self.norm(self.depthwise(x))
if x_skip is not None:
x = torch.cat([x, x_skip], dim=1)
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x + x_res
# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
class GlobalResponseNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * stand_div_norm) + self.beta + x
class SDCascadeAttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__()
linear_cls = nn.Linear
self.self_attn = self_attn
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
def forward(self, x, kv):
kv = self.kv_mapper(kv)
norm_x = self.norm(x)
if self.self_attn:
batch_size, channel, _, _ = x.shape
kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1)
x = x + self.attention(norm_x, encoder_hidden_states=kv)
return x
class UpDownBlock2d(nn.Module):
def __init__(self, in_channels, out_channels, mode, enabled=True):
super().__init__()
if mode not in ["up", "down"]:
raise ValueError(f"{mode} not supported")
interpolation = (
nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True)
if enabled
else nn.Identity()
)
mapping = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@dataclass
class StableCascadeUNetOutput(BaseOutput):
sample: torch.FloatTensor = None
class StableCascadeUNet(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
timestep_ratio_embedding_dim: int = 64,
patch_size: int = 1,
conditioning_dim: int = 2048,
block_out_channels: Tuple[int] = (2048, 2048),
num_attention_heads: Tuple[int] = (32, 32),
down_num_layers_per_block: Tuple[int] = (8, 24),
up_num_layers_per_block: Tuple[int] = (24, 8),
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
1,
1,
),
up_blocks_repeat_mappers: Optional[Tuple[int]] = (1, 1),
block_types_per_layer: Tuple[Tuple[str]] = (
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
),
clip_text_in_channels: Optional[int] = None,
clip_text_pooled_in_channels=1280,
clip_image_in_channels: Optional[int] = None,
clip_seq=4,
effnet_in_channels: Optional[int] = None,
pixel_mapper_in_channels: Optional[int] = None,
kernel_size=3,
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
self_attn: Union[bool, Tuple[bool]] = True,
timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
switch_level: Optional[Tuple[bool]] = None,
):
"""
Parameters:
in_channels (`int`, defaults to 16):
Number of channels in the input sample.
out_channels (`int`, defaults to 16):
Number of channels in the output sample.
timestep_ratio_embedding_dim (`int`, defaults to 64):
Dimension of the projected time embedding.
patch_size (`int`, defaults to 1):
Patch size to use for pixel unshuffling layer
conditioning_dim (`int`, defaults to 2048):
Dimension of the image and text conditional embedding.
block_out_channels (Tuple[int], defaults to (2048, 2048)):
Tuple of output channels for each block.
num_attention_heads (Tuple[int], defaults to (32, 32)):
Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have attention.
down_num_layers_per_block (Tuple[int], defaults to [8, 24]):
Number of layers in each down block.
up_num_layers_per_block (Tuple[int], defaults to [24, 8]):
Number of layers in each up block.
down_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]):
Number of 1x1 Convolutional layers to repeat in each down block.
up_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]):
Number of 1x1 Convolutional layers to repeat in each up block.
block_types_per_layer (Tuple[Tuple[str]], optional,
defaults to (
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock")
):
Block types used in each layer of the up/down blocks.
clip_text_in_channels (`int`, *optional*, defaults to `None`):
Number of input channels for CLIP based text conditioning.
clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280):
Number of input channels for pooled CLIP text embeddings.
clip_image_in_channels (`int`, *optional*):
Number of input channels for CLIP based image conditioning.
clip_seq (`int`, *optional*, defaults to 4):
effnet_in_channels (`int`, *optional*, defaults to `None`):
Number of input channels for effnet conditioning.
pixel_mapper_in_channels (`int`, defaults to `None`):
Number of input channels for pixel mapper conditioning.
kernel_size (`int`, *optional*, defaults to 3):
Kernel size to use in the block convolutional layers.
dropout (Tuple[float], *optional*, defaults to (0.1, 0.1)):
Dropout to use per block.
self_attn (Union[bool, Tuple[bool]]):
Tuple of booleans that determine whether to use self attention in a block or not.
timestep_conditioning_type (Tuple[str], defaults to ("sca", "crp")):
Timestep conditioning type.
switch_level (Optional[Tuple[bool]], *optional*, defaults to `None`):
Tuple that indicates whether upsampling or downsampling should be applied in a block
"""
super().__init__()
if len(block_out_channels) != len(down_num_layers_per_block):
raise ValueError(
f"Number of elements in `down_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}"
)
elif len(block_out_channels) != len(up_num_layers_per_block):
raise ValueError(
f"Number of elements in `up_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}"
)
elif len(block_out_channels) != len(down_blocks_repeat_mappers):
raise ValueError(
f"Number of elements in `down_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}"
)
elif len(block_out_channels) != len(up_blocks_repeat_mappers):
raise ValueError(
f"Number of elements in `up_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}"
)
elif len(block_out_channels) != len(block_types_per_layer):
raise ValueError(
f"Number of elements in `block_types_per_layer` must match the length of `block_out_channels`: {len(block_out_channels)}"
)
if isinstance(dropout, float):
dropout = (dropout,) * len(block_out_channels)
if isinstance(self_attn, bool):
self_attn = (self_attn,) * len(block_out_channels)
# CONDITIONING
if effnet_in_channels is not None:
self.effnet_mapper = nn.Sequential(
nn.Conv2d(effnet_in_channels, block_out_channels[0] * 4, kernel_size=1),
nn.GELU(),
nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1),
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
)
if pixel_mapper_in_channels is not None:
self.pixels_mapper = nn.Sequential(
nn.Conv2d(pixel_mapper_in_channels, block_out_channels[0] * 4, kernel_size=1),
nn.GELU(),
nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1),
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
)
self.clip_txt_pooled_mapper = nn.Linear(clip_text_pooled_in_channels, conditioning_dim * clip_seq)
if clip_text_in_channels is not None:
self.clip_txt_mapper = nn.Linear(clip_text_in_channels, conditioning_dim)
if clip_image_in_channels is not None:
self.clip_img_mapper = nn.Linear(clip_image_in_channels, conditioning_dim * clip_seq)
self.clip_norm = nn.LayerNorm(conditioning_dim, elementwise_affine=False, eps=1e-6)
self.embedding = nn.Sequential(
nn.PixelUnshuffle(patch_size),
nn.Conv2d(in_channels * (patch_size**2), block_out_channels[0], kernel_size=1),
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
)
def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=True):
if block_type == "SDCascadeResBlock":
return SDCascadeResBlock(in_channels, c_skip, kernel_size=kernel_size, dropout=dropout)
elif block_type == "SDCascadeAttnBlock":
return SDCascadeAttnBlock(in_channels, conditioning_dim, nhead, self_attn=self_attn, dropout=dropout)
elif block_type == "SDCascadeTimestepBlock":
return SDCascadeTimestepBlock(
in_channels, timestep_ratio_embedding_dim, conds=timestep_conditioning_type
)
else:
raise ValueError(f"Block type {block_type} not supported")
# BLOCKS
# -- down blocks
self.down_blocks = nn.ModuleList()
self.down_downscalers = nn.ModuleList()
self.down_repeat_mappers = nn.ModuleList()
for i in range(len(block_out_channels)):
if i > 0:
self.down_downscalers.append(
nn.Sequential(
SDCascadeLayerNorm(block_out_channels[i - 1], elementwise_affine=False, eps=1e-6),
UpDownBlock2d(
block_out_channels[i - 1], block_out_channels[i], mode="down", enabled=switch_level[i - 1]
)
if switch_level is not None
else nn.Conv2d(block_out_channels[i - 1], block_out_channels[i], kernel_size=2, stride=2),
)
)
else:
self.down_downscalers.append(nn.Identity())
down_block = nn.ModuleList()
for _ in range(down_num_layers_per_block[i]):
for block_type in block_types_per_layer[i]:
block = get_block(
block_type,
block_out_channels[i],
num_attention_heads[i],
dropout=dropout[i],
self_attn=self_attn[i],
)
down_block.append(block)
self.down_blocks.append(down_block)
if down_blocks_repeat_mappers is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(down_blocks_repeat_mappers[i] - 1):
block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1))
self.down_repeat_mappers.append(block_repeat_mappers)
# -- up blocks
self.up_blocks = nn.ModuleList()
self.up_upscalers = nn.ModuleList()
self.up_repeat_mappers = nn.ModuleList()
for i in reversed(range(len(block_out_channels))):
if i > 0:
self.up_upscalers.append(
nn.Sequential(
SDCascadeLayerNorm(block_out_channels[i], elementwise_affine=False, eps=1e-6),
UpDownBlock2d(
block_out_channels[i], block_out_channels[i - 1], mode="up", enabled=switch_level[i - 1]
)
if switch_level is not None
else nn.ConvTranspose2d(
block_out_channels[i], block_out_channels[i - 1], kernel_size=2, stride=2
),
)
)
else:
self.up_upscalers.append(nn.Identity())
up_block = nn.ModuleList()
for j in range(up_num_layers_per_block[::-1][i]):
for k, block_type in enumerate(block_types_per_layer[i]):
c_skip = block_out_channels[i] if i < len(block_out_channels) - 1 and j == k == 0 else 0
block = get_block(
block_type,
block_out_channels[i],
num_attention_heads[i],
c_skip=c_skip,
dropout=dropout[i],
self_attn=self_attn[i],
)
up_block.append(block)
self.up_blocks.append(up_block)
if up_blocks_repeat_mappers is not None:
block_repeat_mappers = nn.ModuleList()
for _ in range(up_blocks_repeat_mappers[::-1][i] - 1):
block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1))
self.up_repeat_mappers.append(block_repeat_mappers)
# OUTPUT
self.clf = nn.Sequential(
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
nn.Conv2d(block_out_channels[0], out_channels * (patch_size**2), kernel_size=1),
nn.PixelShuffle(patch_size),
)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, value=False):
self.gradient_checkpointing = value
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02)
nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) if hasattr(self, "clip_txt_mapper") else None
nn.init.normal_(self.clip_img_mapper.weight, std=0.02) if hasattr(self, "clip_img_mapper") else None
if hasattr(self, "effnet_mapper"):
nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
if hasattr(self, "pixels_mapper"):
nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
nn.init.constant_(self.clf[1].weight, 0) # outputs
# blocks
for level_block in self.down_blocks + self.up_blocks:
for block in level_block:
if isinstance(block, SDCascadeResBlock):
block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks[0]))
elif isinstance(block, SDCascadeTimestepBlock):
nn.init.constant_(block.mapper.weight, 0)
def get_timestep_ratio_embedding(self, timestep_ratio, max_positions=10000):
r = timestep_ratio * max_positions
half_dim = self.config.timestep_ratio_embedding_dim // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
emb = r[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.config.timestep_ratio_embedding_dim % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode="constant")
return emb.to(dtype=r.dtype)
def get_clip_embeddings(self, clip_txt_pooled, clip_txt=None, clip_img=None):
if len(clip_txt_pooled.shape) == 2:
clip_txt_pool = clip_txt_pooled.unsqueeze(1)
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(
clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.config.clip_seq, -1
)
if clip_txt is not None and clip_img is not None:
clip_txt = self.clip_txt_mapper(clip_txt)
if len(clip_img.shape) == 2:
clip_img = clip_img.unsqueeze(1)
clip_img = self.clip_img_mapper(clip_img).view(
clip_img.size(0), clip_img.size(1) * self.config.clip_seq, -1
)
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
else:
clip = clip_txt_pool
return self.clip_norm(clip)
def _down_encode(self, x, r_embed, clip):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, SDCascadeResBlock):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
elif isinstance(block, SDCascadeAttnBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, clip, use_reentrant=False
)
elif isinstance(block, SDCascadeTimestepBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, r_embed, use_reentrant=False
)
else:
x = x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), use_reentrant=False
)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
else:
for down_block, downscaler, repmap in block_group:
x = downscaler(x)
for i in range(len(repmap) + 1):
for block in down_block:
if isinstance(block, SDCascadeResBlock):
x = block(x)
elif isinstance(block, SDCascadeAttnBlock):
x = block(x, clip)
elif isinstance(block, SDCascadeTimestepBlock):
x = block(x, r_embed)
else:
x = block(x)
if i < len(repmap):
x = repmap[i](x)
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
if isinstance(block, SDCascadeResBlock):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
x = torch.nn.functional.interpolate(
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, skip, use_reentrant=False
)
elif isinstance(block, SDCascadeAttnBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, clip, use_reentrant=False
)
elif isinstance(block, SDCascadeTimestepBlock):
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), x, r_embed, use_reentrant=False
)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
else:
for i, (up_block, upscaler, repmap) in enumerate(block_group):
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
if isinstance(block, SDCascadeResBlock):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
x = torch.nn.functional.interpolate(
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
)
x = block(x, skip)
elif isinstance(block, SDCascadeAttnBlock):
x = block(x, clip)
elif isinstance(block, SDCascadeTimestepBlock):
x = block(x, r_embed)
else:
x = block(x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
return x
def forward(
self,
sample,
timestep_ratio,
clip_text_pooled,
clip_text=None,
clip_img=None,
effnet=None,
pixels=None,
sca=None,
crp=None,
return_dict=True,
):
if pixels is None:
pixels = sample.new_zeros(sample.size(0), 3, 8, 8)
# Process the conditioning embeddings
timestep_ratio_embed = self.get_timestep_ratio_embedding(timestep_ratio)
for c in self.config.timestep_conditioning_type:
if c == "sca":
cond = sca
elif c == "crp":
cond = crp
else:
cond = None
t_cond = cond or torch.zeros_like(timestep_ratio)
timestep_ratio_embed = torch.cat([timestep_ratio_embed, self.get_timestep_ratio_embedding(t_cond)], dim=1)
clip = self.get_clip_embeddings(clip_txt_pooled=clip_text_pooled, clip_txt=clip_text, clip_img=clip_img)
# Model Blocks
x = self.embedding(sample)
if hasattr(self, "effnet_mapper") and effnet is not None:
x = x + self.effnet_mapper(
nn.functional.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
)
if hasattr(self, "pixels_mapper"):
x = x + nn.functional.interpolate(
self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True
)
level_outputs = self._down_encode(x, timestep_ratio_embed, clip)
x = self._up_decode(level_outputs, timestep_ratio_embed, clip)
sample = self.clf(x)
if not return_dict:
return (sample,)
return StableCascadeUNetOutput(sample=sample)

View File

@@ -176,6 +176,11 @@ else:
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_cascade"] = [
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
"StableCascadePriorPipeline",
]
_import_structure["stable_diffusion"].extend(
[
"CLIPImageProjection",
@@ -424,6 +429,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pixart_alpha import PixArtAlphaPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_cascade import (
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
)
from .stable_diffusion import (
CLIPImageProjection,
StableDiffusionDepth2ImgPipeline,

View File

@@ -0,0 +1,508 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import re
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
from huggingface_hub import (
model_info,
)
from packaging import version
from ..utils import (
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
get_class_from_dynamic_module,
is_peft_available,
is_transformers_available,
logging,
)
from ..utils.torch_utils import is_compiled_module
if is_transformers_available():
import transformers
from transformers import PreTrainedModel
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
INDEX_FILE = "diffusion_pytorch_model.bin"
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
DUMMY_MODULES_FOLDER = "diffusers.utils"
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
CONNECTED_PIPES_KEYS = ["prior"]
logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
},
"onnxruntime.training": {
"ORTModule": ["save_pretrained", "from_pretrained"],
},
}
ALL_IMPORTABLE_CLASSES = {}
for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
"""
Checking for safetensors compatibility:
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
files to know which safetensors files are needed.
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
Converting default pytorch serialized filenames to safetensors serialized filenames:
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
pt_filenames = []
sf_filenames = set()
passed_components = passed_components or []
for filename in filenames:
_, extension = os.path.splitext(filename)
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
continue
if extension == ".bin":
pt_filenames.append(os.path.normpath(filename))
elif extension == ".safetensors":
sf_filenames.add(os.path.normpath(filename))
for filename in pt_filenames:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
path, filename = os.path.split(filename)
filename, extension = os.path.splitext(filename)
if filename.startswith("pytorch_model"):
filename = filename.replace("pytorch_model", "model")
else:
filename = filename
expected_sf_filename = os.path.normpath(os.path.join(path, filename))
expected_sf_filename = f"{expected_sf_filename}.safetensors"
if expected_sf_filename not in sf_filenames:
logger.warning(f"{expected_sf_filename} not found")
return False
return True
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
# .bin, .safetensors, ...
weight_suffixs = [w.split(".")[-1] for w in weight_names]
# -00001-of-00002
transformers_index_format = r"\d{5}-of-\d{5}"
if variant is not None:
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors`
variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
)
# `text_encoder/pytorch_model.bin.index.fp16.json`
variant_index_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
)
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
non_variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
)
# `text_encoder/pytorch_model.bin.index.json`
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
if variant is not None:
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
variant_filenames = variant_weights | variant_indexes
else:
variant_filenames = set()
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
non_variant_filenames = non_variant_weights | non_variant_indexes
# all variant filenames will be used by default
usable_filenames = set(variant_filenames)
def convert_to_variant(filename):
if "index" in filename:
variant_filename = filename.replace("index", f"index.{variant}")
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
else:
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
return variant_filename
for f in non_variant_filenames:
variant_filename = convert_to_variant(f)
if variant_filename not in usable_filenames:
usable_filenames.add(f)
return usable_filenames, variant_filenames
@validate_hf_hub_args
def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames):
info = model_info(
pretrained_model_name_or_path,
token=token,
revision=None,
)
filenames = {sibling.rfilename for sibling in info.siblings}
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
if set(model_filenames).issubset(set(comp_model_filenames)):
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
else:
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
FutureWarning,
)
def _unwrap_model(model):
"""Unwraps a model."""
if is_compiled_module(model):
model = model._orig_mod
if is_peft_available():
from peft import PeftModel
if isinstance(model, PeftModel):
model = model.base_model.model
return model
def maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
):
"""Simple helper method to raise or warn in case incorrect module has been passed"""
if not is_pipeline_module:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
expected_class_obj = None
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
expected_class_obj = class_candidate
# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
sub_model = passed_class_obj[name]
unwrapped_sub_model = _unwrap_model(sub_model)
model_cls = unwrapped_sub_model.__class__
if not issubclass(model_cls, expected_class_obj):
raise ValueError(
f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
)
else:
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name)
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
class_candidates = {c: class_obj for c in importable_classes.keys()}
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
)
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
return class_obj, class_candidates
def _get_pipeline_class(
class_obj,
config=None,
load_connected_pipeline=False,
custom_pipeline=None,
repo_id=None,
hub_revision=None,
class_name=None,
cache_dir=None,
revision=None,
):
if custom_pipeline is not None:
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
elif repo_id is not None:
file_name = f"{custom_pipeline}.py"
custom_pipeline = repo_id
else:
file_name = CUSTOM_PIPELINE_FILE_NAME
if repo_id is not None and hub_revision is not None:
# if we load the pipeline code from the Hub
# make sure to overwrite the `revision`
revision = hub_revision
return get_class_from_dynamic_module(
custom_pipeline,
module_file=file_name,
class_name=class_name,
cache_dir=cache_dir,
revision=revision,
)
if class_obj.__name__ != "DiffusionPipeline":
return class_obj
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
class_name = class_name or config["_class_name"]
if not class_name:
raise ValueError(
"The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`."
)
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
pipeline_cls = getattr(diffusers_module, class_name)
if load_connected_pipeline:
from .auto_pipeline import _get_connected_pipeline
connected_pipeline_cls = _get_connected_pipeline(pipeline_cls)
if connected_pipeline_cls is not None:
logger.info(
f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`"
)
else:
logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.")
pipeline_cls = connected_pipeline_cls or pipeline_cls
return pipeline_cls
def load_sub_model(
library_name: str,
class_name: str,
importable_classes: List[Any],
pipelines: Any,
is_pipeline_module: bool,
pipeline_class: Any,
torch_dtype: torch.dtype,
provider: Any,
sess_options: Any,
device_map: Optional[Union[Dict[str, torch.device], str]],
max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
offload_folder: Optional[Union[str, os.PathLike]],
offload_state_dict: bool,
model_variants: Dict[str, str],
name: str,
from_flax: bool,
variant: str,
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
# retrieve class candidates
class_obj, class_candidates = get_class_obj_and_candidates(
library_name,
class_name,
importable_classes,
pipelines,
is_pipeline_module,
component_name=name,
cache_dir=cached_folder,
)
load_method_name = None
# retrieve load method name
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
# if load method name is None, then we have a dummy module -> raise Error
if load_method_name is None:
none_module = class_obj.__module__
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
TRANSFORMERS_DUMMY_MODULES_FOLDER
)
if is_dummy_path and "dummy" in none_module:
# call class_obj for nice error message of missing requirements
class_obj()
raise ValueError(
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
)
load_method = getattr(class_obj, load_method_name)
# add kwargs to loading method
diffusers_module = importlib.import_module(__name__.split(".")[0])
loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map
loading_kwargs["max_memory"] = max_memory
loading_kwargs["offload_folder"] = offload_folder
loading_kwargs["offload_state_dict"] = offload_state_dict
loading_kwargs["variant"] = model_variants.pop(name, None)
if from_flax:
loading_kwargs["from_flax"] = True
# the following can be deleted once the minimum required `transformers` version
# is higher than 4.27
if (
is_transformers_model
and loading_kwargs["variant"] is not None
and transformers_version < version.parse("4.27.0")
):
raise ImportError(
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
)
elif is_transformers_model and loading_kwargs["variant"] is None:
loading_kwargs.pop("variant")
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
if not (from_flax and is_transformers_model):
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
else:
loading_kwargs["low_cpu_mem_usage"] = False
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
else:
# else load from the root directory
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
return loaded_sub_model
def _fetch_class_library_tuple(module):
# import it here to avoid circular import
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")
# register the config from the original module, not the dynamo compiled one
not_compiled_module = _unwrap_model(module)
library = not_compiled_module.__module__.split(".")[0]
# check if the module is a pipeline module
module_path_items = not_compiled_module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = not_compiled_module.__module__
# retrieve class_name
class_name = not_compiled_module.__class__.__name__
return (library, class_name)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,50 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_stable_cascade"] = ["StableCascadeDecoderPipeline"]
_import_structure["pipeline_stable_cascade_combined"] = ["StableCascadeCombinedPipeline"]
_import_structure["pipeline_stable_cascade_prior"] = ["StableCascadePriorPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
from .pipeline_stable_cascade_combined import StableCascadeCombinedPipeline
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,465 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, List, Optional, Union
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
>>> prior_pipe = StableCascadePriorPipeline.from_pretrained(
... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16
... ).to("cuda")
>>> gen_pipe = StableCascadeDecoderPipeline.from_pretrain(
... "stabilityai/stable-cascade", torch_dtype=torch.float16
... ).to("cuda")
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> prior_output = pipe(prompt)
>>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt)
```
"""
class StableCascadeDecoderPipeline(DiffusionPipeline):
"""
Pipeline for generating images from the Stable Cascade model.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer (`CLIPTokenizer`):
The CLIP tokenizer.
text_encoder (`CLIPTextModel`):
The CLIP text encoder.
decoder ([`StableCascadeUNet`]):
The Stable Cascade decoder unet.
vqgan ([`PaellaVQModel`]):
The VQGAN model.
scheduler ([`DDPMWuerstchenScheduler`]):
A scheduler to be used in combination with `prior` to generate image embedding.
latent_dim_scale (float, `optional`, defaults to 10.67):
Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and
width=int(24*10.67)=256 in order to match the training conditions.
"""
unet_name = "decoder"
text_encoder_name = "text_encoder"
model_cpu_offload_seq = "text_encoder->decoder->vqgan"
_callback_tensor_inputs = [
"latents",
"prompt_embeds_pooled",
"negative_prompt_embeds",
"image_embeddings",
]
def __init__(
self,
decoder: StableCascadeUNet,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
scheduler: DDPMWuerstchenScheduler,
vqgan: PaellaVQModel,
latent_dim_scale: float = 10.67,
) -> None:
super().__init__()
self.register_modules(
decoder=decoder,
tokenizer=tokenizer,
text_encoder=text_encoder,
scheduler=scheduler,
vqgan=vqgan,
)
self.register_to_config(latent_dim_scale=latent_dim_scale)
def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler):
batch_size, channels, height, width = image_embeddings.shape
latents_shape = (
batch_size * num_images_per_prompt,
4,
int(height * self.config.latent_dim_scale),
int(width * self.config.latent_dim_scale),
)
if latents is None:
latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
def encode_prompt(
self,
device,
batch_size,
num_images_per_prompt,
do_classifier_free_guidance,
prompt=None,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
):
if prompt_embeds is None:
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
text_encoder_output = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True
)
prompt_embeds = text_encoder_output.hidden_states[-1]
if prompt_embeds_pooled is None:
prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1)
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0)
if negative_prompt_embeds is None and do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds_text_encoder_output = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=uncond_input.attention_mask.to(device),
output_hidden_states=True,
)
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1]
negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds_pooled.shape[1]
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to(
dtype=self.text_encoder.dtype, device=device
)
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view(
batch_size * num_images_per_prompt, seq_len, -1
)
# done duplicates
return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled
def check_inputs(
self,
prompt,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image_embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]],
prompt: Union[str, List[str]] = None,
num_inference_steps: int = 10,
guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
"""
Function invoked when calling the pipeline for generation.
Args:
image_embedding (`torch.FloatTensor` or `List[torch.FloatTensor]`):
Image Embeddings either extracted from an image or generated by a Prior Model.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
num_inference_steps (`int`, *optional*, defaults to 12):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 0.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
`decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
linked to the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `decoder_guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image
embeddings.
"""
# 0. Define commonly used variables
device = self._execution_device
dtype = self.decoder.dtype
self._guidance_scale = guidance_scale
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
if isinstance(image_embeddings, list):
image_embeddings = torch.cat(image_embeddings, dim=0)
batch_size = image_embeddings.shape[0]
# 2. Encode caption
if prompt_embeds is None and negative_prompt_embeds is None:
prompt_embeds, _, negative_prompt_embeds, _ = self.encode_prompt(
prompt=prompt,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
prompt_embeds_pooled = (
torch.cat([prompt_embeds, negative_prompt_embeds]) if self.do_classifier_free_guidance else prompt_embeds
)
effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
if self.do_classifier_free_guidance
else image_embeddings
)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latents
latents = self.prepare_latents(
image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
)
# 6. Run denoising loop
self._num_timesteps = len(timesteps[:-1])
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
timestep_ratio = t.expand(latents.size(0)).to(dtype)
# 7. Denoise latents
predicted_latents = self.decoder(
sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio,
clip_text_pooled=prompt_embeds_pooled,
effnet=effnet,
return_dict=False,
)[0]
# 8. Check for classifier free guidance and apply it
if self.do_classifier_free_guidance:
predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2)
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
# 9. Renoise latents to next timestep
latents = self.scheduler.step(
model_output=predicted_latents,
timestep=timestep_ratio,
sample=latents,
generator=generator,
).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
)
if not output_type == "latent":
# 10. Scale and decode the image latents with vq-vae
latents = self.vqgan.config.scale_factor * latents
images = self.vqgan.decode(latents).sample.clamp(0, 1)
if output_type == "np":
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
elif output_type == "pil":
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
images = self.numpy_to_pil(images)
else:
images = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return images
return ImagePipelineOutput(images)

View File

@@ -0,0 +1,294 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, List, Optional, Union
import PIL
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
TEXT2IMAGE_EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusions import StableCascadeCombinedPipeline
>>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade-combined", torch_dtype=torch.bfloat16).to(
... "cuda"
... )
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> images = pipe(prompt=prompt)
```
"""
class StableCascadeCombinedPipeline(DiffusionPipeline):
"""
Combined Pipeline for text-to-image generation using Stable Cascade.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer (`CLIPTokenizer`):
The decoder tokenizer to be used for text inputs.
text_encoder (`CLIPTextModel`):
The decoder text encoder to be used for text inputs.
decoder (`StableCascadeUNet`):
The decoder model to be used for decoder image generation pipeline.
scheduler (`DDPMWuerstchenScheduler`):
The scheduler to be used for decoder image generation pipeline.
vqgan (`PaellaVQModel`):
The VQGAN model to be used for decoder image generation pipeline.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
prior_prior (`StableCascadeUNet`):
The prior model to be used for prior pipeline.
prior_scheduler (`DDPMWuerstchenScheduler`):
The scheduler to be used for prior pipeline.
"""
_load_connected_pipes = True
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
decoder: StableCascadeUNet,
scheduler: DDPMWuerstchenScheduler,
vqgan: PaellaVQModel,
prior_prior: StableCascadeUNet,
prior_text_encoder: CLIPTextModel,
prior_tokenizer: CLIPTokenizer,
prior_scheduler: DDPMWuerstchenScheduler,
prior_feature_extractor: Optional[CLIPImageProcessor] = None,
prior_image_encoder: Optional[CLIPVisionModelWithProjection] = None,
):
super().__init__()
self.register_modules(
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqgan,
prior_text_encoder=prior_text_encoder,
prior_tokenizer=prior_tokenizer,
prior_prior=prior_prior,
prior_scheduler=prior_scheduler,
prior_feature_extractor=prior_feature_extractor,
prior_image_encoder=prior_image_encoder,
)
self.prior_pipe = StableCascadePriorPipeline(
prior=prior_prior,
text_encoder=prior_text_encoder,
tokenizer=prior_tokenizer,
scheduler=prior_scheduler,
image_encoder=prior_image_encoder,
feature_extractor=prior_feature_extractor,
)
self.decoder_pipe = StableCascadeDecoderPipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqgan,
)
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
"""
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
self.decoder_pipe.progress_bar(iterable=iterable, total=total)
def set_progress_bar_config(self, **kwargs):
self.prior_pipe.set_progress_bar_config(**kwargs)
self.decoder_pipe.set_progress_bar_config(**kwargs)
@torch.no_grad()
@replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
height: int = 512,
width: int = 512,
prior_num_inference_steps: int = 60,
prior_timesteps: Optional[List[float]] = None,
prior_guidance_scale: float = 4.0,
num_inference_steps: int = 12,
decoder_timesteps: Optional[List[float]] = None,
decoder_guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation for the prior and decoder.
images (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, *optional*):
The images to guide the image generation for the prior.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
input argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`prior_guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
`prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
to the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60):
The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. For more specific timestep spacing, you can pass customized
`prior_timesteps`
num_inference_steps (`int`, *optional*, defaults to 12):
The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at
the expense of slower inference. For more specific timestep spacing, you can pass customized
`timesteps`
decoder_guidance_scale (`float`, *optional*, defaults to 0.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
prior_callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep:
int, callback_kwargs: Dict)`.
prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
the `._callback_tensor_inputs` attribute of your pipeine class.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
"""
prior_outputs = self.prior_pipe(
prompt=prompt if prompt_embeds is None else None,
images=images,
height=height,
width=width,
num_inference_steps=prior_num_inference_steps,
guidance_scale=prior_guidance_scale,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
latents=latents,
output_type="pt",
return_dict=True,
callback_on_step_end=prior_callback_on_step_end,
callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
)
image_embeddings = prior_outputs.image_embeddings
prompt_embeds = prior_outputs.get("prompt_embeds", None)
negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None)
outputs = self.decoder_pipe(
image_embeddings=image_embeddings,
prompt=prompt if prompt_embeds is None else None,
num_inference_steps=num_inference_steps,
guidance_scale=decoder_guidance_scale,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
generator=generator,
output_type=output_type,
return_dict=return_dict,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
return outputs

View File

@@ -0,0 +1,614 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from math import ceil
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL
import torch
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import StableCascadePriorPipeline
>>> prior_pipe = StableCascadePriorPipeline.from_pretrained(
... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16
... ).to("cuda")
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> prior_output = pipe(prompt)
```
"""
@dataclass
class StableCascadePriorPipelineOutput(BaseOutput):
"""
Output class for WuerstchenPriorPipeline.
Args:
image_embeddings (`torch.FloatTensor` or `np.ndarray`)
Prior image embeddings for text prompt
prompt_embeds (`torch.FloatTensor`):
Text embeddings for the prompt.
negative_prompt_embeds (`torch.FloatTensor`):
Text embeddings for the negative prompt.
"""
image_embeddings: Union[torch.FloatTensor, np.ndarray]
prompt_embeds: Union[torch.FloatTensor, np.ndarray]
negative_prompt_embeds: Union[torch.FloatTensor, np.ndarray]
class StableCascadePriorPipeline(DiffusionPipeline):
"""
Pipeline for generating image prior for Stable Cascade.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
prior ([`StableCascadeUNet`]):
The Stable Cascade prior to approximate the image embedding from the text and/or image embedding.
text_encoder ([`CLIPTextModelWithProjection`]):
Frozen text-encoder ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
feature_extractor ([`~transformers.CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
scheduler ([`DDPMWuerstchenScheduler`]):
A scheduler to be used in combination with `prior` to generate image embedding.
resolution_multiple ('float', *optional*, defaults to 42.67):
Default resolution for multiple images generated.
"""
unet_name = "prior"
text_encoder_name = "text_encoder"
model_cpu_offload_seq = "image_encoder->text_encoder->prior"
_optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"]
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModelWithProjection,
prior: StableCascadeUNet,
scheduler: DDPMWuerstchenScheduler,
resolution_multiple: float = 42.67,
feature_extractor: Optional[CLIPImageProcessor] = None,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
) -> None:
super().__init__()
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
prior=prior,
scheduler=scheduler,
)
self.register_to_config(resolution_multiple=resolution_multiple)
def prepare_latents(
self, batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, scheduler
):
latent_shape = (
num_images_per_prompt * batch_size,
self.prior.config.in_channels,
ceil(height / self.config.resolution_multiple),
ceil(width / self.config.resolution_multiple),
)
if latents is None:
latents = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != latent_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latent_shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
def encode_prompt(
self,
device,
batch_size,
num_images_per_prompt,
do_classifier_free_guidance,
prompt=None,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
):
if prompt_embeds is None:
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
text_encoder_output = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True
)
prompt_embeds = text_encoder_output.hidden_states[-1]
if prompt_embeds_pooled is None:
prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1)
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device)
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0)
if negative_prompt_embeds is None and do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds_text_encoder_output = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=uncond_input.attention_mask.to(device),
output_hidden_states=True,
)
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1]
negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds_pooled.shape[1]
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to(
dtype=self.text_encoder.dtype, device=device
)
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view(
batch_size * num_images_per_prompt, seq_len, -1
)
# done duplicates
return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled
def encode_image(self, images, device, dtype, batch_size, num_images_per_prompt):
image_embeds = []
for image in images:
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embed = self.image_encoder(image).image_embeds.unsqueeze(1)
image_embeds.append(image_embed)
image_embeds = torch.cat(image_embeds, dim=1)
image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1)
negative_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, negative_image_embeds
def check_inputs(
self,
prompt,
images=None,
image_embeds=None,
negative_prompt=None,
prompt_embeds=None,
prompt_embeds_pooled=None,
negative_prompt_embeds=None,
negative_prompt_embeds_pooled=None,
callback_on_step_end_tensor_inputs=None,
):
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None:
if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape:
raise ValueError(
"`prompt_embeds_pooled` and `negative_prompt_embeds_pooled` must have the same shape when passed"
f"directly, but got: `prompt_embeds_pooled` {prompt_embeds_pooled.shape} !="
f"`negative_prompt_embeds_pooled` {negative_prompt_embeds_pooled.shape}."
)
if image_embeds is not None and images is not None:
raise ValueError(
f"Cannot forward both `images`: {images} and `image_embeds`: {image_embeds}. Please make sure to"
" only forward one of the two."
)
if images:
for i, image in enumerate(images):
if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
raise TypeError(
f"'images' must contain images of type 'torch.Tensor' or 'PIL.Image.Image, but got"
f"{type(image)} for image number {i}."
)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
def get_t_condioning(self, t, alphas_cumprod):
s = torch.tensor([0.003])
clamp_range = [0, 1]
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
var = alphas_cumprod[t]
var = var.clamp(*clamp_range)
s, min_var = s.to(var.device), min_var.to(var.device)
ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
return ratio
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 20,
timesteps: List[float] = None,
guidance_scale: float = 4.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pt",
return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 1024):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 1024):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 60):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 8.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
`decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
linked to the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `decoder_guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input
argument.
image_embeds (`torch.FloatTensor`, *optional*):
Pre-generated image embeddings. Can be used to easily tweak image inputs, *e.g.* prompt weighting.
If not provided, image embeddings will be generated from `image` input argument if existing.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`StableCascadePriorPipelineOutput`] or `tuple` [`StableCascadePriorPipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated image embeddings.
"""
# 0. Define commonly used variables
device = self._execution_device
dtype = next(self.prior.parameters()).dtype
self._guidance_scale = guidance_scale
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
images=images,
image_embeds=image_embeds,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
# 2. Encode caption + images
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = self.encode_prompt(
prompt=prompt,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
)
if images is not None:
image_embeds_pooled, uncond_image_embeds_pooled = self.encode_image(
images=images,
device=device,
dtype=dtype,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
)
elif image_embeds is not None:
image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1)
uncond_image_embeds_pooled = torch.zeros_like(image_embeds_pooled)
else:
image_embeds_pooled = torch.zeros(
batch_size * num_images_per_prompt,
1,
self.prior.config.clip_image_in_channels,
device=device,
dtype=dtype,
)
uncond_image_embeds_pooled = torch.zeros(
batch_size * num_images_per_prompt,
1,
self.prior.config.clip_image_in_channels,
device=device,
dtype=dtype,
)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([image_embeds_pooled, uncond_image_embeds_pooled], dim=0)
else:
image_embeds = image_embeds_pooled
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_encoder_hidden_states = (
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
)
text_encoder_pooled = (
torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled])
if negative_prompt_embeds is not None
else prompt_embeds_pooled
)
# 4. Prepare and set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latents
latents = self.prepare_latents(
batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
)
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
timesteps = timesteps[:-1]
else:
if self.scheduler.config.clip_sample:
self.scheduler.config.clip_sample = False # disample sample clipping
logger.warning(" set `clip_sample` to be False")
# 6. Run denoising loop
if hasattr(self.scheduler, "betas"):
alphas = 1.0 - self.scheduler.betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
else:
alphas_cumprod = []
self._num_timesteps = len(timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
if len(alphas_cumprod) > 0:
timestep_ratio = self.get_t_condioning(t.long().cpu(), alphas_cumprod)
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
else:
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
else:
timestep_ratio = t.expand(latents.size(0)).to(dtype)
# 7. Denoise image embeddings
predicted_image_embedding = self.prior(
sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio,
clip_text_pooled=text_encoder_pooled,
clip_text=text_encoder_hidden_states,
clip_img=image_embeds,
return_dict=False,
)[0]
# 8. Check for classifier free guidance and apply it
if self.do_classifier_free_guidance:
predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2)
predicted_image_embedding = torch.lerp(
predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale
)
# 9. Renoise latents to next timestep
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
timestep_ratio = t
latents = self.scheduler.step(
model_output=predicted_image_embedding, timestep=timestep_ratio, sample=latents, generator=generator
).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# Offload all models
self.maybe_free_model_hooks()
if output_type == "np":
latents = latents.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
prompt_embeds = prompt_embeds.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
negative_prompt_embeds = (
negative_prompt_embeds.cpu().float().numpy() if negative_prompt_embeds is not None else None
) # float() as bfloat16-> numpy doesnt work
if not return_dict:
return (latents, prompt_embeds, negative_prompt_embeds)
return StableCascadePriorPipelineOutput(latents, prompt_embeds, negative_prompt_embeds)

View File

@@ -322,7 +322,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
@property
def do_classifier_free_guidance(self):
if isinstance(self.guidance_scale, (int, float)):
return self.guidance_scale
return self.guidance_scale > 1
return self.guidance_scale.max() > 1
@property

View File

@@ -1,18 +1,3 @@
# Copyright (c) 2023 Dominic Rampas MIT License
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn

View File

@@ -233,7 +233,7 @@ class WuerstchenDiffNeXt(ModelMixin, ConfigMixin):
class ResBlockStageB(nn.Module):
def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
super().__init__()
self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)

View File

@@ -349,6 +349,11 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
text_encoder_hidden_states = (
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
)
effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
if self.do_classifier_free_guidance
else image_embeddings
)
# 3. Determine latent shape of latents
latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale)
@@ -371,11 +376,6 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
self._num_timesteps = len(timesteps[:-1])
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
ratio = t.expand(latents.size(0)).to(dtype)
effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
if self.do_classifier_free_guidance
else image_embeddings
)
# 7. Denoise latents
predicted_latents = self.decoder(
torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
@@ -423,9 +423,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
latents = self.vqgan.config.scale_factor * latents
images = self.vqgan.decode(latents).sample.clamp(0, 1)
if output_type == "np":
images = images.permute(0, 2, 3, 1).cpu().numpy()
images = images.permute(0, 2, 3, 1).cpu().float().numpy()
elif output_type == "pil":
images = images.permute(0, 2, 3, 1).cpu().numpy()
images = images.permute(0, 2, 3, 1).cpu().float().numpy()
images = self.numpy_to_pil(images)
else:
images = latents

View File

@@ -508,7 +508,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
self.maybe_free_model_hooks()
if output_type == "np":
latents = latents.cpu().numpy()
latents = latents.cpu().float().numpy()
if not return_dict:
return (latents,)

View File

@@ -19,7 +19,6 @@ from packaging import version
from .. import __version__
from .constants import (
_ACCEPTED_SINGLE_FILE_FORMATS,
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME,
@@ -84,7 +83,7 @@ from .import_utils import (
is_xformers_available,
requires_backends,
)
from .loading_utils import is_single_file_checkpoint, load_image
from .loading_utils import load_image
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (

View File

@@ -37,7 +37,6 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://hugging
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
_ACCEPTED_SINGLE_FILE_FORMATS = (".safetensors", ".ckpt", ".bin", ".pth", ".pt")
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

View File

@@ -752,6 +752,51 @@ class ShapEPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableCascadeCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableCascadeDecoderPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableCascadePriorPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionAdapterPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -1,28 +1,10 @@
import os
from typing import Callable, Union
from urllib.parse import urlparse
import PIL.Image
import PIL.ImageOps
import requests
from ..utils.constants import _ACCEPTED_SINGLE_FILE_FORMATS
def is_single_file_checkpoint(filepath):
def is_valid_url(url):
result = urlparse(url)
if result.scheme and result.netloc:
return True
filepath = str(filepath)
if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS):
if is_valid_url(filepath):
return True
elif os.path.isfile(filepath):
return True
return False
def load_image(
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None

View File

@@ -0,0 +1,246 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, StableCascadeCombinedPipeline
from diffusers.models import StableCascadeUNet
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableCascadeCombinedPipeline
params = ["prompt"]
batch_params = ["prompt", "negative_prompt"]
required_optional_params = [
"generator",
"height",
"width",
"latents",
"prior_guidance_scale",
"decoder_guidance_scale",
"negative_prompt",
"num_inference_steps",
"return_dict",
"prior_num_inference_steps",
"output_type",
]
test_xformers_attention = True
@property
def text_embedder_hidden_size(self):
return 32
@property
def dummy_prior(self):
torch.manual_seed(0)
model_kwargs = {
"conditioning_dim": 128,
"block_out_channels": (128, 128),
"num_attention_heads": (2, 2),
"down_num_layers_per_block": (1, 1),
"up_num_layers_per_block": (1, 1),
"clip_image_in_channels": 768,
"switch_level": (False,),
"clip_text_in_channels": self.text_embedder_hidden_size,
"clip_text_pooled_in_channels": self.text_embedder_hidden_size,
}
model = StableCascadeUNet(**model_kwargs)
return model.eval()
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
projection_dim=self.text_embedder_hidden_size,
hidden_size=self.text_embedder_hidden_size,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModelWithProjection(config).eval()
@property
def dummy_vqgan(self):
torch.manual_seed(0)
model_kwargs = {
"bottleneck_blocks": 1,
"num_vq_embeddings": 2,
}
model = PaellaVQModel(**model_kwargs)
return model.eval()
@property
def dummy_decoder(self):
torch.manual_seed(0)
model_kwargs = {
"in_channels": 4,
"out_channels": 4,
"conditioning_dim": 128,
"block_out_channels": (16, 32, 64, 128),
"num_attention_heads": (-1, -1, 1, 2),
"down_num_layers_per_block": (1, 1, 1, 1),
"up_num_layers_per_block": (1, 1, 1, 1),
"down_blocks_repeat_mappers": (1, 1, 1, 1),
"up_blocks_repeat_mappers": (3, 3, 2, 2),
"block_types_per_layer": (
("SDCascadeResBlock", "SDCascadeTimestepBlock"),
("SDCascadeResBlock", "SDCascadeTimestepBlock"),
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
),
"switch_level": None,
"clip_text_pooled_in_channels": 32,
"dropout": (0.1, 0.1, 0.1, 0.1),
}
model = StableCascadeUNet(**model_kwargs)
return model.eval()
def get_dummy_components(self):
prior = self.dummy_prior
scheduler = DDPMWuerstchenScheduler()
tokenizer = self.dummy_tokenizer
text_encoder = self.dummy_text_encoder
decoder = self.dummy_decoder
vqgan = self.dummy_vqgan
prior_text_encoder = self.dummy_text_encoder
prior_tokenizer = self.dummy_tokenizer
components = {
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"decoder": decoder,
"scheduler": scheduler,
"vqgan": vqgan,
"prior_text_encoder": prior_text_encoder,
"prior_tokenizer": prior_tokenizer,
"prior_prior": prior,
"prior_scheduler": scheduler,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "horse",
"generator": generator,
"prior_guidance_scale": 4.0,
"decoder_guidance_scale": 4.0,
"num_inference_steps": 2,
"prior_num_inference_steps": 2,
"output_type": "np",
"height": 128,
"width": 128,
}
return inputs
def test_stable_cascade(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.images
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[-3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
assert (
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe.enable_sequential_cpu_offload()
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe.enable_model_cpu_offload()
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs).images
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=2e-2)
@unittest.skip(reason="fp16 not supported")
def test_float16_inference(self):
super().test_float16_inference()
@unittest.skip(reason="no callback test for combined pipeline")
def test_callback_inputs(self):
super().test_callback_inputs()
# def test_callback_cfg(self):
# pass
# pass

View File

@@ -0,0 +1,249 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, StableCascadeDecoderPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import StableCascadeUNet
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
load_image,
load_pt,
require_torch_gpu,
skip_mps,
slow,
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableCascadeDecoderPipeline
params = ["prompt"]
batch_params = ["image_embeddings", "prompt", "negative_prompt"]
required_optional_params = [
"num_images_per_prompt",
"num_inference_steps",
"latents",
"negative_prompt",
"guidance_scale",
"output_type",
"return_dict",
]
test_xformers_attention = False
callback_cfg_params = ["image_embeddings", "text_encoder_hidden_states"]
@property
def text_embedder_hidden_size(self):
return 32
@property
def time_input_dim(self):
return 32
@property
def block_out_channels_0(self):
return self.time_input_dim
@property
def time_embed_dim(self):
return self.time_input_dim * 4
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
projection_dim=self.text_embedder_hidden_size,
hidden_size=self.text_embedder_hidden_size,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModelWithProjection(config).eval()
@property
def dummy_vqgan(self):
torch.manual_seed(0)
model_kwargs = {
"bottleneck_blocks": 1,
"num_vq_embeddings": 2,
}
model = PaellaVQModel(**model_kwargs)
return model.eval()
@property
def dummy_decoder(self):
torch.manual_seed(0)
model_kwargs = {
"in_channels": 4,
"out_channels": 4,
"conditioning_dim": 128,
"block_out_channels": [16, 32, 64, 128],
"num_attention_heads": [-1, -1, 1, 2],
"down_num_layers_per_block": [1, 1, 1, 1],
"up_num_layers_per_block": [1, 1, 1, 1],
"down_blocks_repeat_mappers": [1, 1, 1, 1],
"up_blocks_repeat_mappers": [3, 3, 2, 2],
"block_types_per_layer": [
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
"switch_level": None,
"clip_text_pooled_in_channels": 32,
"dropout": [0.1, 0.1, 0.1, 0.1],
}
model = StableCascadeUNet(**model_kwargs)
return model.eval()
def get_dummy_components(self):
decoder = self.dummy_decoder
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
vqgan = self.dummy_vqgan
scheduler = DDPMWuerstchenScheduler()
components = {
"decoder": decoder,
"vqgan": vqgan,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"scheduler": scheduler,
"latent_dim_scale": 4.0,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"image_embeddings": torch.ones((1, 4, 4, 4), device=device),
"prompt": "horse",
"generator": generator,
"guidance_scale": 2.0,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def test_wuerstchen_decoder(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.images
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@skip_mps
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-2)
@skip_mps
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
test_mean_pixel_difference = False
self._test_attention_slicing_forward_pass(
test_max_difference=test_max_difference,
test_mean_pixel_difference=test_mean_pixel_difference,
)
@unittest.skip(reason="fp16 not supported")
def test_float16_inference(self):
super().test_float16_inference()
@slow
@require_torch_gpu
class StableCascadeDecoderPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_cascade_decoder(self):
pipe = StableCascadeDecoderPipeline.from_pretrained(
"diffusers/StableCascade-decoder", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
generator = torch.Generator(device="cpu").manual_seed(0)
image_embedding = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt"
)
image = pipe(
prompt=prompt, image_embeddings=image_embedding, num_inference_steps=10, generator=generator
).images[0]
assert image.size == (1024, 1024)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/t2i.png"
)
image_processor = VaeImageProcessor()
image_np = image_processor.pil_to_numpy(image)
expected_image_np = image_processor.pil_to_numpy(expected_image)
self.assertTrue(np.allclose(image_np, expected_image_np, atol=53e-2))

View File

@@ -0,0 +1,308 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline
from diffusers.loaders import AttnProcsLayers
from diffusers.models import StableCascadeUNet
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
enable_full_determinism,
load_pt,
require_peft_backend,
require_torch_gpu,
skip_mps,
slow,
torch_device,
)
if is_peft_available():
from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
def create_prior_lora_layers(unet: nn.Module):
lora_attn_procs = {}
for name in unet.attn_processors.keys():
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=unet.config.c,
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers
class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableCascadePriorPipeline
params = ["prompt"]
batch_params = ["prompt", "negative_prompt"]
required_optional_params = [
"num_images_per_prompt",
"generator",
"num_inference_steps",
"latents",
"negative_prompt",
"guidance_scale",
"output_type",
"return_dict",
]
test_xformers_attention = False
callback_cfg_params = ["text_encoder_hidden_states"]
@property
def text_embedder_hidden_size(self):
return 32
@property
def time_input_dim(self):
return 32
@property
def block_out_channels_0(self):
return self.time_input_dim
@property
def time_embed_dim(self):
return self.time_input_dim * 4
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=self.text_embedder_hidden_size,
projection_dim=self.text_embedder_hidden_size,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModelWithProjection(config).eval()
@property
def dummy_prior(self):
torch.manual_seed(0)
model_kwargs = {
"conditioning_dim": 128,
"block_out_channels": (128, 128),
"num_attention_heads": (2, 2),
"down_num_layers_per_block": (1, 1),
"up_num_layers_per_block": (1, 1),
"switch_level": (False,),
"clip_image_in_channels": 768,
"clip_text_in_channels": self.text_embedder_hidden_size,
"clip_text_pooled_in_channels": self.text_embedder_hidden_size,
"dropout": (0.1, 0.1),
}
model = StableCascadeUNet(**model_kwargs)
return model.eval()
def get_dummy_components(self):
prior = self.dummy_prior
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
scheduler = DDPMWuerstchenScheduler()
components = {
"prior": prior,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"scheduler": scheduler,
"feature_extractor": None,
"image_encoder": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "horse",
"generator": generator,
"guidance_scale": 4.0,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def test_wuerstchen_prior(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.image_embeddings
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
image_slice = image[0, 0, 0, -10:]
image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:]
assert image.shape == (1, 16, 24, 24)
expected_slice = np.array(
[
96.139565,
-20.213179,
-116.40341,
-191.57129,
39.350136,
74.80767,
39.782352,
-184.67352,
-46.426907,
168.41783,
]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 5e-2
@skip_mps
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-1)
@skip_mps
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
test_mean_pixel_difference = False
self._test_attention_slicing_forward_pass(
test_max_difference=test_max_difference,
test_mean_pixel_difference=test_mean_pixel_difference,
)
@unittest.skip(reason="fp16 not supported")
def test_float16_inference(self):
super().test_float16_inference()
def check_if_lora_correctly_set(self, model) -> bool:
"""
Checks if the LoRA layers are correctly set with peft
"""
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
def get_lora_components(self):
prior = self.dummy_prior
prior_lora_config = LoraConfig(
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)
prior_lora_attn_procs, prior_lora_layers = create_prior_lora_layers(prior)
lora_components = {
"prior_lora_layers": prior_lora_layers,
"prior_lora_attn_procs": prior_lora_attn_procs,
}
return prior, prior_lora_config, lora_components
@require_peft_backend
@unittest.skip(reason="no lora support for now")
def test_inference_with_prior_lora(self):
_, prior_lora_config, _ = self.get_lora_components()
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output_no_lora = pipe(**self.get_dummy_inputs(device))
image_embed = output_no_lora.image_embeddings
self.assertTrue(image_embed.shape == (1, 16, 24, 24))
pipe.prior.add_adapter(prior_lora_config)
self.assertTrue(self.check_if_lora_correctly_set(pipe.prior), "Lora not correctly set in prior")
output_lora = pipe(**self.get_dummy_inputs(device))
lora_image_embed = output_lora.image_embeddings
self.assertTrue(image_embed.shape == lora_image_embed.shape)
@slow
@require_torch_gpu
class StableCascadePriorPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_cascade_prior(self):
pipe = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(prompt, num_inference_steps=10, generator=generator)
image_embedding = output.image_embeddings
expected_image_embedding = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt"
)
assert image_embedding.shape == (1, 16, 24, 24)
self.assertTrue(
np.allclose(
image_embedding.cpu().float().numpy(), expected_image_embedding.cpu().float().numpy(), atol=5e-2
)
)

View File

@@ -496,6 +496,22 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results")
def test_disable_cfg(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
inputs["max_guidance_scale"] = 1.0
output = pipe(**inputs).frames
self.assertEqual(len(output.shape), 5)
@slow
@require_torch_gpu

View File

@@ -99,14 +99,13 @@ class SDFunctionTesterMixin:
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
def test_vae_tiling(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# make sure here that pndm scheduler skips prk
if "safety_checker" in components:
components["safety_checker"] = None
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
@@ -126,7 +125,7 @@ class SDFunctionTesterMixin:
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
zeros = torch.zeros(shape).to(device)
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)
def test_freeu_enabled(self):

View File

@@ -45,7 +45,6 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
"return_dict",
"prior_num_inference_steps",
"output_type",
"return_dict",
]
test_xformers_attention = True