mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
15 Commits
support-si
...
fix-vae-ti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c980c53a8a | ||
|
|
a0fcca6dd5 | ||
|
|
c2cc79c2c0 | ||
|
|
ec64d34d5c | ||
|
|
43fbc3aec5 | ||
|
|
196835695e | ||
|
|
0d4dfbbd0a | ||
|
|
ada3bb941b | ||
|
|
b5814c5555 | ||
|
|
9940573618 | ||
|
|
59433ca1ae | ||
|
|
534f5d54fa | ||
|
|
40aa47b998 | ||
|
|
1bc0d37ffe | ||
|
|
eb942b866a |
@@ -318,6 +318,8 @@
|
||||
title: Semantic Guidance
|
||||
- local: api/pipelines/shap_e
|
||||
title: Shap-E
|
||||
- local: api/pipelines/stable_cascade
|
||||
title: Stable Cascade
|
||||
- sections:
|
||||
- local: api/pipelines/stable_diffusion/overview
|
||||
title: Overview
|
||||
|
||||
88
docs/source/en/api/pipelines/stable_cascade.md
Normal file
88
docs/source/en/api/pipelines/stable_cascade.md
Normal file
@@ -0,0 +1,88 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Stable Cascade
|
||||
|
||||
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
|
||||
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
|
||||
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
|
||||
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
|
||||
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
|
||||
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
|
||||
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
|
||||
Diffusion 1.5.
|
||||
|
||||
Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions
|
||||
like finetuning, LoRA, ControlNet, IP-Adapter, LCM etc. are possible with this method as well.
|
||||
|
||||
The original codebase can be found at [Stability-AI/StableCascade](https://github.com/Stability-AI/StableCascade).
|
||||
|
||||
## Model Overview
|
||||
Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images,
|
||||
hence the name "Stable Cascade".
|
||||
|
||||
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
|
||||
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
|
||||
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
|
||||
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
|
||||
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
|
||||
for generating the small 24 x 24 latents given a text prompt.
|
||||
|
||||
## Uses
|
||||
|
||||
### Direct Use
|
||||
|
||||
The model is intended for research purposes for now. Possible research areas and tasks include
|
||||
|
||||
- Research on generative models.
|
||||
- Safe deployment of models which have the potential to generate harmful content.
|
||||
- Probing and understanding the limitations and biases of generative models.
|
||||
- Generation of artworks and use in design and other artistic processes.
|
||||
- Applications in educational or creative tools.
|
||||
|
||||
Excluded uses are described below.
|
||||
|
||||
### Out-of-Scope Use
|
||||
|
||||
The model was not trained to be factual or true representations of people or events,
|
||||
and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
||||
The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
|
||||
|
||||
## Limitations and Bias
|
||||
|
||||
### Limitations
|
||||
- Faces and people in general may not be generated properly.
|
||||
- The autoencoding part of the model is lossy.
|
||||
|
||||
|
||||
## StableCascadeCombinedPipeline
|
||||
|
||||
[[autodoc]] StableCascadeCombinedPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableCascadePriorPipeline
|
||||
|
||||
[[autodoc]] StableCascadePriorPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableCascadePriorPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.stable_cascade.pipeline_stable_cascade_prior.StableCascadePriorPipelineOutput
|
||||
|
||||
## StableCascadeDecoderPipeline
|
||||
|
||||
[[autodoc]] StableCascadeDecoderPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -666,7 +666,6 @@ def parse_args(input_args=None):
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dora",
|
||||
type=bool,
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
|
||||
@@ -15,18 +15,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.models.attention import Attention
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPipelineOutput,
|
||||
from packaging import version
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.configuration_utils import FrozenDict, deprecate
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import (
|
||||
FromSingleFileMixin,
|
||||
IPAdapterMixin,
|
||||
LoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from diffusers.models.attention import Attention
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
@@ -43,34 +72,486 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class Prompt2PromptPipeline(StableDiffusionPipeline):
|
||||
class Prompt2PromptPipeline(
|
||||
DiffusionPipeline,
|
||||
TextualInversionLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
IPAdapterMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
Prompt-to-Prompt-Pipeline for text-to-image generation using Stable Diffusion. This model inherits from
|
||||
[`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for
|
||||
all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler
|
||||
([`SchedulerMixin`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`~transformers.CLIPTextModel`]):
|
||||
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
tokenizer ([`~transformers.CLIPTokenizer`]):
|
||||
A `CLIPTokenizer` to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: process multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# Access the `hidden_states` first, that contains a tuple of
|
||||
# all the hidden states from the encoder layers. Then index into
|
||||
# the tuple to access the hidden states from the desired layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
||||
# We also need to apply the final LayerNorm here to not mess with the
|
||||
# representations. The `last_hidden_states` that we typically use for
|
||||
# obtaining the final prompt representations passes through the LayerNorm
|
||||
# layer.
|
||||
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: process multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
ip_adapter_image=None,
|
||||
ip_adapter_image_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
|
||||
@@ -243,3 +243,29 @@ accelerate launch train_dreambooth_lora_sdxl.py \
|
||||
|
||||
> [!CAUTION]
|
||||
> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any "variant".
|
||||
|
||||
### DoRA training
|
||||
The script now supports DoRA training too!
|
||||
> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://arxiv.org/abs/2402.09353),
|
||||
**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.
|
||||
The authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.
|
||||
|
||||
> [!NOTE]
|
||||
> 💡DoRA training is still _experimental_
|
||||
> and is likely to require different hyperparameter values to perform best compared to a LoRA.
|
||||
> Specifically, we've noticed 2 differences to take into account your training:
|
||||
> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA)
|
||||
> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example.
|
||||
> This is also aligned with some of the quantitative analysis shown in the paper.
|
||||
|
||||
**Usage**
|
||||
1. To use DoRA you need to install `peft` from main:
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
2. Enable DoRA training by adding this flag
|
||||
```bash
|
||||
--use_dora
|
||||
```
|
||||
**Inference**
|
||||
The inference is the same as if you train a regular LoRA 🤗
|
||||
@@ -647,6 +647,15 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dora",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
|
||||
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
|
||||
),
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -868,6 +877,8 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
if with_prior_preservation:
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
prompts += [example["class_prompt"] for example in examples]
|
||||
original_sizes += [example["original_size"] for example in examples]
|
||||
crop_top_lefts += [example["crop_top_left"] for example in examples]
|
||||
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
@@ -1147,6 +1158,7 @@ def main(args):
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
@@ -1158,6 +1170,7 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
|
||||
@@ -61,6 +61,34 @@ accelerate launch train_diffusion_dpo_sdxl.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## SDXL Turbo training command
|
||||
|
||||
```bash
|
||||
accelerate launch train_diffusion_dpo_sdxl.py \
|
||||
--pretrained_model_name_or_path=stabilityai/sdxl-turbo \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--output_dir="diffusion-sdxl-turbo-dpo" \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=kashif/pickascore \
|
||||
--train_batch_size=8 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--rank=8 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=2000 \
|
||||
--checkpointing_steps=500 \
|
||||
--run_validation --validation_steps=50 \
|
||||
--seed="0" \
|
||||
--report_to="wandb" \
|
||||
--is_turbo --resolution 512 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
This is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/.
|
||||
|
||||
@@ -118,9 +118,16 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v
|
||||
images = []
|
||||
context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()
|
||||
|
||||
guidance_scale = 5.0
|
||||
num_inference_steps = 25
|
||||
if args.is_turbo:
|
||||
guidance_scale = 0.0
|
||||
num_inference_steps = 4
|
||||
for prompt in VALIDATION_PROMPTS:
|
||||
with context:
|
||||
image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
image = pipeline(
|
||||
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
tracker_key = "test" if is_final_validation else "validation"
|
||||
@@ -141,7 +148,10 @@ def log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_v
|
||||
if is_final_validation:
|
||||
pipeline.disable_lora()
|
||||
no_lora_images = [
|
||||
pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS
|
||||
pipeline(
|
||||
prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
|
||||
).images[0]
|
||||
for prompt in VALIDATION_PROMPTS
|
||||
]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
@@ -423,6 +433,11 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is_turbo",
|
||||
action="store_true",
|
||||
help=("Use if tuning SDXL Turbo instead of SDXL"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
@@ -444,6 +459,9 @@ def parse_args(input_args=None):
|
||||
if args.dataset_name is None:
|
||||
raise ValueError("Must provide a `dataset_name`.")
|
||||
|
||||
if args.is_turbo:
|
||||
assert "turbo" in args.pretrained_model_name_or_path
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
@@ -560,6 +578,36 @@ def main(args):
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
def enforce_zero_terminal_snr(scheduler):
|
||||
# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L93
|
||||
# Original implementation https://arxiv.org/pdf/2305.08891.pdf
|
||||
# Turbo needs zero terminal SNR
|
||||
# Turbo: https://static1.squarespace.com/static/6213c340453c3f502425776e/t/65663480a92fba51d0e1023f/1701197769659/adversarial_diffusion_distillation.pdf
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1 - scheduler.betas
|
||||
alphas_bar = alphas.cumprod(0)
|
||||
alphas_bar_sqrt = alphas_bar.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
# Shift so last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
# Scale so first timestep is back to old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
alphas_bar = alphas_bar_sqrt**2
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
scheduler.alphas_cumprod = alphas_cumprod
|
||||
return
|
||||
|
||||
if args.is_turbo:
|
||||
enforce_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
|
||||
)
|
||||
@@ -909,6 +957,10 @@ def main(args):
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long
|
||||
).repeat(2)
|
||||
if args.is_turbo:
|
||||
# Learn a 4 timestep schedule
|
||||
timesteps_0_to_3 = timesteps % 4
|
||||
timesteps = 250 * timesteps_0_to_3 + 249
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
|
||||
215
scripts/convert_stable_cascade.py
Normal file
215
scripts/convert_stable_cascade.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
|
||||
import argparse
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
CLIPConfig,
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
DDPMWuerstchenScheduler,
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
)
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
|
||||
parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights")
|
||||
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
|
||||
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
|
||||
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
|
||||
parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.model_path
|
||||
|
||||
device = "cpu"
|
||||
|
||||
# set paths to model weights
|
||||
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
|
||||
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
|
||||
|
||||
# Clip Text encoder and tokenizer
|
||||
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
config.text_config.projection_dim = config.projection_dim
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
|
||||
# image processor
|
||||
feature_extractor = CLIPImageProcessor()
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
# Prior
|
||||
if args.use_safetensors:
|
||||
orig_state_dict = load_file(prior_checkpoint_path, device=device)
|
||||
else:
|
||||
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
|
||||
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
prior_model = StableCascadeUNet(
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=1,
|
||||
conditioning_dim=2048,
|
||||
block_out_channels=[2048, 2048],
|
||||
num_attention_heads=[32, 32],
|
||||
down_num_layers_per_block=[8, 24],
|
||||
up_num_layers_per_block=[24, 8],
|
||||
down_blocks_repeat_mappers=[1, 1],
|
||||
up_blocks_repeat_mappers=[1, 1],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_in_channels=1280,
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_image_in_channels=768,
|
||||
clip_seq=4,
|
||||
kernel_size=3,
|
||||
dropout=[0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca", "crp"],
|
||||
switch_level=[False],
|
||||
)
|
||||
load_model_dict_into_meta(prior_model, state_dict)
|
||||
|
||||
# scheduler for prior and decoder
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
|
||||
# Prior pipeline
|
||||
prior_pipeline = StableCascadePriorPipeline(
|
||||
prior=prior_model,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub)
|
||||
|
||||
# Decoder
|
||||
if args.use_safetensors:
|
||||
orig_state_dict = load_file(decoder_checkpoint_path, device=device)
|
||||
else:
|
||||
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
|
||||
|
||||
state_dict = {}
|
||||
for key in orig_state_dict.keys():
|
||||
if key.endswith("in_proj_weight"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
||||
elif key.endswith("in_proj_bias"):
|
||||
weights = orig_state_dict[key].chunk(3, 0)
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
||||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
||||
elif key.endswith("out_proj.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
||||
elif key.endswith("out_proj.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
||||
# rename clip_mapper to clip_txt_pooled_mapper
|
||||
elif key.endswith("clip_mapper.weight"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
|
||||
elif key.endswith("clip_mapper.bias"):
|
||||
weights = orig_state_dict[key]
|
||||
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
|
||||
else:
|
||||
state_dict[key] = orig_state_dict[key]
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
decoder = StableCascadeUNet(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
timestep_ratio_embedding_dim=64,
|
||||
patch_size=2,
|
||||
conditioning_dim=1280,
|
||||
block_out_channels=[320, 640, 1280, 1280],
|
||||
down_num_layers_per_block=[2, 6, 28, 6],
|
||||
up_num_layers_per_block=[6, 28, 6, 2],
|
||||
down_blocks_repeat_mappers=[1, 1, 1, 1],
|
||||
up_blocks_repeat_mappers=[3, 3, 2, 2],
|
||||
num_attention_heads=[0, 0, 20, 20],
|
||||
block_types_per_layer=[
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_seq=4,
|
||||
effnet_in_channels=16,
|
||||
pixel_mapper_in_channels=3,
|
||||
kernel_size=3,
|
||||
dropout=[0, 0, 0.1, 0.1],
|
||||
self_attn=True,
|
||||
timestep_conditioning_type=["sca"],
|
||||
)
|
||||
load_model_dict_into_meta(decoder, state_dict)
|
||||
|
||||
# VQGAN from Wuerstchen-V2
|
||||
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
|
||||
|
||||
# Decoder pipeline
|
||||
decoder_pipeline = StableCascadeDecoderPipeline(
|
||||
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
|
||||
)
|
||||
decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub)
|
||||
|
||||
# Stable Cascade combined pipeline
|
||||
stable_cascade_pipeline = StableCascadeCombinedPipeline(
|
||||
# Decoder
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqmodel,
|
||||
# Prior
|
||||
prior_text_encoder=text_encoder,
|
||||
prior_tokenizer=tokenizer,
|
||||
prior_prior=prior_model,
|
||||
prior_scheduler=scheduler,
|
||||
prior_image_encoder=image_encoder,
|
||||
prior_feature_extractor=feature_extractor,
|
||||
)
|
||||
stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub)
|
||||
@@ -86,6 +86,7 @@ else:
|
||||
"MotionAdapter",
|
||||
"MultiAdapter",
|
||||
"PriorTransformer",
|
||||
"StableCascadeUNet",
|
||||
"T2IAdapter",
|
||||
"T5FilmDecoder",
|
||||
"Transformer2DModel",
|
||||
@@ -259,6 +260,9 @@ else:
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
"StableCascadeCombinedPipeline",
|
||||
"StableCascadeDecoderPipeline",
|
||||
"StableCascadePriorPipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
@@ -626,6 +630,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
|
||||
@@ -143,5 +143,4 @@ class FromOriginalVAEMixin:
|
||||
if torch_dtype is not None:
|
||||
vae = vae.to(torch_dtype)
|
||||
|
||||
vae.eval()
|
||||
return vae
|
||||
|
||||
@@ -133,5 +133,4 @@ class FromOriginalControlNetMixin:
|
||||
if torch_dtype is not None:
|
||||
controlnet = controlnet.to(torch_dtype)
|
||||
|
||||
controlnet.eval()
|
||||
return controlnet
|
||||
|
||||
@@ -63,13 +63,20 @@ def build_sub_model_components(
|
||||
num_in_channels=num_in_channels,
|
||||
image_size=image_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type=model_type,
|
||||
)
|
||||
return unet_components
|
||||
|
||||
if component_name == "vae":
|
||||
scaling_factor = kwargs.get("scaling_factor", None)
|
||||
vae_components = create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size, scaling_factor, torch_dtype
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
image_size,
|
||||
scaling_factor,
|
||||
torch_dtype,
|
||||
model_type=model_type,
|
||||
)
|
||||
return vae_components
|
||||
|
||||
@@ -124,11 +131,12 @@ def build_sub_model_components(
|
||||
def set_additional_components(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint=None,
|
||||
model_type=None,
|
||||
):
|
||||
components = {}
|
||||
if pipeline_class_name in REFINER_PIPELINES:
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
||||
is_refiner = model_type == "SDXL-Refiner"
|
||||
components.update(
|
||||
{
|
||||
|
||||
@@ -28,6 +28,7 @@ from ..schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EDMDPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
@@ -175,6 +176,7 @@ DIFFUSERS_TO_LDM_MAPPING = {
|
||||
|
||||
LDM_VAE_KEY = "first_stage_model."
|
||||
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
||||
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
||||
LDM_UNET_KEY = "model.diffusion_model."
|
||||
LDM_CONTROLNET_KEY = "control_model."
|
||||
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
||||
@@ -305,7 +307,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
|
||||
return original_config
|
||||
|
||||
|
||||
def infer_model_type(original_config, model_type=None):
|
||||
def infer_model_type(original_config, checkpoint=None, model_type=None):
|
||||
if model_type is not None:
|
||||
return model_type
|
||||
|
||||
@@ -323,7 +325,9 @@ def infer_model_type(original_config, model_type=None):
|
||||
|
||||
elif has_network_config:
|
||||
context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"]
|
||||
if context_dim == 2048:
|
||||
if "edm_mean" in checkpoint and "edm_std" in checkpoint:
|
||||
model_type = "Playground"
|
||||
elif context_dim == 2048:
|
||||
model_type = "SDXL"
|
||||
else:
|
||||
model_type = "SDXL-Refiner"
|
||||
@@ -344,13 +348,13 @@ def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=
|
||||
return image_size
|
||||
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
model_type = infer_model_type(original_config, model_type)
|
||||
model_type = infer_model_type(original_config, checkpoint, model_type)
|
||||
|
||||
if pipeline_class_name == "StableDiffusionUpscalePipeline":
|
||||
image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
|
||||
return image_size
|
||||
|
||||
elif model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
elif model_type in ["SDXL", "SDXL-Refiner", "Playground"]:
|
||||
image_size = 1024
|
||||
return image_size
|
||||
|
||||
@@ -506,12 +510,14 @@ def create_controlnet_diffusers_config(original_config, image_size: int):
|
||||
return controlnet_config
|
||||
|
||||
|
||||
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None):
|
||||
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None, latents_mean=None, latents_std=None):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
if scaling_factor is None and "scale_factor" in original_config["model"]["params"]:
|
||||
if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
|
||||
scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
|
||||
elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
|
||||
scaling_factor = original_config["model"]["params"]["scale_factor"]
|
||||
elif scaling_factor is None:
|
||||
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
|
||||
@@ -531,6 +537,8 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
|
||||
"layers_per_block": vae_params["num_res_blocks"],
|
||||
"scaling_factor": scaling_factor,
|
||||
}
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
config.update({"latents_mean": latents_mean, "latents_std": latents_std})
|
||||
|
||||
return config
|
||||
|
||||
@@ -1172,6 +1180,7 @@ def create_diffusers_unet_model_from_ldm(
|
||||
extract_ema=False,
|
||||
image_size=None,
|
||||
torch_dtype=None,
|
||||
model_type=None,
|
||||
):
|
||||
from ..models import UNet2DConditionModel
|
||||
|
||||
@@ -1190,7 +1199,9 @@ def create_diffusers_unet_model_from_ldm(
|
||||
else:
|
||||
num_in_channels = 4
|
||||
|
||||
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
|
||||
image_size = set_image_size(
|
||||
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
|
||||
)
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["in_channels"] = num_in_channels
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
@@ -1223,14 +1234,40 @@ def create_diffusers_unet_model_from_ldm(
|
||||
|
||||
|
||||
def create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None, torch_dtype=None
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
image_size=None,
|
||||
scaling_factor=None,
|
||||
torch_dtype=None,
|
||||
model_type=None,
|
||||
):
|
||||
# import here to avoid circular imports
|
||||
from ..models import AutoencoderKL
|
||||
|
||||
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
|
||||
image_size = set_image_size(
|
||||
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
|
||||
)
|
||||
model_type = infer_model_type(original_config, checkpoint, model_type)
|
||||
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size, scaling_factor=scaling_factor)
|
||||
if model_type == "Playground":
|
||||
edm_mean = (
|
||||
checkpoint["edm_mean"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_mean"].tolist()
|
||||
)
|
||||
edm_std = (
|
||||
checkpoint["edm_std"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_std"].tolist()
|
||||
)
|
||||
else:
|
||||
edm_mean = None
|
||||
edm_std = None
|
||||
|
||||
vae_config = create_vae_diffusers_config(
|
||||
original_config,
|
||||
image_size=image_size,
|
||||
scaling_factor=scaling_factor,
|
||||
latents_mean=edm_mean,
|
||||
latents_std=edm_std,
|
||||
)
|
||||
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
@@ -1265,7 +1302,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
||||
local_files_only=False,
|
||||
torch_dtype=None,
|
||||
):
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
config_name = "stabilityai/stable-diffusion-2"
|
||||
@@ -1332,7 +1369,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
||||
"text_encoder_2": text_encoder_2,
|
||||
}
|
||||
|
||||
elif model_type == "SDXL":
|
||||
elif model_type in ["SDXL", "Playground"]:
|
||||
try:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
@@ -1383,7 +1420,7 @@ def create_scheduler_from_ldm(
|
||||
model_type=None,
|
||||
):
|
||||
scheduler_config = get_default_scheduler_config()
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
||||
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
|
||||
@@ -1406,7 +1443,8 @@ def create_scheduler_from_ldm(
|
||||
|
||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
scheduler_type = "euler"
|
||||
|
||||
elif model_type == "Playground":
|
||||
scheduler_type = "edm_dpm_solver_multistep"
|
||||
else:
|
||||
beta_start = original_config["model"]["params"].get("linear_start", 0.02)
|
||||
beta_end = original_config["model"]["params"].get("linear_end", 0.085)
|
||||
@@ -1438,6 +1476,26 @@ def create_scheduler_from_ldm(
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = DDIMScheduler.from_config(scheduler_config)
|
||||
|
||||
elif scheduler_type == "edm_dpm_solver_multistep":
|
||||
scheduler_config = {
|
||||
"algorithm_type": "dpmsolver++",
|
||||
"dynamic_thresholding_ratio": 0.995,
|
||||
"euler_at_final": False,
|
||||
"final_sigmas_type": "zero",
|
||||
"lower_order_final": True,
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"rho": 7.0,
|
||||
"sample_max_value": 1.0,
|
||||
"sigma_data": 0.5,
|
||||
"sigma_max": 80.0,
|
||||
"sigma_min": 0.002,
|
||||
"solver_order": 2,
|
||||
"solver_type": "midpoint",
|
||||
"thresholding": False,
|
||||
}
|
||||
scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ if is_torch_available():
|
||||
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
|
||||
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
|
||||
_import_structure["vq_model"] = ["VQModel"]
|
||||
|
||||
@@ -80,6 +81,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
MotionAdapter,
|
||||
StableCascadeUNet,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
|
||||
@@ -440,7 +440,6 @@ class TemporalBasicTransformerBlock(nn.Module):
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
self.norm_in = nn.LayerNorm(dim)
|
||||
self.ff_in = FeedForward(
|
||||
dim,
|
||||
dim_out=time_mix_inner_dim,
|
||||
|
||||
@@ -39,7 +39,6 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_single_file_checkpoint,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
@@ -49,8 +48,6 @@ from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populat
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
SINGLE_FILE_LOADABLE_CLASSES = {"ControlNetModel", "AutoencoderKL"}
|
||||
|
||||
if is_torch_version(">=", "1.9.0"):
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
||||
else:
|
||||
@@ -500,90 +497,102 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
"""
|
||||
if is_single_file_checkpoint(pretrained_model_name_or_path):
|
||||
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
|
||||
)
|
||||
logger.info("Single file checkpoint detected...")
|
||||
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
|
||||
model = model.eval()
|
||||
return model
|
||||
else:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_accelerate_available():
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
||||
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
||||
)
|
||||
if device_map is not None and not is_accelerate_available():
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
||||
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
||||
)
|
||||
|
||||
# Check if we can handle device_map and dispatching the weights
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
# Check if we can handle device_map and dispatching the weights
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
# load config
|
||||
config, unused_kwargs, commit_hash = cls.load_config(
|
||||
config_path,
|
||||
# load config
|
||||
config, unused_kwargs, commit_hash = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
user_agent=user_agent,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# load model
|
||||
model_file = None
|
||||
if from_flax:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=FLAX_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
@@ -591,62 +600,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
user_agent=user_agent,
|
||||
**kwargs,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
# load model
|
||||
model_file = None
|
||||
if from_flax:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=FLAX_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
# Convert the weights
|
||||
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
||||
|
||||
# Convert the weights
|
||||
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||
else:
|
||||
if use_safetensors:
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
except IOError as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
pass
|
||||
if model_file is None:
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||
else:
|
||||
if use_safetensors:
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
@@ -658,48 +626,95 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
except IOError as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
pass
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
# move the params from meta device to cpu
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(
|
||||
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
||||
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
||||
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
||||
" those weights or else make sure your checkpoint file is correct."
|
||||
)
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict,
|
||||
device=param_device,
|
||||
dtype=torch_dtype,
|
||||
model_name_or_path=pretrained_model_name_or_path,
|
||||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
# move the params from meta device to cpu
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(
|
||||
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
||||
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
||||
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
||||
" those weights or else make sure your checkpoint file is correct."
|
||||
)
|
||||
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict,
|
||||
device=param_device,
|
||||
dtype=torch_dtype,
|
||||
model_name_or_path=pretrained_model_name_or_path,
|
||||
)
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by default the device_map is None and the weights are loaded on the CPU
|
||||
try:
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_file,
|
||||
device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
except AttributeError as e:
|
||||
# When using accelerate loading, we do not have the ability to load the state
|
||||
# dict and rename the weight names manually. Additionally, accelerate skips
|
||||
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
||||
# (which look like they should be private variables?), so we can't use the standard hooks
|
||||
# to rename parameters on load. We need to mimic the original weight names so the correct
|
||||
# attributes are available. After we have loaded the weights, we convert the deprecated
|
||||
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
||||
# the weights so we don't have to do this again.
|
||||
|
||||
if "'Attention' object has no attribute" in str(e):
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
||||
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
||||
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
||||
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
||||
" please also re-upload it or open a PR on the original repository."
|
||||
)
|
||||
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by default the device_map is None and the weights are loaded on the CPU
|
||||
try:
|
||||
model._temp_convert_self_to_deprecated_attention_blocks()
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_file,
|
||||
@@ -709,80 +724,52 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
except AttributeError as e:
|
||||
# When using accelerate loading, we do not have the ability to load the state
|
||||
# dict and rename the weight names manually. Additionally, accelerate skips
|
||||
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
||||
# (which look like they should be private variables?), so we can't use the standard hooks
|
||||
# to rename parameters on load. We need to mimic the original weight names so the correct
|
||||
# attributes are available. After we have loaded the weights, we convert the deprecated
|
||||
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
||||
# the weights so we don't have to do this again.
|
||||
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
||||
else:
|
||||
raise e
|
||||
|
||||
if "'Attention' object has no attribute" in str(e):
|
||||
logger.warn(
|
||||
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
||||
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
||||
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
||||
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
||||
" please also re-upload it or open a PR on the original repository."
|
||||
)
|
||||
model._temp_convert_self_to_deprecated_attention_blocks()
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_file,
|
||||
device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
||||
else:
|
||||
raise e
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
else:
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
else:
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
model_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"mismatched_keys": mismatched_keys,
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
model_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
elif torch_dtype is not None:
|
||||
model = model.to(torch_dtype)
|
||||
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"mismatched_keys": mismatched_keys,
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
if output_loading_info:
|
||||
return model, loading_info
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
||||
)
|
||||
elif torch_dtype is not None:
|
||||
model = model.to(torch_dtype)
|
||||
|
||||
return model
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
if output_loading_info:
|
||||
return model, loading_info
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
|
||||
@@ -10,6 +10,7 @@ if is_torch_available():
|
||||
from .unet_kandinsky3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .unet_stable_cascade import StableCascadeUNet
|
||||
from .uvit_2d import UVit2DModel
|
||||
|
||||
|
||||
|
||||
609
src/diffusers/models/unets/unet_stable_cascade.py
Normal file
609
src/diffusers/models/unets/unet_stable_cascade.py
Normal file
@@ -0,0 +1,609 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..attention_processor import Attention
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.wuerstchen.modeling_wuerstchen_common.WuerstchenLayerNorm with WuerstchenLayerNorm -> SDCascadeLayerNorm
|
||||
class SDCascadeLayerNorm(nn.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = super().forward(x)
|
||||
return x.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
class SDCascadeTimestepBlock(nn.Module):
|
||||
def __init__(self, c, c_timestep, conds=[]):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
self.mapper = linear_cls(c_timestep, c * 2)
|
||||
self.conds = conds
|
||||
for cname in conds:
|
||||
setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2))
|
||||
|
||||
def forward(self, x, t):
|
||||
t = t.chunk(len(self.conds) + 1, dim=1)
|
||||
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
|
||||
for i, c in enumerate(self.conds):
|
||||
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
|
||||
a, b = a + ac, b + bc
|
||||
return x * (1 + a) + b
|
||||
|
||||
|
||||
class SDCascadeResBlock(nn.Module):
|
||||
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
|
||||
super().__init__()
|
||||
self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
||||
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.channelwise = nn.Sequential(
|
||||
nn.Linear(c + c_skip, c * 4),
|
||||
nn.GELU(),
|
||||
GlobalResponseNorm(c * 4),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(c * 4, c),
|
||||
)
|
||||
|
||||
def forward(self, x, x_skip=None):
|
||||
x_res = x
|
||||
x = self.norm(self.depthwise(x))
|
||||
if x_skip is not None:
|
||||
x = torch.cat([x, x_skip], dim=1)
|
||||
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x + x_res
|
||||
|
||||
|
||||
# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
||||
class GlobalResponseNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
||||
|
||||
def forward(self, x):
|
||||
agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * stand_div_norm) + self.beta + x
|
||||
|
||||
|
||||
class SDCascadeAttnBlock(nn.Module):
|
||||
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.self_attn = self_attn
|
||||
self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
|
||||
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
|
||||
|
||||
def forward(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
norm_x = self.norm(x)
|
||||
if self.self_attn:
|
||||
batch_size, channel, _, _ = x.shape
|
||||
kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1)
|
||||
x = x + self.attention(norm_x, encoder_hidden_states=kv)
|
||||
return x
|
||||
|
||||
|
||||
class UpDownBlock2d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, mode, enabled=True):
|
||||
super().__init__()
|
||||
if mode not in ["up", "down"]:
|
||||
raise ValueError(f"{mode} not supported")
|
||||
interpolation = (
|
||||
nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True)
|
||||
if enabled
|
||||
else nn.Identity()
|
||||
)
|
||||
mapping = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||
self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation])
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class StableCascadeUNetOutput(BaseOutput):
|
||||
sample: torch.FloatTensor = None
|
||||
|
||||
|
||||
class StableCascadeUNet(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
timestep_ratio_embedding_dim: int = 64,
|
||||
patch_size: int = 1,
|
||||
conditioning_dim: int = 2048,
|
||||
block_out_channels: Tuple[int] = (2048, 2048),
|
||||
num_attention_heads: Tuple[int] = (32, 32),
|
||||
down_num_layers_per_block: Tuple[int] = (8, 24),
|
||||
up_num_layers_per_block: Tuple[int] = (24, 8),
|
||||
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
|
||||
1,
|
||||
1,
|
||||
),
|
||||
up_blocks_repeat_mappers: Optional[Tuple[int]] = (1, 1),
|
||||
block_types_per_layer: Tuple[Tuple[str]] = (
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
),
|
||||
clip_text_in_channels: Optional[int] = None,
|
||||
clip_text_pooled_in_channels=1280,
|
||||
clip_image_in_channels: Optional[int] = None,
|
||||
clip_seq=4,
|
||||
effnet_in_channels: Optional[int] = None,
|
||||
pixel_mapper_in_channels: Optional[int] = None,
|
||||
kernel_size=3,
|
||||
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
|
||||
self_attn: Union[bool, Tuple[bool]] = True,
|
||||
timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
|
||||
switch_level: Optional[Tuple[bool]] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`, defaults to 16):
|
||||
Number of channels in the input sample.
|
||||
out_channels (`int`, defaults to 16):
|
||||
Number of channels in the output sample.
|
||||
timestep_ratio_embedding_dim (`int`, defaults to 64):
|
||||
Dimension of the projected time embedding.
|
||||
patch_size (`int`, defaults to 1):
|
||||
Patch size to use for pixel unshuffling layer
|
||||
conditioning_dim (`int`, defaults to 2048):
|
||||
Dimension of the image and text conditional embedding.
|
||||
block_out_channels (Tuple[int], defaults to (2048, 2048)):
|
||||
Tuple of output channels for each block.
|
||||
num_attention_heads (Tuple[int], defaults to (32, 32)):
|
||||
Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have attention.
|
||||
down_num_layers_per_block (Tuple[int], defaults to [8, 24]):
|
||||
Number of layers in each down block.
|
||||
up_num_layers_per_block (Tuple[int], defaults to [24, 8]):
|
||||
Number of layers in each up block.
|
||||
down_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]):
|
||||
Number of 1x1 Convolutional layers to repeat in each down block.
|
||||
up_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]):
|
||||
Number of 1x1 Convolutional layers to repeat in each up block.
|
||||
block_types_per_layer (Tuple[Tuple[str]], optional,
|
||||
defaults to (
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock")
|
||||
):
|
||||
Block types used in each layer of the up/down blocks.
|
||||
clip_text_in_channels (`int`, *optional*, defaults to `None`):
|
||||
Number of input channels for CLIP based text conditioning.
|
||||
clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280):
|
||||
Number of input channels for pooled CLIP text embeddings.
|
||||
clip_image_in_channels (`int`, *optional*):
|
||||
Number of input channels for CLIP based image conditioning.
|
||||
clip_seq (`int`, *optional*, defaults to 4):
|
||||
effnet_in_channels (`int`, *optional*, defaults to `None`):
|
||||
Number of input channels for effnet conditioning.
|
||||
pixel_mapper_in_channels (`int`, defaults to `None`):
|
||||
Number of input channels for pixel mapper conditioning.
|
||||
kernel_size (`int`, *optional*, defaults to 3):
|
||||
Kernel size to use in the block convolutional layers.
|
||||
dropout (Tuple[float], *optional*, defaults to (0.1, 0.1)):
|
||||
Dropout to use per block.
|
||||
self_attn (Union[bool, Tuple[bool]]):
|
||||
Tuple of booleans that determine whether to use self attention in a block or not.
|
||||
timestep_conditioning_type (Tuple[str], defaults to ("sca", "crp")):
|
||||
Timestep conditioning type.
|
||||
switch_level (Optional[Tuple[bool]], *optional*, defaults to `None`):
|
||||
Tuple that indicates whether upsampling or downsampling should be applied in a block
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
if len(block_out_channels) != len(down_num_layers_per_block):
|
||||
raise ValueError(
|
||||
f"Number of elements in `down_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
elif len(block_out_channels) != len(up_num_layers_per_block):
|
||||
raise ValueError(
|
||||
f"Number of elements in `up_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
elif len(block_out_channels) != len(down_blocks_repeat_mappers):
|
||||
raise ValueError(
|
||||
f"Number of elements in `down_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
elif len(block_out_channels) != len(up_blocks_repeat_mappers):
|
||||
raise ValueError(
|
||||
f"Number of elements in `up_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
elif len(block_out_channels) != len(block_types_per_layer):
|
||||
raise ValueError(
|
||||
f"Number of elements in `block_types_per_layer` must match the length of `block_out_channels`: {len(block_out_channels)}"
|
||||
)
|
||||
|
||||
if isinstance(dropout, float):
|
||||
dropout = (dropout,) * len(block_out_channels)
|
||||
if isinstance(self_attn, bool):
|
||||
self_attn = (self_attn,) * len(block_out_channels)
|
||||
|
||||
# CONDITIONING
|
||||
if effnet_in_channels is not None:
|
||||
self.effnet_mapper = nn.Sequential(
|
||||
nn.Conv2d(effnet_in_channels, block_out_channels[0] * 4, kernel_size=1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1),
|
||||
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
|
||||
)
|
||||
if pixel_mapper_in_channels is not None:
|
||||
self.pixels_mapper = nn.Sequential(
|
||||
nn.Conv2d(pixel_mapper_in_channels, block_out_channels[0] * 4, kernel_size=1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1),
|
||||
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
|
||||
)
|
||||
|
||||
self.clip_txt_pooled_mapper = nn.Linear(clip_text_pooled_in_channels, conditioning_dim * clip_seq)
|
||||
if clip_text_in_channels is not None:
|
||||
self.clip_txt_mapper = nn.Linear(clip_text_in_channels, conditioning_dim)
|
||||
if clip_image_in_channels is not None:
|
||||
self.clip_img_mapper = nn.Linear(clip_image_in_channels, conditioning_dim * clip_seq)
|
||||
self.clip_norm = nn.LayerNorm(conditioning_dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.embedding = nn.Sequential(
|
||||
nn.PixelUnshuffle(patch_size),
|
||||
nn.Conv2d(in_channels * (patch_size**2), block_out_channels[0], kernel_size=1),
|
||||
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
|
||||
)
|
||||
|
||||
def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=True):
|
||||
if block_type == "SDCascadeResBlock":
|
||||
return SDCascadeResBlock(in_channels, c_skip, kernel_size=kernel_size, dropout=dropout)
|
||||
elif block_type == "SDCascadeAttnBlock":
|
||||
return SDCascadeAttnBlock(in_channels, conditioning_dim, nhead, self_attn=self_attn, dropout=dropout)
|
||||
elif block_type == "SDCascadeTimestepBlock":
|
||||
return SDCascadeTimestepBlock(
|
||||
in_channels, timestep_ratio_embedding_dim, conds=timestep_conditioning_type
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Block type {block_type} not supported")
|
||||
|
||||
# BLOCKS
|
||||
# -- down blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
self.down_downscalers = nn.ModuleList()
|
||||
self.down_repeat_mappers = nn.ModuleList()
|
||||
for i in range(len(block_out_channels)):
|
||||
if i > 0:
|
||||
self.down_downscalers.append(
|
||||
nn.Sequential(
|
||||
SDCascadeLayerNorm(block_out_channels[i - 1], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(
|
||||
block_out_channels[i - 1], block_out_channels[i], mode="down", enabled=switch_level[i - 1]
|
||||
)
|
||||
if switch_level is not None
|
||||
else nn.Conv2d(block_out_channels[i - 1], block_out_channels[i], kernel_size=2, stride=2),
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.down_downscalers.append(nn.Identity())
|
||||
|
||||
down_block = nn.ModuleList()
|
||||
for _ in range(down_num_layers_per_block[i]):
|
||||
for block_type in block_types_per_layer[i]:
|
||||
block = get_block(
|
||||
block_type,
|
||||
block_out_channels[i],
|
||||
num_attention_heads[i],
|
||||
dropout=dropout[i],
|
||||
self_attn=self_attn[i],
|
||||
)
|
||||
down_block.append(block)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
if down_blocks_repeat_mappers is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(down_blocks_repeat_mappers[i] - 1):
|
||||
block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1))
|
||||
self.down_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# -- up blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
self.up_upscalers = nn.ModuleList()
|
||||
self.up_repeat_mappers = nn.ModuleList()
|
||||
for i in reversed(range(len(block_out_channels))):
|
||||
if i > 0:
|
||||
self.up_upscalers.append(
|
||||
nn.Sequential(
|
||||
SDCascadeLayerNorm(block_out_channels[i], elementwise_affine=False, eps=1e-6),
|
||||
UpDownBlock2d(
|
||||
block_out_channels[i], block_out_channels[i - 1], mode="up", enabled=switch_level[i - 1]
|
||||
)
|
||||
if switch_level is not None
|
||||
else nn.ConvTranspose2d(
|
||||
block_out_channels[i], block_out_channels[i - 1], kernel_size=2, stride=2
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.up_upscalers.append(nn.Identity())
|
||||
|
||||
up_block = nn.ModuleList()
|
||||
for j in range(up_num_layers_per_block[::-1][i]):
|
||||
for k, block_type in enumerate(block_types_per_layer[i]):
|
||||
c_skip = block_out_channels[i] if i < len(block_out_channels) - 1 and j == k == 0 else 0
|
||||
block = get_block(
|
||||
block_type,
|
||||
block_out_channels[i],
|
||||
num_attention_heads[i],
|
||||
c_skip=c_skip,
|
||||
dropout=dropout[i],
|
||||
self_attn=self_attn[i],
|
||||
)
|
||||
up_block.append(block)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
if up_blocks_repeat_mappers is not None:
|
||||
block_repeat_mappers = nn.ModuleList()
|
||||
for _ in range(up_blocks_repeat_mappers[::-1][i] - 1):
|
||||
block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1))
|
||||
self.up_repeat_mappers.append(block_repeat_mappers)
|
||||
|
||||
# OUTPUT
|
||||
self.clf = nn.Sequential(
|
||||
SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6),
|
||||
nn.Conv2d(block_out_channels[0], out_channels * (patch_size**2), kernel_size=1),
|
||||
nn.PixelShuffle(patch_size),
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02)
|
||||
nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) if hasattr(self, "clip_txt_mapper") else None
|
||||
nn.init.normal_(self.clip_img_mapper.weight, std=0.02) if hasattr(self, "clip_img_mapper") else None
|
||||
|
||||
if hasattr(self, "effnet_mapper"):
|
||||
nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
|
||||
nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
|
||||
|
||||
if hasattr(self, "pixels_mapper"):
|
||||
nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
|
||||
nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
|
||||
|
||||
torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
|
||||
nn.init.constant_(self.clf[1].weight, 0) # outputs
|
||||
|
||||
# blocks
|
||||
for level_block in self.down_blocks + self.up_blocks:
|
||||
for block in level_block:
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks[0]))
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
nn.init.constant_(block.mapper.weight, 0)
|
||||
|
||||
def get_timestep_ratio_embedding(self, timestep_ratio, max_positions=10000):
|
||||
r = timestep_ratio * max_positions
|
||||
half_dim = self.config.timestep_ratio_embedding_dim // 2
|
||||
|
||||
emb = math.log(max_positions) / (half_dim - 1)
|
||||
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
||||
emb = r[:, None] * emb[None, :]
|
||||
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
||||
|
||||
if self.config.timestep_ratio_embedding_dim % 2 == 1: # zero pad
|
||||
emb = nn.functional.pad(emb, (0, 1), mode="constant")
|
||||
|
||||
return emb.to(dtype=r.dtype)
|
||||
|
||||
def get_clip_embeddings(self, clip_txt_pooled, clip_txt=None, clip_img=None):
|
||||
if len(clip_txt_pooled.shape) == 2:
|
||||
clip_txt_pool = clip_txt_pooled.unsqueeze(1)
|
||||
clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(
|
||||
clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.config.clip_seq, -1
|
||||
)
|
||||
if clip_txt is not None and clip_img is not None:
|
||||
clip_txt = self.clip_txt_mapper(clip_txt)
|
||||
if len(clip_img.shape) == 2:
|
||||
clip_img = clip_img.unsqueeze(1)
|
||||
clip_img = self.clip_img_mapper(clip_img).view(
|
||||
clip_img.size(0), clip_img.size(1) * self.config.clip_seq, -1
|
||||
)
|
||||
clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
|
||||
else:
|
||||
clip = clip_txt_pool
|
||||
return self.clip_norm(clip)
|
||||
|
||||
def _down_encode(self, x, r_embed, clip):
|
||||
level_outputs = []
|
||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, clip, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, r_embed, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x = x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), use_reentrant=False
|
||||
)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
else:
|
||||
for down_block, downscaler, repmap in block_group:
|
||||
x = downscaler(x)
|
||||
for i in range(len(repmap) + 1):
|
||||
for block in down_block:
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
x = block(x)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if i < len(repmap):
|
||||
x = repmap[i](x)
|
||||
level_outputs.insert(0, x)
|
||||
return level_outputs
|
||||
|
||||
def _up_decode(self, level_outputs, r_embed, clip):
|
||||
x = level_outputs[0]
|
||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, skip, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, clip, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, r_embed, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
else:
|
||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||
for j in range(len(repmap) + 1):
|
||||
for k, block in enumerate(up_block):
|
||||
if isinstance(block, SDCascadeResBlock):
|
||||
skip = level_outputs[i] if k == 0 and i > 0 else None
|
||||
if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x.float(), skip.shape[-2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
x = block(x, skip)
|
||||
elif isinstance(block, SDCascadeAttnBlock):
|
||||
x = block(x, clip)
|
||||
elif isinstance(block, SDCascadeTimestepBlock):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
if j < len(repmap):
|
||||
x = repmap[j](x)
|
||||
x = upscaler(x)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample,
|
||||
timestep_ratio,
|
||||
clip_text_pooled,
|
||||
clip_text=None,
|
||||
clip_img=None,
|
||||
effnet=None,
|
||||
pixels=None,
|
||||
sca=None,
|
||||
crp=None,
|
||||
return_dict=True,
|
||||
):
|
||||
if pixels is None:
|
||||
pixels = sample.new_zeros(sample.size(0), 3, 8, 8)
|
||||
|
||||
# Process the conditioning embeddings
|
||||
timestep_ratio_embed = self.get_timestep_ratio_embedding(timestep_ratio)
|
||||
for c in self.config.timestep_conditioning_type:
|
||||
if c == "sca":
|
||||
cond = sca
|
||||
elif c == "crp":
|
||||
cond = crp
|
||||
else:
|
||||
cond = None
|
||||
t_cond = cond or torch.zeros_like(timestep_ratio)
|
||||
timestep_ratio_embed = torch.cat([timestep_ratio_embed, self.get_timestep_ratio_embedding(t_cond)], dim=1)
|
||||
clip = self.get_clip_embeddings(clip_txt_pooled=clip_text_pooled, clip_txt=clip_text, clip_img=clip_img)
|
||||
|
||||
# Model Blocks
|
||||
x = self.embedding(sample)
|
||||
if hasattr(self, "effnet_mapper") and effnet is not None:
|
||||
x = x + self.effnet_mapper(
|
||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
|
||||
)
|
||||
if hasattr(self, "pixels_mapper"):
|
||||
x = x + nn.functional.interpolate(
|
||||
self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True
|
||||
)
|
||||
level_outputs = self._down_encode(x, timestep_ratio_embed, clip)
|
||||
x = self._up_decode(level_outputs, timestep_ratio_embed, clip)
|
||||
sample = self.clf(x)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
return StableCascadeUNetOutput(sample=sample)
|
||||
@@ -176,6 +176,11 @@ else:
|
||||
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
_import_structure["stable_cascade"] = [
|
||||
"StableCascadeCombinedPipeline",
|
||||
"StableCascadeDecoderPipeline",
|
||||
"StableCascadePriorPipeline",
|
||||
]
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"CLIPImageProjection",
|
||||
@@ -424,6 +429,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pixart_alpha import PixArtAlphaPipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_cascade import (
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
)
|
||||
from .stable_diffusion import (
|
||||
CLIPImageProjection,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
|
||||
508
src/diffusers/pipelines/pipeline_loading_utils.py
Normal file
508
src/diffusers/pipelines/pipeline_loading_utils.py
Normal file
@@ -0,0 +1,508 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import (
|
||||
model_info,
|
||||
)
|
||||
from packaging import version
|
||||
|
||||
from ..utils import (
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
get_class_from_dynamic_module,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
|
||||
DUMMY_MODULES_FOLDER = "diffusers.utils"
|
||||
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
|
||||
CONNECTED_PIPES_KEYS = ["prior"]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
|
||||
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"transformers": {
|
||||
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
|
||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
|
||||
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
|
||||
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"onnxruntime.training": {
|
||||
"ORTModule": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
}
|
||||
|
||||
ALL_IMPORTABLE_CLASSES = {}
|
||||
for library in LOADABLE_CLASSES:
|
||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
|
||||
|
||||
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
|
||||
"""
|
||||
Checking for safetensors compatibility:
|
||||
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
|
||||
files to know which safetensors files are needed.
|
||||
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
|
||||
|
||||
Converting default pytorch serialized filenames to safetensors serialized filenames:
|
||||
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
|
||||
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
|
||||
extension is replaced with ".safetensors"
|
||||
"""
|
||||
pt_filenames = []
|
||||
|
||||
sf_filenames = set()
|
||||
|
||||
passed_components = passed_components or []
|
||||
|
||||
for filename in filenames:
|
||||
_, extension = os.path.splitext(filename)
|
||||
|
||||
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
|
||||
continue
|
||||
|
||||
if extension == ".bin":
|
||||
pt_filenames.append(os.path.normpath(filename))
|
||||
elif extension == ".safetensors":
|
||||
sf_filenames.add(os.path.normpath(filename))
|
||||
|
||||
for filename in pt_filenames:
|
||||
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
|
||||
path, filename = os.path.split(filename)
|
||||
filename, extension = os.path.splitext(filename)
|
||||
|
||||
if filename.startswith("pytorch_model"):
|
||||
filename = filename.replace("pytorch_model", "model")
|
||||
else:
|
||||
filename = filename
|
||||
|
||||
expected_sf_filename = os.path.normpath(os.path.join(path, filename))
|
||||
expected_sf_filename = f"{expected_sf_filename}.safetensors"
|
||||
if expected_sf_filename not in sf_filenames:
|
||||
logger.warning(f"{expected_sf_filename} not found")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
# .bin, .safetensors, ...
|
||||
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
||||
# -00001-of-00002
|
||||
transformers_index_format = r"\d{5}-of-\d{5}"
|
||||
|
||||
if variant is not None:
|
||||
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors`
|
||||
variant_file_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
# `text_encoder/pytorch_model.bin.index.fp16.json`
|
||||
variant_index_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
|
||||
non_variant_file_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
# `text_encoder/pytorch_model.bin.index.json`
|
||||
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
||||
|
||||
if variant is not None:
|
||||
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
variant_filenames = variant_weights | variant_indexes
|
||||
else:
|
||||
variant_filenames = set()
|
||||
|
||||
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
|
||||
non_variant_filenames = non_variant_weights | non_variant_indexes
|
||||
|
||||
# all variant filenames will be used by default
|
||||
usable_filenames = set(variant_filenames)
|
||||
|
||||
def convert_to_variant(filename):
|
||||
if "index" in filename:
|
||||
variant_filename = filename.replace("index", f"index.{variant}")
|
||||
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
|
||||
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
|
||||
else:
|
||||
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
||||
return variant_filename
|
||||
|
||||
for f in non_variant_filenames:
|
||||
variant_filename = convert_to_variant(f)
|
||||
if variant_filename not in usable_filenames:
|
||||
usable_filenames.add(f)
|
||||
|
||||
return usable_filenames, variant_filenames
|
||||
|
||||
|
||||
@validate_hf_hub_args
|
||||
def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames):
|
||||
info = model_info(
|
||||
pretrained_model_name_or_path,
|
||||
token=token,
|
||||
revision=None,
|
||||
)
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
|
||||
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
|
||||
|
||||
if set(model_filenames).issubset(set(comp_model_filenames)):
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap_model(model):
|
||||
"""Unwraps a model."""
|
||||
if is_compiled_module(model):
|
||||
model = model._orig_mod
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftModel
|
||||
|
||||
if isinstance(model, PeftModel):
|
||||
model = model.base_model.model
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def maybe_raise_or_warn(
|
||||
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
||||
):
|
||||
"""Simple helper method to raise or warn in case incorrect module has been passed"""
|
||||
if not is_pipeline_module:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
expected_class_obj = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
expected_class_obj = class_candidate
|
||||
|
||||
# Dynamo wraps the original model in a private class.
|
||||
# I didn't find a public API to get the original class.
|
||||
sub_model = passed_class_obj[name]
|
||||
unwrapped_sub_model = _unwrap_model(sub_model)
|
||||
model_cls = unwrapped_sub_model.__class__
|
||||
|
||||
if not issubclass(model_cls, expected_class_obj):
|
||||
raise ValueError(
|
||||
f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
" has the correct type"
|
||||
)
|
||||
|
||||
|
||||
def get_class_obj_and_candidates(
|
||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
||||
):
|
||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||
component_folder = os.path.join(cache_dir, component_name)
|
||||
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
||||
# load custom component
|
||||
class_obj = get_class_from_dynamic_module(
|
||||
component_folder, module_file=library_name + ".py", class_name=class_name
|
||||
)
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
return class_obj, class_candidates
|
||||
|
||||
|
||||
def _get_pipeline_class(
|
||||
class_obj,
|
||||
config=None,
|
||||
load_connected_pipeline=False,
|
||||
custom_pipeline=None,
|
||||
repo_id=None,
|
||||
hub_revision=None,
|
||||
class_name=None,
|
||||
cache_dir=None,
|
||||
revision=None,
|
||||
):
|
||||
if custom_pipeline is not None:
|
||||
if custom_pipeline.endswith(".py"):
|
||||
path = Path(custom_pipeline)
|
||||
# decompose into folder & file
|
||||
file_name = path.name
|
||||
custom_pipeline = path.parent.absolute()
|
||||
elif repo_id is not None:
|
||||
file_name = f"{custom_pipeline}.py"
|
||||
custom_pipeline = repo_id
|
||||
else:
|
||||
file_name = CUSTOM_PIPELINE_FILE_NAME
|
||||
|
||||
if repo_id is not None and hub_revision is not None:
|
||||
# if we load the pipeline code from the Hub
|
||||
# make sure to overwrite the `revision`
|
||||
revision = hub_revision
|
||||
|
||||
return get_class_from_dynamic_module(
|
||||
custom_pipeline,
|
||||
module_file=file_name,
|
||||
class_name=class_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if class_obj.__name__ != "DiffusionPipeline":
|
||||
return class_obj
|
||||
|
||||
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
||||
class_name = class_name or config["_class_name"]
|
||||
if not class_name:
|
||||
raise ValueError(
|
||||
"The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`."
|
||||
)
|
||||
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
pipeline_cls = getattr(diffusers_module, class_name)
|
||||
|
||||
if load_connected_pipeline:
|
||||
from .auto_pipeline import _get_connected_pipeline
|
||||
|
||||
connected_pipeline_cls = _get_connected_pipeline(pipeline_cls)
|
||||
if connected_pipeline_cls is not None:
|
||||
logger.info(
|
||||
f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`"
|
||||
)
|
||||
else:
|
||||
logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.")
|
||||
|
||||
pipeline_cls = connected_pipeline_cls or pipeline_cls
|
||||
|
||||
return pipeline_cls
|
||||
|
||||
|
||||
def load_sub_model(
|
||||
library_name: str,
|
||||
class_name: str,
|
||||
importable_classes: List[Any],
|
||||
pipelines: Any,
|
||||
is_pipeline_module: bool,
|
||||
pipeline_class: Any,
|
||||
torch_dtype: torch.dtype,
|
||||
provider: Any,
|
||||
sess_options: Any,
|
||||
device_map: Optional[Union[Dict[str, torch.device], str]],
|
||||
max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
|
||||
offload_folder: Optional[Union[str, os.PathLike]],
|
||||
offload_state_dict: bool,
|
||||
model_variants: Dict[str, str],
|
||||
name: str,
|
||||
from_flax: bool,
|
||||
variant: str,
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
# retrieve class candidates
|
||||
class_obj, class_candidates = get_class_obj_and_candidates(
|
||||
library_name,
|
||||
class_name,
|
||||
importable_classes,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
component_name=name,
|
||||
cache_dir=cached_folder,
|
||||
)
|
||||
|
||||
load_method_name = None
|
||||
# retrieve load method name
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
load_method_name = importable_classes[class_name][1]
|
||||
|
||||
# if load method name is None, then we have a dummy module -> raise Error
|
||||
if load_method_name is None:
|
||||
none_module = class_obj.__module__
|
||||
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
|
||||
TRANSFORMERS_DUMMY_MODULES_FOLDER
|
||||
)
|
||||
if is_dummy_path and "dummy" in none_module:
|
||||
# call class_obj for nice error message of missing requirements
|
||||
class_obj()
|
||||
|
||||
raise ValueError(
|
||||
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
||||
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
||||
)
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
|
||||
# add kwargs to loading method
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
loading_kwargs = {}
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs["torch_dtype"] = torch_dtype
|
||||
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
|
||||
loading_kwargs["provider"] = provider
|
||||
loading_kwargs["sess_options"] = sess_options
|
||||
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
loading_kwargs["device_map"] = device_map
|
||||
loading_kwargs["max_memory"] = max_memory
|
||||
loading_kwargs["offload_folder"] = offload_folder
|
||||
loading_kwargs["offload_state_dict"] = offload_state_dict
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
# the following can be deleted once the minimum required `transformers` version
|
||||
# is higher than 4.27
|
||||
if (
|
||||
is_transformers_model
|
||||
and loading_kwargs["variant"] is not None
|
||||
and transformers_version < version.parse("4.27.0")
|
||||
):
|
||||
raise ImportError(
|
||||
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
|
||||
)
|
||||
elif is_transformers_model and loading_kwargs["variant"] is None:
|
||||
loading_kwargs.pop("variant")
|
||||
|
||||
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
|
||||
if not (from_flax and is_transformers_model):
|
||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
else:
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
||||
else:
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
def _fetch_class_library_tuple(module):
|
||||
# import it here to avoid circular import
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipelines = getattr(diffusers_module, "pipelines")
|
||||
|
||||
# register the config from the original module, not the dynamo compiled one
|
||||
not_compiled_module = _unwrap_model(module)
|
||||
library = not_compiled_module.__module__.split(".")[0]
|
||||
|
||||
# check if the module is a pipeline module
|
||||
module_path_items = not_compiled_module.__module__.split(".")
|
||||
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
|
||||
|
||||
path = not_compiled_module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
elif library not in LOADABLE_CLASSES:
|
||||
library = not_compiled_module.__module__
|
||||
|
||||
# retrieve class_name
|
||||
class_name = not_compiled_module.__class__.__name__
|
||||
|
||||
return (library, class_name)
|
||||
File diff suppressed because it is too large
Load Diff
50
src/diffusers/pipelines/stable_cascade/__init__.py
Normal file
50
src/diffusers/pipelines/stable_cascade/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_cascade"] = ["StableCascadeDecoderPipeline"]
|
||||
_import_structure["pipeline_stable_cascade_combined"] = ["StableCascadeCombinedPipeline"]
|
||||
_import_structure["pipeline_stable_cascade_prior"] = ["StableCascadePriorPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
|
||||
from .pipeline_stable_cascade_combined import StableCascadeCombinedPipeline
|
||||
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,465 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline
|
||||
|
||||
>>> prior_pipe = StableCascadePriorPipeline.from_pretrained(
|
||||
... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16
|
||||
... ).to("cuda")
|
||||
>>> gen_pipe = StableCascadeDecoderPipeline.from_pretrain(
|
||||
... "stabilityai/stable-cascade", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
|
||||
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
>>> prior_output = pipe(prompt)
|
||||
>>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating images from the Stable Cascade model.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
The CLIP tokenizer.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The CLIP text encoder.
|
||||
decoder ([`StableCascadeUNet`]):
|
||||
The Stable Cascade decoder unet.
|
||||
vqgan ([`PaellaVQModel`]):
|
||||
The VQGAN model.
|
||||
scheduler ([`DDPMWuerstchenScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
latent_dim_scale (float, `optional`, defaults to 10.67):
|
||||
Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
|
||||
height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and
|
||||
width=int(24*10.67)=256 in order to match the training conditions.
|
||||
"""
|
||||
|
||||
unet_name = "decoder"
|
||||
text_encoder_name = "text_encoder"
|
||||
model_cpu_offload_seq = "text_encoder->decoder->vqgan"
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds_pooled",
|
||||
"negative_prompt_embeds",
|
||||
"image_embeddings",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: StableCascadeUNet,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
vqgan: PaellaVQModel,
|
||||
latent_dim_scale: float = 10.67,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
decoder=decoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqgan,
|
||||
)
|
||||
self.register_to_config(latent_dim_scale=latent_dim_scale)
|
||||
|
||||
def prepare_latents(self, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler):
|
||||
batch_size, channels, height, width = image_embeddings.shape
|
||||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
4,
|
||||
int(height * self.config.latent_dim_scale),
|
||||
int(width * self.config.latent_dim_scale),
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
device,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
prompt=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
if prompt_embeds is None:
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True
|
||||
)
|
||||
prompt_embeds = text_encoder_output.hidden_states[-1]
|
||||
if prompt_embeds_pooled is None:
|
||||
prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if negative_prompt_embeds is None and do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
negative_prompt_embeds_text_encoder_output = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=uncond_input.attention_mask.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1]
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
seq_len = negative_prompt_embeds_pooled.shape[1]
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to(
|
||||
dtype=self.text_encoder.dtype, device=device
|
||||
)
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
# done duplicates
|
||||
|
||||
return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image_embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]],
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_inference_steps: int = 10,
|
||||
guidance_scale: float = 0.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image_embedding (`torch.FloatTensor` or `List[torch.FloatTensor]`):
|
||||
Image Embeddings either extracted from an image or generated by a Prior Model.
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 12):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
|
||||
`decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
|
||||
linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `decoder_guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
|
||||
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image
|
||||
embeddings.
|
||||
"""
|
||||
|
||||
# 0. Define commonly used variables
|
||||
device = self._execution_device
|
||||
dtype = self.decoder.dtype
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
if isinstance(image_embeddings, list):
|
||||
image_embeddings = torch.cat(image_embeddings, dim=0)
|
||||
batch_size = image_embeddings.shape[0]
|
||||
|
||||
# 2. Encode caption
|
||||
if prompt_embeds is None and negative_prompt_embeds is None:
|
||||
prompt_embeds, _, negative_prompt_embeds, _ = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
prompt_embeds_pooled = (
|
||||
torch.cat([prompt_embeds, negative_prompt_embeds]) if self.do_classifier_free_guidance else prompt_embeds
|
||||
)
|
||||
effnet = (
|
||||
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
|
||||
if self.do_classifier_free_guidance
|
||||
else image_embeddings
|
||||
)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latents
|
||||
latents = self.prepare_latents(
|
||||
image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
||||
)
|
||||
|
||||
# 6. Run denoising loop
|
||||
self._num_timesteps = len(timesteps[:-1])
|
||||
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
|
||||
timestep_ratio = t.expand(latents.size(0)).to(dtype)
|
||||
|
||||
# 7. Denoise latents
|
||||
predicted_latents = self.decoder(
|
||||
sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
|
||||
timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio,
|
||||
clip_text_pooled=prompt_embeds_pooled,
|
||||
effnet=effnet,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# 8. Check for classifier free guidance and apply it
|
||||
if self.do_classifier_free_guidance:
|
||||
predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2)
|
||||
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
|
||||
|
||||
# 9. Renoise latents to next timestep
|
||||
latents = self.scheduler.step(
|
||||
model_output=predicted_latents,
|
||||
timestep=timestep_ratio,
|
||||
sample=latents,
|
||||
generator=generator,
|
||||
).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if output_type not in ["pt", "np", "pil", "latent"]:
|
||||
raise ValueError(
|
||||
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
|
||||
)
|
||||
|
||||
if not output_type == "latent":
|
||||
# 10. Scale and decode the image latents with vq-vae
|
||||
latents = self.vqgan.config.scale_factor * latents
|
||||
images = self.vqgan.decode(latents).sample.clamp(0, 1)
|
||||
if output_type == "np":
|
||||
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
|
||||
elif output_type == "pil":
|
||||
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
|
||||
images = self.numpy_to_pil(images)
|
||||
else:
|
||||
images = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return images
|
||||
return ImagePipelineOutput(images)
|
||||
@@ -0,0 +1,294 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
|
||||
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
|
||||
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
|
||||
|
||||
|
||||
TEXT2IMAGE_EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusions import StableCascadeCombinedPipeline
|
||||
|
||||
>>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade-combined", torch_dtype=torch.bfloat16).to(
|
||||
... "cuda"
|
||||
... )
|
||||
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
>>> images = pipe(prompt=prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableCascadeCombinedPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Combined Pipeline for text-to-image generation using Stable Cascade.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
The decoder tokenizer to be used for text inputs.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The decoder text encoder to be used for text inputs.
|
||||
decoder (`StableCascadeUNet`):
|
||||
The decoder model to be used for decoder image generation pipeline.
|
||||
scheduler (`DDPMWuerstchenScheduler`):
|
||||
The scheduler to be used for decoder image generation pipeline.
|
||||
vqgan (`PaellaVQModel`):
|
||||
The VQGAN model to be used for decoder image generation pipeline.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
|
||||
image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
prior_prior (`StableCascadeUNet`):
|
||||
The prior model to be used for prior pipeline.
|
||||
prior_scheduler (`DDPMWuerstchenScheduler`):
|
||||
The scheduler to be used for prior pipeline.
|
||||
"""
|
||||
|
||||
_load_connected_pipes = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
decoder: StableCascadeUNet,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
vqgan: PaellaVQModel,
|
||||
prior_prior: StableCascadeUNet,
|
||||
prior_text_encoder: CLIPTextModel,
|
||||
prior_tokenizer: CLIPTokenizer,
|
||||
prior_scheduler: DDPMWuerstchenScheduler,
|
||||
prior_feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
prior_image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqgan,
|
||||
prior_text_encoder=prior_text_encoder,
|
||||
prior_tokenizer=prior_tokenizer,
|
||||
prior_prior=prior_prior,
|
||||
prior_scheduler=prior_scheduler,
|
||||
prior_feature_extractor=prior_feature_extractor,
|
||||
prior_image_encoder=prior_image_encoder,
|
||||
)
|
||||
self.prior_pipe = StableCascadePriorPipeline(
|
||||
prior=prior_prior,
|
||||
text_encoder=prior_text_encoder,
|
||||
tokenizer=prior_tokenizer,
|
||||
scheduler=prior_scheduler,
|
||||
image_encoder=prior_image_encoder,
|
||||
feature_extractor=prior_feature_extractor,
|
||||
)
|
||||
self.decoder_pipe = StableCascadeDecoderPipeline(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
decoder=decoder,
|
||||
scheduler=scheduler,
|
||||
vqgan=vqgan,
|
||||
)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
|
||||
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
|
||||
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
|
||||
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
||||
|
||||
def progress_bar(self, iterable=None, total=None):
|
||||
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
||||
self.decoder_pipe.progress_bar(iterable=iterable, total=total)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self.prior_pipe.set_progress_bar_config(**kwargs)
|
||||
self.decoder_pipe.set_progress_bar_config(**kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
prior_num_inference_steps: int = 60,
|
||||
prior_timesteps: Optional[List[float]] = None,
|
||||
prior_guidance_scale: float = 4.0,
|
||||
num_inference_steps: int = 12,
|
||||
decoder_timesteps: Optional[List[float]] = None,
|
||||
decoder_guidance_scale: float = 0.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation for the prior and decoder.
|
||||
images (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, *optional*):
|
||||
The images to guide the image generation for the prior.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
|
||||
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`prior_guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
|
||||
`prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
|
||||
to the text `prompt`, usually at the expense of lower image quality.
|
||||
prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60):
|
||||
The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference. For more specific timestep spacing, you can pass customized
|
||||
`prior_timesteps`
|
||||
num_inference_steps (`int`, *optional*, defaults to 12):
|
||||
The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at
|
||||
the expense of slower inference. For more specific timestep spacing, you can pass customized
|
||||
`timesteps`
|
||||
decoder_guidance_scale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
prior_callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep:
|
||||
int, callback_kwargs: Dict)`.
|
||||
prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
|
||||
list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
|
||||
the `._callback_tensor_inputs` attribute of your pipeine class.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
|
||||
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
prior_outputs = self.prior_pipe(
|
||||
prompt=prompt if prompt_embeds is None else None,
|
||||
images=images,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=prior_num_inference_steps,
|
||||
guidance_scale=prior_guidance_scale,
|
||||
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type="pt",
|
||||
return_dict=True,
|
||||
callback_on_step_end=prior_callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
image_embeddings = prior_outputs.image_embeddings
|
||||
prompt_embeds = prior_outputs.get("prompt_embeds", None)
|
||||
negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None)
|
||||
|
||||
outputs = self.decoder_pipe(
|
||||
image_embeddings=image_embeddings,
|
||||
prompt=prompt if prompt_embeds is None else None,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=decoder_guidance_scale,
|
||||
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
generator=generator,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
@@ -0,0 +1,614 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import StableCascadeUNet
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import BaseOutput, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableCascadePriorPipeline
|
||||
|
||||
>>> prior_pipe = StableCascadePriorPipeline.from_pretrained(
|
||||
... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16
|
||||
... ).to("cuda")
|
||||
|
||||
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
|
||||
>>> prior_output = pipe(prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StableCascadePriorPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for WuerstchenPriorPipeline.
|
||||
|
||||
Args:
|
||||
image_embeddings (`torch.FloatTensor` or `np.ndarray`)
|
||||
Prior image embeddings for text prompt
|
||||
prompt_embeds (`torch.FloatTensor`):
|
||||
Text embeddings for the prompt.
|
||||
negative_prompt_embeds (`torch.FloatTensor`):
|
||||
Text embeddings for the negative prompt.
|
||||
"""
|
||||
|
||||
image_embeddings: Union[torch.FloatTensor, np.ndarray]
|
||||
prompt_embeds: Union[torch.FloatTensor, np.ndarray]
|
||||
negative_prompt_embeds: Union[torch.FloatTensor, np.ndarray]
|
||||
|
||||
|
||||
class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating image prior for Stable Cascade.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
prior ([`StableCascadeUNet`]):
|
||||
The Stable Cascade prior to approximate the image embedding from the text and/or image embedding.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
Frozen text-encoder ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
|
||||
image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`DDPMWuerstchenScheduler`]):
|
||||
A scheduler to be used in combination with `prior` to generate image embedding.
|
||||
resolution_multiple ('float', *optional*, defaults to 42.67):
|
||||
Default resolution for multiple images generated.
|
||||
"""
|
||||
|
||||
unet_name = "prior"
|
||||
text_encoder_name = "text_encoder"
|
||||
model_cpu_offload_seq = "image_encoder->text_encoder->prior"
|
||||
_optional_components = ["image_encoder", "feature_extractor"]
|
||||
_callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
prior: StableCascadeUNet,
|
||||
scheduler: DDPMWuerstchenScheduler,
|
||||
resolution_multiple: float = 42.67,
|
||||
feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
prior=prior,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.register_to_config(resolution_multiple=resolution_multiple)
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, scheduler
|
||||
):
|
||||
latent_shape = (
|
||||
num_images_per_prompt * batch_size,
|
||||
self.prior.config.in_channels,
|
||||
ceil(height / self.config.resolution_multiple),
|
||||
ceil(width / self.config.resolution_multiple),
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != latent_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latent_shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
device,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
prompt=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
if prompt_embeds is None:
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True
|
||||
)
|
||||
prompt_embeds = text_encoder_output.hidden_states[-1]
|
||||
if prompt_embeds_pooled is None:
|
||||
prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if negative_prompt_embeds is None and do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
negative_prompt_embeds_text_encoder_output = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=uncond_input.attention_mask.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1]
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
seq_len = negative_prompt_embeds_pooled.shape[1]
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to(
|
||||
dtype=self.text_encoder.dtype, device=device
|
||||
)
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
# done duplicates
|
||||
|
||||
return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled
|
||||
|
||||
def encode_image(self, images, device, dtype, batch_size, num_images_per_prompt):
|
||||
image_embeds = []
|
||||
for image in images:
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embed = self.image_encoder(image).image_embeds.unsqueeze(1)
|
||||
image_embeds.append(image_embed)
|
||||
image_embeds = torch.cat(image_embeds, dim=1)
|
||||
|
||||
image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1)
|
||||
negative_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, negative_image_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
images=None,
|
||||
image_embeds=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
prompt_embeds_pooled=None,
|
||||
negative_prompt_embeds=None,
|
||||
negative_prompt_embeds_pooled=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None:
|
||||
if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds_pooled` and `negative_prompt_embeds_pooled` must have the same shape when passed"
|
||||
f"directly, but got: `prompt_embeds_pooled` {prompt_embeds_pooled.shape} !="
|
||||
f"`negative_prompt_embeds_pooled` {negative_prompt_embeds_pooled.shape}."
|
||||
)
|
||||
|
||||
if image_embeds is not None and images is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `images`: {images} and `image_embeds`: {image_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
|
||||
if images:
|
||||
for i, image in enumerate(images):
|
||||
if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
|
||||
raise TypeError(
|
||||
f"'images' must contain images of type 'torch.Tensor' or 'PIL.Image.Image, but got"
|
||||
f"{type(image)} for image number {i}."
|
||||
)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
def get_t_condioning(self, t, alphas_cumprod):
|
||||
s = torch.tensor([0.003])
|
||||
clamp_range = [0, 1]
|
||||
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
|
||||
var = alphas_cumprod[t]
|
||||
var = var.clamp(*clamp_range)
|
||||
s, min_var = s.to(var.device), min_var.to(var.device)
|
||||
ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
||||
return ratio
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
num_inference_steps: int = 20,
|
||||
timesteps: List[float] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
|
||||
image_embeds: Optional[torch.FloatTensor] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pt",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 1024):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 1024):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 60):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 8.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
|
||||
`decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
|
||||
linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `decoder_guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input
|
||||
argument.
|
||||
image_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated image embeddings. Can be used to easily tweak image inputs, *e.g.* prompt weighting.
|
||||
If not provided, image embeddings will be generated from `image` input argument if existing.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`StableCascadePriorPipelineOutput`] or `tuple` [`StableCascadePriorPipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated image embeddings.
|
||||
"""
|
||||
|
||||
# 0. Define commonly used variables
|
||||
device = self._execution_device
|
||||
dtype = next(self.prior.parameters()).dtype
|
||||
self._guidance_scale = guidance_scale
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
images=images,
|
||||
image_embeds=image_embeds,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
# 2. Encode caption + images
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_embeds_pooled,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_pooled=prompt_embeds_pooled,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
|
||||
)
|
||||
|
||||
if images is not None:
|
||||
image_embeds_pooled, uncond_image_embeds_pooled = self.encode_image(
|
||||
images=images,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
elif image_embeds is not None:
|
||||
image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1)
|
||||
uncond_image_embeds_pooled = torch.zeros_like(image_embeds_pooled)
|
||||
else:
|
||||
image_embeds_pooled = torch.zeros(
|
||||
batch_size * num_images_per_prompt,
|
||||
1,
|
||||
self.prior.config.clip_image_in_channels,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
uncond_image_embeds_pooled = torch.zeros(
|
||||
batch_size * num_images_per_prompt,
|
||||
1,
|
||||
self.prior.config.clip_image_in_channels,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([image_embeds_pooled, uncond_image_embeds_pooled], dim=0)
|
||||
else:
|
||||
image_embeds = image_embeds_pooled
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_encoder_hidden_states = (
|
||||
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
|
||||
)
|
||||
text_encoder_pooled = (
|
||||
torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled])
|
||||
if negative_prompt_embeds is not None
|
||||
else prompt_embeds_pooled
|
||||
)
|
||||
|
||||
# 4. Prepare and set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latents
|
||||
latents = self.prepare_latents(
|
||||
batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
||||
)
|
||||
|
||||
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
timesteps = timesteps[:-1]
|
||||
else:
|
||||
if self.scheduler.config.clip_sample:
|
||||
self.scheduler.config.clip_sample = False # disample sample clipping
|
||||
logger.warning(" set `clip_sample` to be False")
|
||||
# 6. Run denoising loop
|
||||
if hasattr(self.scheduler, "betas"):
|
||||
alphas = 1.0 - self.scheduler.betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
else:
|
||||
alphas_cumprod = []
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
if len(alphas_cumprod) > 0:
|
||||
timestep_ratio = self.get_t_condioning(t.long().cpu(), alphas_cumprod)
|
||||
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
|
||||
else:
|
||||
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
|
||||
else:
|
||||
timestep_ratio = t.expand(latents.size(0)).to(dtype)
|
||||
# 7. Denoise image embeddings
|
||||
predicted_image_embedding = self.prior(
|
||||
sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
|
||||
timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio,
|
||||
clip_text_pooled=text_encoder_pooled,
|
||||
clip_text=text_encoder_hidden_states,
|
||||
clip_img=image_embeds,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# 8. Check for classifier free guidance and apply it
|
||||
if self.do_classifier_free_guidance:
|
||||
predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2)
|
||||
predicted_image_embedding = torch.lerp(
|
||||
predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale
|
||||
)
|
||||
|
||||
# 9. Renoise latents to next timestep
|
||||
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
timestep_ratio = t
|
||||
latents = self.scheduler.step(
|
||||
model_output=predicted_image_embedding, timestep=timestep_ratio, sample=latents, generator=generator
|
||||
).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if output_type == "np":
|
||||
latents = latents.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
|
||||
prompt_embeds = prompt_embeds.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
|
||||
negative_prompt_embeds = (
|
||||
negative_prompt_embeds.cpu().float().numpy() if negative_prompt_embeds is not None else None
|
||||
) # float() as bfloat16-> numpy doesnt work
|
||||
|
||||
if not return_dict:
|
||||
return (latents, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
return StableCascadePriorPipelineOutput(latents, prompt_embeds, negative_prompt_embeds)
|
||||
@@ -322,7 +322,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
if isinstance(self.guidance_scale, (int, float)):
|
||||
return self.guidance_scale
|
||||
return self.guidance_scale > 1
|
||||
return self.guidance_scale.max() > 1
|
||||
|
||||
@property
|
||||
|
||||
@@ -1,18 +1,3 @@
|
||||
# Copyright (c) 2023 Dominic Rampas MIT License
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@@ -233,7 +233,7 @@ class WuerstchenDiffNeXt(ModelMixin, ConfigMixin):
|
||||
|
||||
|
||||
class ResBlockStageB(nn.Module):
|
||||
def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0):
|
||||
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
|
||||
super().__init__()
|
||||
self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
||||
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
@@ -349,6 +349,11 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
text_encoder_hidden_states = (
|
||||
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
|
||||
)
|
||||
effnet = (
|
||||
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
|
||||
if self.do_classifier_free_guidance
|
||||
else image_embeddings
|
||||
)
|
||||
|
||||
# 3. Determine latent shape of latents
|
||||
latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale)
|
||||
@@ -371,11 +376,6 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
self._num_timesteps = len(timesteps[:-1])
|
||||
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
|
||||
ratio = t.expand(latents.size(0)).to(dtype)
|
||||
effnet = (
|
||||
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
|
||||
if self.do_classifier_free_guidance
|
||||
else image_embeddings
|
||||
)
|
||||
# 7. Denoise latents
|
||||
predicted_latents = self.decoder(
|
||||
torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
|
||||
@@ -423,9 +423,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
latents = self.vqgan.config.scale_factor * latents
|
||||
images = self.vqgan.decode(latents).sample.clamp(0, 1)
|
||||
if output_type == "np":
|
||||
images = images.permute(0, 2, 3, 1).cpu().numpy()
|
||||
images = images.permute(0, 2, 3, 1).cpu().float().numpy()
|
||||
elif output_type == "pil":
|
||||
images = images.permute(0, 2, 3, 1).cpu().numpy()
|
||||
images = images.permute(0, 2, 3, 1).cpu().float().numpy()
|
||||
images = self.numpy_to_pil(images)
|
||||
else:
|
||||
images = latents
|
||||
|
||||
@@ -508,7 +508,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if output_type == "np":
|
||||
latents = latents.cpu().numpy()
|
||||
latents = latents.cpu().float().numpy()
|
||||
|
||||
if not return_dict:
|
||||
return (latents,)
|
||||
|
||||
@@ -19,7 +19,6 @@ from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from .constants import (
|
||||
_ACCEPTED_SINGLE_FILE_FORMATS,
|
||||
CONFIG_NAME,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||
@@ -84,7 +83,7 @@ from .import_utils import (
|
||||
is_xformers_available,
|
||||
requires_backends,
|
||||
)
|
||||
from .loading_utils import is_single_file_checkpoint, load_image
|
||||
from .loading_utils import load_image
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import (
|
||||
|
||||
@@ -37,7 +37,6 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://hugging
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
|
||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||
_ACCEPTED_SINGLE_FILE_FORMATS = (".safetensors", ".ckpt", ".bin", ".pth", ".pt")
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
|
||||
@@ -752,6 +752,51 @@ class ShapEPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableCascadeCombinedPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableCascadeDecoderPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableCascadePriorPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionAdapterPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -1,28 +1,10 @@
|
||||
import os
|
||||
from typing import Callable, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
import requests
|
||||
|
||||
from ..utils.constants import _ACCEPTED_SINGLE_FILE_FORMATS
|
||||
|
||||
|
||||
def is_single_file_checkpoint(filepath):
|
||||
def is_valid_url(url):
|
||||
result = urlparse(url)
|
||||
if result.scheme and result.netloc:
|
||||
return True
|
||||
|
||||
filepath = str(filepath)
|
||||
if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS):
|
||||
if is_valid_url(filepath):
|
||||
return True
|
||||
elif os.path.isfile(filepath):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load_image(
|
||||
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
|
||||
|
||||
0
tests/pipelines/stable_cascade/__init__.py
Normal file
0
tests/pipelines/stable_cascade/__init__.py
Normal file
246
tests/pipelines/stable_cascade/test_stable_cascade_combined.py
Normal file
246
tests/pipelines/stable_cascade/test_stable_cascade_combined.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import DDPMWuerstchenScheduler, StableCascadeCombinedPipeline
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableCascadeCombinedPipeline
|
||||
params = ["prompt"]
|
||||
batch_params = ["prompt", "negative_prompt"]
|
||||
required_optional_params = [
|
||||
"generator",
|
||||
"height",
|
||||
"width",
|
||||
"latents",
|
||||
"prior_guidance_scale",
|
||||
"decoder_guidance_scale",
|
||||
"negative_prompt",
|
||||
"num_inference_steps",
|
||||
"return_dict",
|
||||
"prior_num_inference_steps",
|
||||
"output_type",
|
||||
]
|
||||
test_xformers_attention = True
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def dummy_prior(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"conditioning_dim": 128,
|
||||
"block_out_channels": (128, 128),
|
||||
"num_attention_heads": (2, 2),
|
||||
"down_num_layers_per_block": (1, 1),
|
||||
"up_num_layers_per_block": (1, 1),
|
||||
"clip_image_in_channels": 768,
|
||||
"switch_level": (False,),
|
||||
"clip_text_in_channels": self.text_embedder_hidden_size,
|
||||
"clip_text_pooled_in_channels": self.text_embedder_hidden_size,
|
||||
}
|
||||
|
||||
model = StableCascadeUNet(**model_kwargs)
|
||||
return model.eval()
|
||||
|
||||
@property
|
||||
def dummy_tokenizer(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
return tokenizer
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
projection_dim=self.text_embedder_hidden_size,
|
||||
hidden_size=self.text_embedder_hidden_size,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModelWithProjection(config).eval()
|
||||
|
||||
@property
|
||||
def dummy_vqgan(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"bottleneck_blocks": 1,
|
||||
"num_vq_embeddings": 2,
|
||||
}
|
||||
model = PaellaVQModel(**model_kwargs)
|
||||
return model.eval()
|
||||
|
||||
@property
|
||||
def dummy_decoder(self):
|
||||
torch.manual_seed(0)
|
||||
model_kwargs = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"conditioning_dim": 128,
|
||||
"block_out_channels": (16, 32, 64, 128),
|
||||
"num_attention_heads": (-1, -1, 1, 2),
|
||||
"down_num_layers_per_block": (1, 1, 1, 1),
|
||||
"up_num_layers_per_block": (1, 1, 1, 1),
|
||||
"down_blocks_repeat_mappers": (1, 1, 1, 1),
|
||||
"up_blocks_repeat_mappers": (3, 3, 2, 2),
|
||||
"block_types_per_layer": (
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock"),
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock"),
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
|
||||
),
|
||||
"switch_level": None,
|
||||
"clip_text_pooled_in_channels": 32,
|
||||
"dropout": (0.1, 0.1, 0.1, 0.1),
|
||||
}
|
||||
|
||||
model = StableCascadeUNet(**model_kwargs)
|
||||
return model.eval()
|
||||
|
||||
def get_dummy_components(self):
|
||||
prior = self.dummy_prior
|
||||
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
tokenizer = self.dummy_tokenizer
|
||||
text_encoder = self.dummy_text_encoder
|
||||
decoder = self.dummy_decoder
|
||||
vqgan = self.dummy_vqgan
|
||||
prior_text_encoder = self.dummy_text_encoder
|
||||
prior_tokenizer = self.dummy_tokenizer
|
||||
|
||||
components = {
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"decoder": decoder,
|
||||
"scheduler": scheduler,
|
||||
"vqgan": vqgan,
|
||||
"prior_text_encoder": prior_text_encoder,
|
||||
"prior_tokenizer": prior_tokenizer,
|
||||
"prior_prior": prior,
|
||||
"prior_scheduler": scheduler,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "horse",
|
||||
"generator": generator,
|
||||
"prior_guidance_scale": 4.0,
|
||||
"decoder_guidance_scale": 4.0,
|
||||
"num_inference_steps": 2,
|
||||
"prior_num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
"height": 128,
|
||||
"width": 128,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_cascade(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(device))
|
||||
image = output.images
|
||||
|
||||
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[-3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
|
||||
expected_slice = np.array([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
assert (
|
||||
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
|
||||
|
||||
@require_torch_gpu
|
||||
def test_offloads(self):
|
||||
pipes = []
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_sequential_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_model_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=2e-2)
|
||||
|
||||
@unittest.skip(reason="fp16 not supported")
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference()
|
||||
|
||||
@unittest.skip(reason="no callback test for combined pipeline")
|
||||
def test_callback_inputs(self):
|
||||
super().test_callback_inputs()
|
||||
|
||||
# def test_callback_cfg(self):
|
||||
# pass
|
||||
# pass
|
||||
249
tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
Normal file
249
tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import DDPMWuerstchenScheduler, StableCascadeDecoderPipeline
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
load_pt,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableCascadeDecoderPipeline
|
||||
params = ["prompt"]
|
||||
batch_params = ["image_embeddings", "prompt", "negative_prompt"]
|
||||
required_optional_params = [
|
||||
"num_images_per_prompt",
|
||||
"num_inference_steps",
|
||||
"latents",
|
||||
"negative_prompt",
|
||||
"guidance_scale",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
]
|
||||
test_xformers_attention = False
|
||||
callback_cfg_params = ["image_embeddings", "text_encoder_hidden_states"]
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_input_dim(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def block_out_channels_0(self):
|
||||
return self.time_input_dim
|
||||
|
||||
@property
|
||||
def time_embed_dim(self):
|
||||
return self.time_input_dim * 4
|
||||
|
||||
@property
|
||||
def dummy_tokenizer(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
return tokenizer
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
projection_dim=self.text_embedder_hidden_size,
|
||||
hidden_size=self.text_embedder_hidden_size,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModelWithProjection(config).eval()
|
||||
|
||||
@property
|
||||
def dummy_vqgan(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"bottleneck_blocks": 1,
|
||||
"num_vq_embeddings": 2,
|
||||
}
|
||||
model = PaellaVQModel(**model_kwargs)
|
||||
return model.eval()
|
||||
|
||||
@property
|
||||
def dummy_decoder(self):
|
||||
torch.manual_seed(0)
|
||||
model_kwargs = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"conditioning_dim": 128,
|
||||
"block_out_channels": [16, 32, 64, 128],
|
||||
"num_attention_heads": [-1, -1, 1, 2],
|
||||
"down_num_layers_per_block": [1, 1, 1, 1],
|
||||
"up_num_layers_per_block": [1, 1, 1, 1],
|
||||
"down_blocks_repeat_mappers": [1, 1, 1, 1],
|
||||
"up_blocks_repeat_mappers": [3, 3, 2, 2],
|
||||
"block_types_per_layer": [
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
||||
],
|
||||
"switch_level": None,
|
||||
"clip_text_pooled_in_channels": 32,
|
||||
"dropout": [0.1, 0.1, 0.1, 0.1],
|
||||
}
|
||||
model = StableCascadeUNet(**model_kwargs)
|
||||
return model.eval()
|
||||
|
||||
def get_dummy_components(self):
|
||||
decoder = self.dummy_decoder
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = self.dummy_tokenizer
|
||||
vqgan = self.dummy_vqgan
|
||||
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
|
||||
components = {
|
||||
"decoder": decoder,
|
||||
"vqgan": vqgan,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"scheduler": scheduler,
|
||||
"latent_dim_scale": 4.0,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"image_embeddings": torch.ones((1, 4, 4, 4), device=device),
|
||||
"prompt": "horse",
|
||||
"generator": generator,
|
||||
"guidance_scale": 2.0,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_wuerstchen_decoder(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(device))
|
||||
image = output.images
|
||||
|
||||
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@skip_mps
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=1e-2)
|
||||
|
||||
@skip_mps
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
test_max_difference = torch_device == "cpu"
|
||||
test_mean_pixel_difference = False
|
||||
|
||||
self._test_attention_slicing_forward_pass(
|
||||
test_max_difference=test_max_difference,
|
||||
test_mean_pixel_difference=test_mean_pixel_difference,
|
||||
)
|
||||
|
||||
@unittest.skip(reason="fp16 not supported")
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableCascadeDecoderPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_cascade_decoder(self):
|
||||
pipe = StableCascadeDecoderPipeline.from_pretrained(
|
||||
"diffusers/StableCascade-decoder", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_embedding = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt"
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt, image_embeddings=image_embedding, num_inference_steps=10, generator=generator
|
||||
).images[0]
|
||||
|
||||
assert image.size == (1024, 1024)
|
||||
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/t2i.png"
|
||||
)
|
||||
|
||||
image_processor = VaeImageProcessor()
|
||||
|
||||
image_np = image_processor.pil_to_numpy(image)
|
||||
expected_image_np = image_processor.pil_to_numpy(expected_image)
|
||||
|
||||
self.assertTrue(np.allclose(image_np, expected_image_np, atol=53e-2))
|
||||
308
tests/pipelines/stable_cascade/test_stable_cascade_prior.py
Normal file
308
tests/pipelines/stable_cascade/test_stable_cascade_prior.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_pt,
|
||||
require_peft_backend,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def create_prior_lora_layers(unet: nn.Module):
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=unet.config.c,
|
||||
)
|
||||
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
|
||||
return lora_attn_procs, unet_lora_layers
|
||||
|
||||
|
||||
class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableCascadePriorPipeline
|
||||
params = ["prompt"]
|
||||
batch_params = ["prompt", "negative_prompt"]
|
||||
required_optional_params = [
|
||||
"num_images_per_prompt",
|
||||
"generator",
|
||||
"num_inference_steps",
|
||||
"latents",
|
||||
"negative_prompt",
|
||||
"guidance_scale",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
]
|
||||
test_xformers_attention = False
|
||||
callback_cfg_params = ["text_encoder_hidden_states"]
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_input_dim(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def block_out_channels_0(self):
|
||||
return self.time_input_dim
|
||||
|
||||
@property
|
||||
def time_embed_dim(self):
|
||||
return self.time_input_dim * 4
|
||||
|
||||
@property
|
||||
def dummy_tokenizer(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
return tokenizer
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=self.text_embedder_hidden_size,
|
||||
projection_dim=self.text_embedder_hidden_size,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModelWithProjection(config).eval()
|
||||
|
||||
@property
|
||||
def dummy_prior(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"conditioning_dim": 128,
|
||||
"block_out_channels": (128, 128),
|
||||
"num_attention_heads": (2, 2),
|
||||
"down_num_layers_per_block": (1, 1),
|
||||
"up_num_layers_per_block": (1, 1),
|
||||
"switch_level": (False,),
|
||||
"clip_image_in_channels": 768,
|
||||
"clip_text_in_channels": self.text_embedder_hidden_size,
|
||||
"clip_text_pooled_in_channels": self.text_embedder_hidden_size,
|
||||
"dropout": (0.1, 0.1),
|
||||
}
|
||||
|
||||
model = StableCascadeUNet(**model_kwargs)
|
||||
return model.eval()
|
||||
|
||||
def get_dummy_components(self):
|
||||
prior = self.dummy_prior
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = self.dummy_tokenizer
|
||||
|
||||
scheduler = DDPMWuerstchenScheduler()
|
||||
|
||||
components = {
|
||||
"prior": prior,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"scheduler": scheduler,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "horse",
|
||||
"generator": generator,
|
||||
"guidance_scale": 4.0,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_wuerstchen_prior(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(device))
|
||||
image = output.image_embeddings
|
||||
|
||||
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, 0, 0, -10:]
|
||||
image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:]
|
||||
assert image.shape == (1, 16, 24, 24)
|
||||
|
||||
expected_slice = np.array(
|
||||
[
|
||||
96.139565,
|
||||
-20.213179,
|
||||
-116.40341,
|
||||
-191.57129,
|
||||
39.350136,
|
||||
74.80767,
|
||||
39.782352,
|
||||
-184.67352,
|
||||
-46.426907,
|
||||
168.41783,
|
||||
]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 5e-2
|
||||
|
||||
@skip_mps
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-1)
|
||||
|
||||
@skip_mps
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
test_max_difference = torch_device == "cpu"
|
||||
test_mean_pixel_difference = False
|
||||
|
||||
self._test_attention_slicing_forward_pass(
|
||||
test_max_difference=test_max_difference,
|
||||
test_mean_pixel_difference=test_mean_pixel_difference,
|
||||
)
|
||||
|
||||
@unittest.skip(reason="fp16 not supported")
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference()
|
||||
|
||||
def check_if_lora_correctly_set(self, model) -> bool:
|
||||
"""
|
||||
Checks if the LoRA layers are correctly set with peft
|
||||
"""
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_lora_components(self):
|
||||
prior = self.dummy_prior
|
||||
|
||||
prior_lora_config = LoraConfig(
|
||||
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
|
||||
)
|
||||
|
||||
prior_lora_attn_procs, prior_lora_layers = create_prior_lora_layers(prior)
|
||||
|
||||
lora_components = {
|
||||
"prior_lora_layers": prior_lora_layers,
|
||||
"prior_lora_attn_procs": prior_lora_attn_procs,
|
||||
}
|
||||
|
||||
return prior, prior_lora_config, lora_components
|
||||
|
||||
@require_peft_backend
|
||||
@unittest.skip(reason="no lora support for now")
|
||||
def test_inference_with_prior_lora(self):
|
||||
_, prior_lora_config, _ = self.get_lora_components()
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output_no_lora = pipe(**self.get_dummy_inputs(device))
|
||||
image_embed = output_no_lora.image_embeddings
|
||||
self.assertTrue(image_embed.shape == (1, 16, 24, 24))
|
||||
|
||||
pipe.prior.add_adapter(prior_lora_config)
|
||||
self.assertTrue(self.check_if_lora_correctly_set(pipe.prior), "Lora not correctly set in prior")
|
||||
|
||||
output_lora = pipe(**self.get_dummy_inputs(device))
|
||||
lora_image_embed = output_lora.image_embeddings
|
||||
|
||||
self.assertTrue(image_embed.shape == lora_image_embed.shape)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableCascadePriorPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_cascade_prior(self):
|
||||
pipe = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
output = pipe(prompt, num_inference_steps=10, generator=generator)
|
||||
image_embedding = output.image_embeddings
|
||||
|
||||
expected_image_embedding = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt"
|
||||
)
|
||||
|
||||
assert image_embedding.shape == (1, 16, 24, 24)
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
image_embedding.cpu().float().numpy(), expected_image_embedding.cpu().float().numpy(), atol=5e-2
|
||||
)
|
||||
)
|
||||
@@ -496,6 +496,22 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
|
||||
self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results")
|
||||
|
||||
def test_disable_cfg(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["max_guidance_scale"] = 1.0
|
||||
output = pipe(**inputs).frames
|
||||
self.assertEqual(len(output.shape), 5)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -99,14 +99,13 @@ class SDFunctionTesterMixin:
|
||||
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
|
||||
|
||||
def test_vae_tiling(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
if "safety_checker" in components:
|
||||
components["safety_checker"] = None
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
@@ -126,7 +125,7 @@ class SDFunctionTesterMixin:
|
||||
# test that tiled decode works with various shapes
|
||||
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
|
||||
for shape in shapes:
|
||||
zeros = torch.zeros(shape).to(device)
|
||||
zeros = torch.zeros(shape).to(torch_device)
|
||||
pipe.vae.decode(zeros)
|
||||
|
||||
def test_freeu_enabled(self):
|
||||
|
||||
@@ -45,7 +45,6 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
"return_dict",
|
||||
"prior_num_inference_steps",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
]
|
||||
test_xformers_attention = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user