mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
9 Commits
update-sty
...
latent-ups
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d989ce8f40 | ||
|
|
70eb8970cc | ||
|
|
b35a75a7f7 | ||
|
|
b12f7d7cf7 | ||
|
|
4c128d0f37 | ||
|
|
1947c2d39d | ||
|
|
1673c91fb1 | ||
|
|
25d9e54ae8 | ||
|
|
b0a829d65e |
@@ -73,6 +73,7 @@ if is_torch_available() and is_transformers_available():
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalerPipeline,
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -419,6 +419,8 @@ class ResnetBlock2D(nn.Module):
|
||||
self.nonlinearity = Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif non_linearity == "gelu":
|
||||
self.nonlinearity = nn.GELU()
|
||||
|
||||
self.upsample = self.downsample = None
|
||||
if self.up:
|
||||
|
||||
@@ -23,6 +23,7 @@ if is_torch_available() and is_transformers_available():
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalerPipeline,
|
||||
)
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ if is_transformers_available() and is_torch_available():
|
||||
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
|
||||
from .pipeline_stable_diffusion_upscaler import StableDiffusionUpscalerPipeline
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
|
||||
@@ -516,15 +516,20 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
if output_type == "latents":
|
||||
# Skip safety checking if we are returning latents
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 9. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
||||
# 9. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
||||
|
||||
# 10. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
# 10. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
@@ -0,0 +1,443 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...utils import deprecate, logging
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
# TODO: Remove when we migrate the upscaler model to diffusers >>>>>>>
|
||||
import k_diffusion as K
|
||||
import huggingface_hub
|
||||
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
UPSCALER_REPO = "pcuenq/k-upscaler"
|
||||
|
||||
class NoiseLevelAndTextConditionedUpscaler(nn.Module):
|
||||
def __init__(self, inner_model, sigma_data=1., embed_dim=256):
|
||||
super().__init__()
|
||||
self.inner_model = inner_model
|
||||
self.sigma_data = sigma_data
|
||||
self.low_res_noise_embed = K.layers.FourierFeatures(1, embed_dim, std=2)
|
||||
|
||||
def forward(self, input, sigma, low_res, low_res_sigma, c, **kwargs):
|
||||
cross_cond, cross_cond_padding, pooler = c
|
||||
c_in = 1 / (low_res_sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_noise = low_res_sigma.log1p()[:, None]
|
||||
c_in = K.utils.append_dims(c_in, low_res.ndim)
|
||||
low_res_noise_embed = self.low_res_noise_embed(c_noise)
|
||||
low_res_in = F.interpolate(low_res, scale_factor=2, mode='nearest') * c_in
|
||||
mapping_cond = torch.cat([low_res_noise_embed, pooler], dim=1)
|
||||
return self.inner_model(input, sigma, unet_cond=low_res_in, mapping_cond=mapping_cond, cross_cond=cross_cond, cross_cond_padding=cross_cond_padding, **kwargs)
|
||||
|
||||
def make_upscaler_model(config_path, model_path, pooler_dim=768, train=False, device='cpu'):
|
||||
config = K.config.load_config(open(config_path))
|
||||
model = K.config.make_model(config)
|
||||
model = NoiseLevelAndTextConditionedUpscaler(
|
||||
model,
|
||||
sigma_data=config['model']['sigma_data'],
|
||||
embed_dim=config['model']['mapping_cond_dim'] - pooler_dim,
|
||||
)
|
||||
ckpt = torch.load(model_path, map_location='cpu')
|
||||
model.load_state_dict(ckpt['model_ema'])
|
||||
model = K.config.make_denoiser_wrapper(config)(model)
|
||||
if not train:
|
||||
model = model.eval().requires_grad_(False)
|
||||
return model.to(device)
|
||||
|
||||
# <<<<<< To be removed when we migrate upscaler model to diffusers
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionUpscalerPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with the upscaler model. Currently restricted to `EulerDiscreteScheduler`.nn
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
scheduler: EulerDiscreteScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
# Download upscaler
|
||||
config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")
|
||||
weights_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth")
|
||||
self.upscaler = make_upscaler_model(config_path, weights_path)
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||
self.upscaler.to(torch_device)
|
||||
return super().to(torch_device)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
# TODO: enable in the upscaler
|
||||
pass
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
# TODO: disable in the upscaler
|
||||
pass
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
for cpu_offloaded_model in [self.text_encoder, self.vae, self.safety_checker, self.upscaler]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.vae.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _get_text_conditioning(self, prompt, device, do_classifier_free_guidance):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
def get_conditioning(text):
|
||||
text_inputs = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
cross_cond_padding = 1 - attention_mask
|
||||
|
||||
# I believe the attention mask should be provided here, but the original notebook does not do it
|
||||
# TODO: test it out
|
||||
# clip_output = self.text_encoder(input_ids=text_input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||
clip_output = self.text_encoder(input_ids=text_input_ids, output_hidden_states=True)
|
||||
hidden_states = clip_output.hidden_states[-1]
|
||||
pooler_output = clip_output.pooler_output
|
||||
|
||||
return hidden_states, cross_cond_padding.to(dtype=hidden_states.dtype), pooler_output
|
||||
|
||||
prompt_conditioning = get_conditioning(prompt)
|
||||
conditioning = prompt_conditioning
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_conditioning = get_conditioning(batch_size * [""])
|
||||
conditioning = [torch.cat([uc_item, c_item]) for uc_item, c_item in zip(uncond_conditioning, prompt_conditioning)]
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(self, latents, prompt):
|
||||
batch, _, height, width = latents.shape
|
||||
if height != width:
|
||||
raise ValueError(f"Latents should be square, got {height}x{width} instead")
|
||||
|
||||
if prompt is not None:
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if isinstance(prompt, list) and len(prompt) != batch:
|
||||
raise ValueError(f"`prompt` length has to be equal to the latents batch_size ({batch}), but is {len(prompt)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
latents (`torch.FloatTensor`):
|
||||
Latents to be upscaled. Generated from a Stable Diffusion Pipeline using `output_type="latents"`.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image upscaling process.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(latents, prompt)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size, channels, height, width = latents.shape
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt] * batch_size
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Prepare timesteps
|
||||
# We take log-linear steps in noise-level from sigma_max to sigma_min
|
||||
# TODO(Pedro) Fix: create the scheduler with the betas instead
|
||||
sigma_min = self.scheduler.sigmas[-2] # Last one is zero
|
||||
sigma_max = self.scheduler.sigmas[0]
|
||||
# The +1 comes from k-diffusion
|
||||
sigmas = torch.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps+1).exp().to(device)
|
||||
# scheduler.sigmas = torch.cat((sigmas, torch.tensor([0.]).to(device)))
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
self.scheduler.sigmas = sigmas
|
||||
|
||||
# # 4. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
# extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 5. Prepare conditioning
|
||||
conditioning = self._get_text_conditioning(prompt, device, do_classifier_free_guidance)
|
||||
|
||||
# 6. Create initial noise
|
||||
x_shape = [batch_size, channels, 2*height, 2*width]
|
||||
noisy_latents = torch.randn(x_shape, generator=generator, device=device, dtype=sigmas.dtype)
|
||||
|
||||
# Disabled; according to the notebook it doesn't seem to work well
|
||||
# TODO: remove in final implementation
|
||||
low_res_sigma = torch.full([batch_size], 0, device=device, dtype=sigmas.dtype)
|
||||
|
||||
# 7. Prepare inputs for CFG
|
||||
low_res = latents
|
||||
if do_classifier_free_guidance:
|
||||
low_res = torch.cat([low_res] * 2)
|
||||
low_res_sigma = torch.cat([low_res_sigma] * 2)
|
||||
|
||||
# 8. Denoising loop
|
||||
noisy_latents = noisy_latents * sigma_max
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
sigma = sigmas[i]
|
||||
sigma = sigma[None]
|
||||
|
||||
latent_model_input = noisy_latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
latent_model_input = torch.cat([latent_model_input] * 2)
|
||||
sigma = torch.cat([sigma] * 2)
|
||||
|
||||
# predict the next denoised latent
|
||||
denoised = self.upscaler(
|
||||
latent_model_input,
|
||||
sigma,
|
||||
low_res=low_res,
|
||||
low_res_sigma=low_res_sigma,
|
||||
c=conditioning,
|
||||
)
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond, cond = denoised.chunk(2)
|
||||
denoised = uncond + guidance_scale * (cond - uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
noisy_latents = self.scheduler.step(denoised, t, noisy_latents).prev_sample
|
||||
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(noisy_latents)
|
||||
|
||||
# 10. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, sigmas.dtype)
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_utils import SchedulerMixin, betas_from_loglinear_sigmas
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -65,6 +65,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
predict_epsilon (`bool`):
|
||||
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
|
||||
|
||||
"""
|
||||
|
||||
@@ -78,6 +80,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
predict_epsilon: bool = True,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
@@ -88,6 +91,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "loglinear_sigmas":
|
||||
# This scheduler is specific to k-diffusion latent upscaler
|
||||
# We use a helper function because the computation is a bit involved
|
||||
# Alternative: create from a list of `trained_betas` (but see https://github.com/huggingface/diffusers/issues/1367)
|
||||
self.betas = betas_from_loglinear_sigmas(beta_start, beta_end, num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
@@ -120,12 +128,15 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
self.is_scale_input_called = True
|
||||
if not self.config.predict_epsilon:
|
||||
return sample
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
@@ -229,7 +240,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
if self.config.predict_epsilon:
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
else:
|
||||
pred_original_sample = model_output
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
|
||||
@@ -16,6 +16,7 @@ import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..utils import BaseOutput
|
||||
@@ -152,3 +153,47 @@ class SchedulerMixin:
|
||||
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
||||
]
|
||||
return compatible_classes
|
||||
|
||||
def betas_from_loglinear_sigmas(beta_start, beta_end, num_timesteps):
|
||||
"""
|
||||
Computes the beta values suitable to create a loglinear schedule of sigmas,
|
||||
as used in k-diffusion latent upscaler.
|
||||
|
||||
Concretely, these are the betas the create a sigma schedule like the following:
|
||||
```
|
||||
torch.linspace(np.log(sigma_max), np.log(sigma_min), num_timesteps).exp()
|
||||
```
|
||||
|
||||
Args:
|
||||
beta_start (`float`): The start sigma value.
|
||||
beta_end (`float`): The end sigma value.
|
||||
num_timesteps (`int`): The number of training timesteps.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The beta values.
|
||||
"""
|
||||
# First, compute sigma_max and sigma_min considering a "scaled_linear" schedule
|
||||
# as used in Stable Diffusion. We just need sigma_min and sigma_max.
|
||||
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2
|
||||
alphas = 1 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
sigma_min, sigma_max = sigmas[0], sigmas[-1]
|
||||
|
||||
# Then, compute the actual loglinear sigmas from sigma_min, sigma_max
|
||||
sigmas = torch.linspace(np.log(sigma_max), np.log(sigma_min), num_timesteps)
|
||||
sigmas = sigmas.exp()
|
||||
sigmas = np.array(sigmas)[::-1]
|
||||
alpha_cumprod = 1./(1+sigmas**2)
|
||||
|
||||
# Compute the alpha values reversing alpha_cumprod
|
||||
alphas = []
|
||||
prev_prod = 1.
|
||||
for a in alpha_cumprod:
|
||||
current_alpha = a / prev_prod
|
||||
alphas.append(current_alpha)
|
||||
prev_prod = a
|
||||
|
||||
# Get the betas from the alphas
|
||||
betas = 1 - np.array(alphas)
|
||||
return torch.Tensor(betas)
|
||||
|
||||
Reference in New Issue
Block a user