Compare commits

...

9 Commits

Author SHA1 Message Date
Pedro Cuenca
d989ce8f40 Merge branch 'latent-upscaler' of github.com:huggingface/diffusers into latent-upscaler 2022-11-22 21:44:13 +01:00
Pedro Cuenca
70eb8970cc Create "loglinear_sigmas" schedule.
Currently implemented in EulerDiscreteScheduler. An alternative would
have been to initialize the scheduler with an array of `trained_betas`.
However, that is currently not possible because of #1367.
2022-11-22 21:41:03 +01:00
Pedro Cuenca
b35a75a7f7 Remove additional helper class.
And make denoising loop similar to the other pipelines.
2022-11-22 17:03:32 +01:00
Pedro Cuenca
b12f7d7cf7 Remove clip helper classes. 2022-11-22 16:43:29 +01:00
Pedro Cuenca
4c128d0f37 Add GELU as a non-linearity. 2022-11-21 10:33:38 +01:00
Pedro Cuenca
1947c2d39d Remove deprecation warnings. 2022-11-19 20:46:39 +01:00
Pedro Cuenca
1673c91fb1 Rough upscaler pipeline 2022-11-16 21:19:35 +01:00
Pedro Cuenca
25d9e54ae8 StableDiffusionPipeline can return undecoded latents. 2022-11-16 21:18:55 +01:00
Pedro Cuenca
b0a829d65e EulerDiscreteScheduler supports predict_epsilon 2022-11-16 21:18:16 +01:00
8 changed files with 522 additions and 10 deletions

View File

@@ -73,6 +73,7 @@ if is_torch_available() and is_transformers_available():
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
StableDiffusionUpscalerPipeline,
VQDiffusionPipeline,
)
else:

View File

@@ -419,6 +419,8 @@ class ResnetBlock2D(nn.Module):
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
self.upsample = self.downsample = None
if self.up:

View File

@@ -23,6 +23,7 @@ if is_torch_available() and is_transformers_available():
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
StableDiffusionUpscalerPipeline,
)
from .vq_diffusion import VQDiffusionPipeline

View File

@@ -33,6 +33,7 @@ if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
from .pipeline_stable_diffusion_upscaler import StableDiffusionUpscalerPipeline
from .safety_checker import StableDiffusionSafetyChecker
if is_transformers_available() and is_onnx_available():

View File

@@ -516,15 +516,20 @@ class StableDiffusionPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
if output_type == "latents":
# Skip safety checking if we are returning latents
image = latents
has_nsfw_concept = None
else:
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)

View File

@@ -0,0 +1,443 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.utils import is_accelerate_available
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
# TODO: Remove when we migrate the upscaler model to diffusers >>>>>>>
import k_diffusion as K
import huggingface_hub
from torch import nn
import torch.nn.functional as F
UPSCALER_REPO = "pcuenq/k-upscaler"
class NoiseLevelAndTextConditionedUpscaler(nn.Module):
def __init__(self, inner_model, sigma_data=1., embed_dim=256):
super().__init__()
self.inner_model = inner_model
self.sigma_data = sigma_data
self.low_res_noise_embed = K.layers.FourierFeatures(1, embed_dim, std=2)
def forward(self, input, sigma, low_res, low_res_sigma, c, **kwargs):
cross_cond, cross_cond_padding, pooler = c
c_in = 1 / (low_res_sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_noise = low_res_sigma.log1p()[:, None]
c_in = K.utils.append_dims(c_in, low_res.ndim)
low_res_noise_embed = self.low_res_noise_embed(c_noise)
low_res_in = F.interpolate(low_res, scale_factor=2, mode='nearest') * c_in
mapping_cond = torch.cat([low_res_noise_embed, pooler], dim=1)
return self.inner_model(input, sigma, unet_cond=low_res_in, mapping_cond=mapping_cond, cross_cond=cross_cond, cross_cond_padding=cross_cond_padding, **kwargs)
def make_upscaler_model(config_path, model_path, pooler_dim=768, train=False, device='cpu'):
config = K.config.load_config(open(config_path))
model = K.config.make_model(config)
model = NoiseLevelAndTextConditionedUpscaler(
model,
sigma_data=config['model']['sigma_data'],
embed_dim=config['model']['mapping_cond_dim'] - pooler_dim,
)
ckpt = torch.load(model_path, map_location='cpu')
model.load_state_dict(ckpt['model_ema'])
model = K.config.make_denoiser_wrapper(config)(model)
if not train:
model = model.eval().requires_grad_(False)
return model.to(device)
# <<<<<< To be removed when we migrate upscaler model to diffusers
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionUpscalerPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with the upscaler model. Currently restricted to `EulerDiscreteScheduler`.nn
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
scheduler: EulerDiscreteScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
# Download upscaler
config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")
weights_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth")
self.upscaler = make_upscaler_model(config_path, weights_path)
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
self.upscaler.to(torch_device)
return super().to(torch_device)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
# TODO: enable in the upscaler
pass
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
# TODO: disable in the upscaler
pass
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
pass
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
for cpu_offloaded_model in [self.text_encoder, self.vae, self.safety_checker, self.upscaler]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.vae.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _get_text_conditioning(self, prompt, device, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1
def get_conditioning(text):
text_inputs = self.tokenizer(
text,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
attention_mask = text_inputs.attention_mask.to(device)
cross_cond_padding = 1 - attention_mask
# I believe the attention mask should be provided here, but the original notebook does not do it
# TODO: test it out
# clip_output = self.text_encoder(input_ids=text_input_ids, attention_mask=attention_mask, output_hidden_states=True)
clip_output = self.text_encoder(input_ids=text_input_ids, output_hidden_states=True)
hidden_states = clip_output.hidden_states[-1]
pooler_output = clip_output.pooler_output
return hidden_states, cross_cond_padding.to(dtype=hidden_states.dtype), pooler_output
prompt_conditioning = get_conditioning(prompt)
conditioning = prompt_conditioning
if do_classifier_free_guidance:
uncond_conditioning = get_conditioning(batch_size * [""])
conditioning = [torch.cat([uc_item, c_item]) for uc_item, c_item in zip(uncond_conditioning, prompt_conditioning)]
return conditioning
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, latents, prompt):
batch, _, height, width = latents.shape
if height != width:
raise ValueError(f"Latents should be square, got {height}x{width} instead")
if prompt is not None:
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if isinstance(prompt, list) and len(prompt) != batch:
raise ValueError(f"`prompt` length has to be equal to the latents batch_size ({batch}), but is {len(prompt)}")
@torch.no_grad()
def __call__(
self,
latents: torch.FloatTensor,
prompt: Optional[Union[str, List[str]]] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
latents (`torch.FloatTensor`):
Latents to be upscaled. Generated from a Stable Diffusion Pipeline using `output_type="latents"`.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image upscaling process.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(latents, prompt)
# 2. Define call parameters
batch_size, channels, height, width = latents.shape
if isinstance(prompt, str):
prompt = [prompt] * batch_size
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Prepare timesteps
# We take log-linear steps in noise-level from sigma_max to sigma_min
# TODO(Pedro) Fix: create the scheduler with the betas instead
sigma_min = self.scheduler.sigmas[-2] # Last one is zero
sigma_max = self.scheduler.sigmas[0]
# The +1 comes from k-diffusion
sigmas = torch.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps+1).exp().to(device)
# scheduler.sigmas = torch.cat((sigmas, torch.tensor([0.]).to(device)))
self.scheduler.set_timesteps(num_inference_steps, device=device)
self.scheduler.sigmas = sigmas
# # 4. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
# extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 5. Prepare conditioning
conditioning = self._get_text_conditioning(prompt, device, do_classifier_free_guidance)
# 6. Create initial noise
x_shape = [batch_size, channels, 2*height, 2*width]
noisy_latents = torch.randn(x_shape, generator=generator, device=device, dtype=sigmas.dtype)
# Disabled; according to the notebook it doesn't seem to work well
# TODO: remove in final implementation
low_res_sigma = torch.full([batch_size], 0, device=device, dtype=sigmas.dtype)
# 7. Prepare inputs for CFG
low_res = latents
if do_classifier_free_guidance:
low_res = torch.cat([low_res] * 2)
low_res_sigma = torch.cat([low_res_sigma] * 2)
# 8. Denoising loop
noisy_latents = noisy_latents * sigma_max
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
sigma = sigmas[i]
sigma = sigma[None]
latent_model_input = noisy_latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
if do_classifier_free_guidance:
latent_model_input = torch.cat([latent_model_input] * 2)
sigma = torch.cat([sigma] * 2)
# predict the next denoised latent
denoised = self.upscaler(
latent_model_input,
sigma,
low_res=low_res,
low_res_sigma=low_res_sigma,
c=conditioning,
)
# perform guidance
if do_classifier_free_guidance:
uncond, cond = denoised.chunk(2)
denoised = uncond + guidance_scale * (cond - uncond)
# compute the previous noisy sample x_t -> x_t-1
noisy_latents = self.scheduler.step(denoised, t, noisy_latents).prev_sample
# 9. Post-processing
image = self.decode_latents(noisy_latents)
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, sigmas.dtype)
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -20,7 +20,7 @@ import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, betas_from_loglinear_sigmas
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -65,6 +65,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
"""
@@ -78,6 +80,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
predict_epsilon: bool = True,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
@@ -88,6 +91,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "loglinear_sigmas":
# This scheduler is specific to k-diffusion latent upscaler
# We use a helper function because the computation is a bit involved
# Alternative: create from a list of `trained_betas` (but see https://github.com/huggingface/diffusers/issues/1367)
self.betas = betas_from_loglinear_sigmas(beta_start, beta_end, num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
@@ -120,12 +128,15 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Returns:
`torch.FloatTensor`: scaled input sample
"""
self.is_scale_input_called = True
if not self.config.predict_epsilon:
return sample
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
@@ -229,7 +240,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma_hat * model_output
if self.config.predict_epsilon:
pred_original_sample = sample - sigma_hat * model_output
else:
pred_original_sample = model_output
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat

View File

@@ -16,6 +16,7 @@ import os
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from ..utils import BaseOutput
@@ -152,3 +153,47 @@ class SchedulerMixin:
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
]
return compatible_classes
def betas_from_loglinear_sigmas(beta_start, beta_end, num_timesteps):
"""
Computes the beta values suitable to create a loglinear schedule of sigmas,
as used in k-diffusion latent upscaler.
Concretely, these are the betas the create a sigma schedule like the following:
```
torch.linspace(np.log(sigma_max), np.log(sigma_min), num_timesteps).exp()
```
Args:
beta_start (`float`): The start sigma value.
beta_end (`float`): The end sigma value.
num_timesteps (`int`): The number of training timesteps.
Returns:
`torch.FloatTensor`: The beta values.
"""
# First, compute sigma_max and sigma_min considering a "scaled_linear" schedule
# as used in Stable Diffusion. We just need sigma_min and sigma_max.
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
sigma_min, sigma_max = sigmas[0], sigmas[-1]
# Then, compute the actual loglinear sigmas from sigma_min, sigma_max
sigmas = torch.linspace(np.log(sigma_max), np.log(sigma_min), num_timesteps)
sigmas = sigmas.exp()
sigmas = np.array(sigmas)[::-1]
alpha_cumprod = 1./(1+sigmas**2)
# Compute the alpha values reversing alpha_cumprod
alphas = []
prev_prod = 1.
for a in alpha_cumprod:
current_alpha = a / prev_prod
alphas.append(current_alpha)
prev_prod = a
# Get the betas from the alphas
betas = 1 - np.array(alphas)
return torch.Tensor(betas)