Compare commits

...

7 Commits

Author SHA1 Message Date
Dhruv Nair
05df661f57 test differences in controlnet with max insteda of sum 2023-08-24 09:32:21 +00:00
Dhruv Nair
731dd5f1d1 use max to test difference 2023-08-24 09:31:00 +00:00
Dhruv Nair
493757b52a use max to test difference instead of sum 2023-08-24 09:26:49 +00:00
Dhruv Nair
3901942d1b fix precision tolerance anchange max diff from sum to max 2023-08-24 09:24:57 +00:00
Dhruv Nair
b36081acf6 clean up 2023-08-23 06:32:18 +00:00
Dhruv Nair
4a6a03a98b clean up sd onnx upscale 2023-08-22 14:21:32 +00:00
Dhruv Nair
9a393a8b10 initial commit to fix inheritance issue 2023-08-17 12:49:27 +00:00
12 changed files with 391 additions and 204 deletions

View File

@@ -1,26 +1,34 @@
from logging import getLogger
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, List, Optional, Union
import numpy as np
import PIL
import torch
from transformers import CLIPImageProcessor, CLIPTokenizer
from ...schedulers import DDPMScheduler
from ...configuration_utils import FrozenDict
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import deprecate, logging
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ..pipeline_utils import ImagePipelineOutput
from . import StableDiffusionUpscalePipeline
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
logger = getLogger(__name__)
NUM_LATENT_CHANNELS = 4
NUM_UNET_INPUT_CHANNELS = 7
ORT_TO_PT_TYPE = {
"float16": torch.float16,
"float32": torch.float32,
}
logger = logging.get_logger(__name__)
def preprocess(image):
@@ -45,7 +53,17 @@ def preprocess(image):
return image
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
vae: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
tokenizer: CLIPTokenizer
unet: OnnxRuntimeModel
low_res_scheduler: DDPMScheduler
scheduler: KarrasDiffusionSchedulers
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPImageProcessor
_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
def __init__(
@@ -55,39 +73,296 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
tokenizer: Any,
unet: OnnxRuntimeModel,
low_res_scheduler: DDPMScheduler,
scheduler: Any,
scheduler: KarrasDiffusionSchedulers,
safety_checker: Optional[OnnxRuntimeModel] = None,
feature_extractor: Optional[CLIPImageProcessor] = None,
max_noise_level: int = 350,
num_latent_channels=4,
num_unet_input_channels=7,
requires_safety_checker: bool = True,
):
super().__init__(
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
low_res_scheduler=low_res_scheduler,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
watermarker=None,
max_noise_level=max_noise_level,
low_res_scheduler=low_res_scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(
max_noise_level=max_noise_level,
num_latent_channels=num_latent_channels,
num_unet_input_channels=num_unet_input_channels,
)
def check_inputs(
self,
prompt: Union[str, List[str]],
image,
noise_level,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, np.ndarray)
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}"
)
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, np.ndarray):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if isinstance(image, list):
image_batch_size = len(image)
else:
image_batch_size = image.shape[0]
if batch_size != image_batch_size:
raise ValueError(
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
" Please make sure that passed `prompt` matches the batch size of `image`."
)
# check noise level
if noise_level > self.config.max_noise_level:
raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
latents = generator.randn(*shape).astype(dtype)
elif latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
return latents
def decode_latents(self, latents):
latents = 1 / 0.08333 * latents
image = self.vae(latent_sample=latents)[0]
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
return image
def _encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: Optional[int],
do_classifier_free_guidance: bool,
negative_prompt: Optional[str],
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`):
prompt to be encoded
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np",
)
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
if do_classifier_free_guidance:
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
image: Union[np.ndarray, PIL.Image.Image, List[PIL.Image.Image]],
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
generator: Optional[Union[np.random.RandomState, List[np.random.RandomState]]] = None,
latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
@@ -108,7 +383,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
noise_level TODO
noise_level (`float`, defaults to 0.2):
Deteremines the amount of noise to add to the initial image before performing upscaling.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
@@ -152,7 +428,15 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
"""
# 1. Check inputs
self.check_inputs(prompt, image, noise_level, callback_steps)
self.check_inputs(
prompt,
image,
noise_level,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -162,16 +446,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
if generator is None:
generator = np.random
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
@@ -179,51 +463,55 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
negative_prompt_embeds=negative_prompt_embeds,
)
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
latents_dtype = prompt_embeds.dtype
image = preprocess(image).cpu().numpy()
height, width = image.shape[2:]
# 4. Preprocess image
image = preprocess(image)
image = image.cpu()
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
self.num_latent_channels,
height,
width,
latents_dtype,
generator,
)
image = image.astype(latents_dtype)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# Scale the initial noise by the standard deviation required by the scheduler
latents = latents * np.float64(self.scheduler.init_noise_sigma)
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
noise_level = np.array([noise_level]).astype(np.int64)
noise = generator.randn(*image.shape).astype(latents_dtype)
image = self.low_res_scheduler.add_noise(
torch.from_numpy(image), torch.from_numpy(noise), torch.from_numpy(noise_level)
)
image = image.numpy()
batch_multiplier = 2 if do_classifier_free_guidance else 1
image = np.concatenate([image] * batch_multiplier * num_images_per_prompt)
noise_level = np.concatenate([noise_level] * image.shape[0])
# 6. Prepare latent variables
height, width = image.shape[2:]
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
NUM_LATENT_CHANNELS,
height,
width,
latents_dtype,
device,
generator,
latents,
)
# 7. Check that sizes of image and latents match
num_channels_image = image.shape[1]
if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
if self.num_latent_channels + num_channels_image != self.num_unet_input_channels:
raise ValueError(
"Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
f" {self.num_unet_input_channels} but received `num_channels_latents`: {self.num_latent_channels} +"
f" `num_channels_image`: {num_channels_image} "
f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
f" = {self.num_latent_channels + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
@@ -248,8 +536,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
noise_pred = self.unet(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=text_embeddings,
class_labels=noise_level.astype(np.int64),
encoder_hidden_states=prompt_embeds,
class_labels=noise_level,
)[0]
# perform guidance
@@ -259,8 +547,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
).prev_sample
latents = latents.numpy()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -269,125 +558,28 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
callback(i, t, latents)
# 10. Post-processing
image = self.decode_latents(latents.float())
image = self.decode_latents(latents)
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return (image, has_nsfw_concept)
return ImagePipelineOutput(images=image)
def decode_latents(self, latents):
latents = 1 / 0.08333 * latents
image = self.vae(latent_sample=latents)[0]
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
return image
def _encode_prompt(
self,
prompt: Union[str, List[str]],
device,
num_images_per_prompt: Optional[int],
do_classifier_free_guidance: bool,
negative_prompt: Optional[str],
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
# no positional arguments to text_encoder
prompt_embeds = self.text_encoder(
input_ids=text_input_ids.int().to(device),
# attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.reshape(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
# if hasattr(uncond_input, "attention_mask"):
# attention_mask = uncond_input.attention_mask.to(device)
# else:
# attention_mask = None
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.int().to(device),
# attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
if do_classifier_free_guidance:
seq_len = uncond_embeddings.shape[1]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
uncond_embeddings = uncond_embeddings.reshape(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = np.concatenate([uncond_embeddings, prompt_embeds])
return prompt_embeds
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -221,7 +221,7 @@ class ModelTesterMixin:
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().sum().item()
max_diff = (image - new_image).abs().max().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_getattr_is_correct(self):
@@ -351,7 +351,7 @@ class ModelTesterMixin:
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().sum().item()
max_diff = (image - new_image).abs().max().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
@require_torch_2

View File

@@ -137,7 +137,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_accelerate.config.in_channels,
model_accelerate.config.sample_size,
model_accelerate.config.sample_size,
generator=torch.manual_seed(0),
generator=torch.Generator("cpu").manual_seed(0),
)
noise = noise.to(torch_device)
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
@@ -263,7 +263,7 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off
expected_output_slice = torch.tensor([-4842.8691, -6499.6631, -3800.1953, -7978.2686, -10980.7129, -20028.8535, 8148.2822, 2342.2905, 567.7608])
expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
# fmt: on
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))

View File

@@ -726,8 +726,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
model.disable_xformers_memory_efficient_attention()
off_sample = model(**inputs_dict).sample
assert (sample - on_sample).abs().max() < 1e-4
assert (sample - off_sample).abs().max() < 1e-4
assert (sample - on_sample).abs().max() <= 5e-4
assert (sample - off_sample).abs().max() <= 5e-4
def test_custom_diffusion_processors(self):
# enable deterministic behavior for gradient checkpointing

View File

@@ -1008,4 +1008,4 @@ class StableDiffusionMultiControlNetPipelineNightlyTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
assert np.abs(images[0] - images[1]).sum() < 1e-3
assert np.abs(images[0] - images[1]).max() < 1e-3

View File

@@ -455,4 +455,4 @@ class ControlNetImg2ImgPipelineNightlyTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
assert np.abs(images[0] - images[1]).sum() < 1e-3
assert np.abs(images[0] - images[1]).max() < 1e-3

View File

@@ -602,4 +602,4 @@ class ControlNetInpaintPipelineNightlyTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
assert np.abs(images[0] - images[1]).sum() < 1e-3
assert np.abs(images[0] - images[1]).max() < 1e-3

View File

@@ -17,7 +17,6 @@ import random
import unittest
import numpy as np
import torch
from diffusers import (
DPMSolverMultistepScheduler,
@@ -49,7 +48,7 @@ class OnnxStableDiffusionUpscalePipelineFastTests(OnnxPipelineTesterMixin, unitt
def get_dummy_inputs(self, seed=0):
image = floats_tensor((1, 3, 128, 128), rng=random.Random(seed))
generator = torch.manual_seed(seed)
generator = np.random.RandomState(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
@@ -70,9 +69,7 @@ class OnnxStableDiffusionUpscalePipelineFastTests(OnnxPipelineTesterMixin, unitt
# started as 128, should now be 512
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array(
[0.6974782, 0.68902093, 0.70135885, 0.7583618, 0.7804545, 0.7854912, 0.78667426, 0.78743863, 0.78070223]
)
expected_slice = np.array([0.6957, 0.7002, 0.7186, 0.6881, 0.6693, 0.6910, 0.7445, 0.7274, 0.7056])
assert np.abs(image_slice - expected_slice).max() < 1e-1
def test_pipeline_pndm(self):
@@ -85,9 +82,7 @@ class OnnxStableDiffusionUpscalePipelineFastTests(OnnxPipelineTesterMixin, unitt
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array(
[0.6898892, 0.59240556, 0.52499527, 0.58866215, 0.52258235, 0.52572715, 0.62414473, 0.6174387, 0.6214964]
)
expected_slice = np.array([0.7349, 0.7347, 0.7034, 0.7696, 0.7876, 0.7597, 0.7916, 0.8085, 0.8036])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
def test_pipeline_dpm_multistep(self):
@@ -174,7 +169,7 @@ class OnnxStableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
prompt = "A fantasy landscape, trending on artstation"
generator = torch.manual_seed(0)
generator = np.random.RandomState(0)
output = pipe(
prompt=prompt,
image=init_image,
@@ -211,7 +206,7 @@ class OnnxStableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
prompt = "A fantasy landscape, trending on artstation"
generator = torch.manual_seed(0)
generator = np.random.RandomState(0)
output = pipe(
prompt=prompt,
image=init_image,

View File

@@ -122,7 +122,7 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
except Exception:
error = f"{traceback.format_exc()}"
@@ -1543,7 +1543,7 @@ class PipelineSlowTests(unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
@require_torch_2
def test_from_save_pretrained_dynamo(self):
@@ -1568,7 +1568,7 @@ class PipelineSlowTests(unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
def test_from_pretrained_hub_pass_model(self):
model_path = "google/ddpm-cifar10-32"
@@ -1591,7 +1591,7 @@ class PipelineSlowTests(unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = ddpm_from_hub(generator=generator, num_inference_steps=5, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"
@@ -1625,7 +1625,7 @@ class PipelineSlowTests(unittest.TestCase):
from diffusers import FlaxStableDiffusionPipeline
with tempfile.TemporaryDirectory() as tmpdirname:
pipe_pt.save_pretrained(tmpdirname)
pipe_pt.save_pretrained(tmpdirname, safe_serialization=False)
pipe_flax, params = FlaxStableDiffusionPipeline.from_pretrained(
tmpdirname, safety_checker=None, from_pt=True

View File

@@ -76,7 +76,7 @@ class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
output_type="numpy",
).images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
assert np.abs(image - new_image).max() < 1e-5, "Models don't have the same forward pass"
def test_inference_dual_guided(self):
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion")

View File

@@ -77,7 +77,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
output_type="numpy",
).images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
assert np.abs(image - new_image).max() < 1e-5, "Models don't have the same forward pass"
def test_inference_dual_guided_then_text_to_image(self):
pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16)

View File

@@ -64,7 +64,7 @@ class VersatileDiffusionTextToImagePipelineIntegrationTests(unittest.TestCase):
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=2, output_type="numpy"
).images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
assert np.abs(image - new_image).max() < 1e-5, "Models don't have the same forward pass"
def test_inference_text2img(self):
pipe = VersatileDiffusionTextToImagePipeline.from_pretrained(