mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
7 Commits
pinned-con
...
test-clean
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05df661f57 | ||
|
|
731dd5f1d1 | ||
|
|
493757b52a | ||
|
|
3901942d1b | ||
|
|
b36081acf6 | ||
|
|
4a6a03a98b | ||
|
|
9a393a8b10 |
@@ -1,26 +1,34 @@
|
||||
from logging import getLogger
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, logging
|
||||
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from ..pipeline_utils import ImagePipelineOutput
|
||||
from . import StableDiffusionUpscalePipeline
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
NUM_LATENT_CHANNELS = 4
|
||||
NUM_UNET_INPUT_CHANNELS = 7
|
||||
|
||||
ORT_TO_PT_TYPE = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
@@ -45,7 +53,17 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
vae: OnnxRuntimeModel
|
||||
text_encoder: OnnxRuntimeModel
|
||||
tokenizer: CLIPTokenizer
|
||||
unet: OnnxRuntimeModel
|
||||
low_res_scheduler: DDPMScheduler
|
||||
scheduler: KarrasDiffusionSchedulers
|
||||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPImageProcessor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_is_onnx = True
|
||||
|
||||
def __init__(
|
||||
@@ -55,39 +73,296 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
tokenizer: Any,
|
||||
unet: OnnxRuntimeModel,
|
||||
low_res_scheduler: DDPMScheduler,
|
||||
scheduler: Any,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: Optional[OnnxRuntimeModel] = None,
|
||||
feature_extractor: Optional[CLIPImageProcessor] = None,
|
||||
max_noise_level: int = 350,
|
||||
num_latent_channels=4,
|
||||
num_unet_input_channels=7,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
watermarker=None,
|
||||
max_noise_level=max_noise_level,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(
|
||||
max_noise_level=max_noise_level,
|
||||
num_latent_channels=num_latent_channels,
|
||||
num_unet_input_channels=num_unet_input_channels,
|
||||
)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image,
|
||||
noise_level,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, np.ndarray)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
)
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
|
||||
if isinstance(image, list) or isinstance(image, np.ndarray):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
else:
|
||||
image_batch_size = image.shape[0]
|
||||
if batch_size != image_batch_size:
|
||||
raise ValueError(
|
||||
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
|
||||
" Please make sure that passed `prompt` matches the batch size of `image`."
|
||||
)
|
||||
|
||||
# check noise level
|
||||
if noise_level > self.config.max_noise_level:
|
||||
raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
if latents is None:
|
||||
latents = generator.randn(*shape).astype(dtype)
|
||||
elif latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
image = self.vae(latent_sample=latents)[0]
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
return image
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: Optional[int],
|
||||
do_classifier_free_guidance: bool,
|
||||
negative_prompt: Optional[str],
|
||||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
prompt to be encoded
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`np.ndarray`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`np.ndarray`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
||||
|
||||
if not np.array_equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||
|
||||
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt] * batch_size
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
|
||||
image: Union[np.ndarray, PIL.Image.Image, List[PIL.Image.Image]],
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 20,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
generator: Optional[Union[np.random.RandomState, List[np.random.RandomState]]] = None,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
):
|
||||
r"""
|
||||
@@ -108,7 +383,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
noise_level TODO
|
||||
noise_level (`float`, defaults to 0.2):
|
||||
Deteremines the amount of noise to add to the initial image before performing upscaling.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
@@ -152,7 +428,15 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
"""
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, image, noise_level, callback_steps)
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
noise_level,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -162,16 +446,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
if generator is None:
|
||||
generator = np.random
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
@@ -179,51 +463,55 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
|
||||
latents_dtype = prompt_embeds.dtype
|
||||
image = preprocess(image).cpu().numpy()
|
||||
height, width = image.shape[2:]
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = image.cpu()
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.num_latent_channels,
|
||||
height,
|
||||
width,
|
||||
latents_dtype,
|
||||
generator,
|
||||
)
|
||||
image = image.astype(latents_dtype)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||
noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
noise_level = np.array([noise_level]).astype(np.int64)
|
||||
noise = generator.randn(*image.shape).astype(latents_dtype)
|
||||
|
||||
image = self.low_res_scheduler.add_noise(
|
||||
torch.from_numpy(image), torch.from_numpy(noise), torch.from_numpy(noise_level)
|
||||
)
|
||||
image = image.numpy()
|
||||
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
image = np.concatenate([image] * batch_multiplier * num_images_per_prompt)
|
||||
noise_level = np.concatenate([noise_level] * image.shape[0])
|
||||
|
||||
# 6. Prepare latent variables
|
||||
height, width = image.shape[2:]
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
NUM_LATENT_CHANNELS,
|
||||
height,
|
||||
width,
|
||||
latents_dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Check that sizes of image and latents match
|
||||
num_channels_image = image.shape[1]
|
||||
if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
|
||||
if self.num_latent_channels + num_channels_image != self.num_unet_input_channels:
|
||||
raise ValueError(
|
||||
"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
||||
f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
|
||||
f" {self.num_unet_input_channels} but received `num_channels_latents`: {self.num_latent_channels} +"
|
||||
f" `num_channels_image`: {num_channels_image} "
|
||||
f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
|
||||
f" = {self.num_latent_channels + num_channels_image}. Please verify the config of"
|
||||
" `pipeline.unet` or your `image` input."
|
||||
)
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||
@@ -248,8 +536,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
class_labels=noise_level.astype(np.int64),
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
class_labels=noise_level,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
@@ -259,8 +547,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
|
||||
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
|
||||
).prev_sample
|
||||
latents = latents.numpy()
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
@@ -269,125 +558,28 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
image = self.decode_latents(latents.float())
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="np"
|
||||
).pixel_values.astype(image.dtype)
|
||||
|
||||
images, has_nsfw_concept = [], []
|
||||
for i in range(image.shape[0]):
|
||||
image_i, has_nsfw_concept_i = self.safety_checker(
|
||||
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
|
||||
)
|
||||
images.append(image_i)
|
||||
has_nsfw_concept.append(has_nsfw_concept_i[0])
|
||||
image = np.concatenate(images)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
image = self.vae(latent_sample=latents)[0]
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
return image
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device,
|
||||
num_images_per_prompt: Optional[int],
|
||||
do_classifier_free_guidance: bool,
|
||||
negative_prompt: Optional[str],
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
# no positional arguments to text_encoder
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids.int().to(device),
|
||||
# attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
prompt_embeds = prompt_embeds.reshape(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# if hasattr(uncond_input, "attention_mask"):
|
||||
# attention_mask = uncond_input.attention_mask.to(device)
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.int().to(device),
|
||||
# attention_mask=attention_mask,
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
|
||||
uncond_embeddings = uncond_embeddings.reshape(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = np.concatenate([uncond_embeddings, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
@@ -221,7 +221,7 @@ class ModelTesterMixin:
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
def test_getattr_is_correct(self):
|
||||
@@ -351,7 +351,7 @@ class ModelTesterMixin:
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
@require_torch_2
|
||||
|
||||
@@ -137,7 +137,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_accelerate.config.in_channels,
|
||||
model_accelerate.config.sample_size,
|
||||
model_accelerate.config.sample_size,
|
||||
generator=torch.manual_seed(0),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
)
|
||||
noise = noise.to(torch_device)
|
||||
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
|
||||
@@ -263,7 +263,7 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-4842.8691, -6499.6631, -3800.1953, -7978.2686, -10980.7129, -20028.8535, 8148.2822, 2342.2905, 567.7608])
|
||||
expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
@@ -726,8 +726,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
model.disable_xformers_memory_efficient_attention()
|
||||
off_sample = model(**inputs_dict).sample
|
||||
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
assert (sample - on_sample).abs().max() <= 5e-4
|
||||
assert (sample - off_sample).abs().max() <= 5e-4
|
||||
|
||||
def test_custom_diffusion_processors(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
|
||||
@@ -1008,4 +1008,4 @@ class StableDiffusionMultiControlNetPipelineNightlyTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).sum() < 1e-3
|
||||
assert np.abs(images[0] - images[1]).max() < 1e-3
|
||||
|
||||
@@ -455,4 +455,4 @@ class ControlNetImg2ImgPipelineNightlyTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).sum() < 1e-3
|
||||
assert np.abs(images[0] - images[1]).max() < 1e-3
|
||||
|
||||
@@ -602,4 +602,4 @@ class ControlNetInpaintPipelineNightlyTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).sum() < 1e-3
|
||||
assert np.abs(images[0] - images[1]).max() < 1e-3
|
||||
|
||||
@@ -17,7 +17,6 @@ import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
DPMSolverMultistepScheduler,
|
||||
@@ -49,7 +48,7 @@ class OnnxStableDiffusionUpscalePipelineFastTests(OnnxPipelineTesterMixin, unitt
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
image = floats_tensor((1, 3, 128, 128), rng=random.Random(seed))
|
||||
generator = torch.manual_seed(seed)
|
||||
generator = np.random.RandomState(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
@@ -70,9 +69,7 @@ class OnnxStableDiffusionUpscalePipelineFastTests(OnnxPipelineTesterMixin, unitt
|
||||
|
||||
# started as 128, should now be 512
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array(
|
||||
[0.6974782, 0.68902093, 0.70135885, 0.7583618, 0.7804545, 0.7854912, 0.78667426, 0.78743863, 0.78070223]
|
||||
)
|
||||
expected_slice = np.array([0.6957, 0.7002, 0.7186, 0.6881, 0.6693, 0.6910, 0.7445, 0.7274, 0.7056])
|
||||
assert np.abs(image_slice - expected_slice).max() < 1e-1
|
||||
|
||||
def test_pipeline_pndm(self):
|
||||
@@ -85,9 +82,7 @@ class OnnxStableDiffusionUpscalePipelineFastTests(OnnxPipelineTesterMixin, unitt
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array(
|
||||
[0.6898892, 0.59240556, 0.52499527, 0.58866215, 0.52258235, 0.52572715, 0.62414473, 0.6174387, 0.6214964]
|
||||
)
|
||||
expected_slice = np.array([0.7349, 0.7347, 0.7034, 0.7696, 0.7876, 0.7597, 0.7916, 0.8085, 0.8036])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||
|
||||
def test_pipeline_dpm_multistep(self):
|
||||
@@ -174,7 +169,7 @@ class OnnxStableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = np.random.RandomState(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
@@ -211,7 +206,7 @@ class OnnxStableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = np.random.RandomState(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
|
||||
@@ -122,7 +122,7 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
@@ -1543,7 +1543,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@require_torch_2
|
||||
def test_from_save_pretrained_dynamo(self):
|
||||
@@ -1568,7 +1568,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
def test_from_pretrained_hub_pass_model(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
@@ -1591,7 +1591,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
def test_output_format(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
@@ -1625,7 +1625,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
from diffusers import FlaxStableDiffusionPipeline
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe_pt.save_pretrained(tmpdirname)
|
||||
pipe_pt.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
|
||||
pipe_flax, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
tmpdirname, safety_checker=None, from_pt=True
|
||||
|
||||
@@ -76,7 +76,7 @@ class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't have the same forward pass"
|
||||
|
||||
def test_inference_dual_guided(self):
|
||||
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion")
|
||||
|
||||
@@ -77,7 +77,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't have the same forward pass"
|
||||
|
||||
def test_inference_dual_guided_then_text_to_image(self):
|
||||
pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16)
|
||||
|
||||
@@ -64,7 +64,7 @@ class VersatileDiffusionTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=2, output_type="numpy"
|
||||
).images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
|
||||
assert np.abs(image - new_image).max() < 1e-5, "Models don't have the same forward pass"
|
||||
|
||||
def test_inference_text2img(self):
|
||||
pipe = VersatileDiffusionTextToImagePipeline.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user