Compare commits

...

5 Commits

Author SHA1 Message Date
Patrick von Platen
92bf23aee5 Patch release: v0.15.1 2023-04-17 18:27:40 +02:00
Patrick von Platen
84a89cc90c Fix config deprecation (#3129)
* Better deprecation message

* Better deprecation message

* Better doc string

* Fixes

* fix more

* fix more

* Improve __getattr__

* correct more

* fix more

* fix

* Improve more

* more improvements

* fix more

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* make style

* Fix all rest & add tests & remove old deprecation fns

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2023-04-17 18:23:01 +02:00
Patrick von Platen
f81017300b [Bug fix] Make sure correct timesteps are chosen for img2img (#3128)
Make sure correct timesteps are chosen for img2img
2023-04-17 18:22:50 +02:00
Patrick von Platen
f0ab5e9da8 [Bug fix] Fix img2img processor with safety checker (#3127)
Fix img2img processor with safety checker
2023-04-17 18:22:41 +02:00
Patrick von Platen
d12119e74c Add global pooling to controlnet (#3121) 2023-04-17 18:22:26 +02:00
32 changed files with 269 additions and 154 deletions

View File

@@ -372,9 +372,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
num_channels_latents = self.decoder.in_channels
height = self.decoder.sample_size
width = self.decoder.sample_size
num_channels_latents = self.decoder.config.in_channels
height = self.decoder.config.sample_size
width = self.decoder.config.sample_size
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
@@ -425,9 +425,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
channels = self.super_res_first.in_channels // 2
height = self.super_res_first.sample_size
width = self.super_res_first.sample_size
channels = self.super_res_first.config.in_channels // 2
height = self.super_res_first.config.sample_size
width = self.super_res_first.config.sample_size
super_res_latents = self.prepare_latents(
(batch_size, channels, height, width),

View File

@@ -452,9 +452,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
num_channels_latents = self.decoder.in_channels
height = self.decoder.sample_size
width = self.decoder.sample_size
num_channels_latents = self.decoder.config.in_channels
height = self.decoder.config.sample_size
width = self.decoder.config.sample_size
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
@@ -505,9 +505,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
channels = self.super_res_first.in_channels // 2
height = self.super_res_first.sample_size
width = self.super_res_first.sample_size
channels = self.super_res_first.config.in_channels // 2
height = self.super_res_first.config.sample_size
width = self.super_res_first.config.sample_size
super_res_latents = self.prepare_latents(
(batch_size, channels, height, width),

View File

@@ -226,7 +226,7 @@ install_requires = [
setup(
name="diffusers",
version="0.15.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.15.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -1,4 +1,4 @@
__version__ = "0.15.0"
__version__ = "0.15.1"
from .configuration_utils import ConfigMixin
from .utils import (

View File

@@ -118,6 +118,24 @@ class ConfigMixin:
self._internal_dict = FrozenDict(internal_dict)
def __getattr__(self, name: str) -> Any:
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
"""
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
is_attribute = name in self.__dict__
if is_in_config and not is_attribute:
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
return self._internal_dict[name]
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the

View File

@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, apply_forward_hook, deprecate
from ..utils import BaseOutput, apply_forward_hook
from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
@@ -123,16 +123,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
@property
def block_out_channels(self):
deprecate(
"block_out_channels",
"1.0.0",
"Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`",
standard_warn=False,
)
return self.config.block_out_channels
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value

View File

@@ -119,6 +119,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
):
super().__init__()
@@ -559,6 +560,12 @@ class ControlNetModel(ModelMixin, ConfigMixin):
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample *= conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)

View File

@@ -17,7 +17,7 @@
import inspect
import os
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor, device
@@ -32,6 +32,7 @@ from ..utils import (
WEIGHTS_NAME,
_add_variant,
_get_model_file,
deprecate,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
@@ -156,6 +157,24 @@ class ModelMixin(torch.nn.Module):
def __init__(self):
super().__init__()
def __getattr__(self, name: str) -> Any:
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
"""
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
is_attribute = name in self.__dict__
if is_in_config and not is_attribute:
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
return self._internal_dict[name]
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
return super().__getattr__(name)
@property
def is_gradient_checkpointing(self) -> bool:
"""

View File

@@ -19,7 +19,7 @@ import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, deprecate
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
@@ -190,16 +190,6 @@ class UNet1DModel(ModelMixin, ConfigMixin):
fc_dim=block_out_channels[-1] // 4,
)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, deprecate
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
@@ -216,16 +216,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -21,7 +21,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, deprecate, logging
from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
@@ -447,16 +447,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
standard_warn=False,
)
return self.config.in_channels
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""

View File

@@ -503,7 +503,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
return timesteps, num_inference_steps - t_start

View File

@@ -507,7 +507,7 @@ class DiffusionPipeline(ConfigMixin):
setattr(self, name, module)
def __setattr__(self, name: str, value: Any):
if hasattr(self, name) and hasattr(self.config, name):
if name in self.__dict__ and hasattr(self.config, name):
# We need to overwrite the config if name exists in config
if isinstance(getattr(self.config, name), (tuple, list)):
if value is not None and self.config[name][0] is not None:
@@ -635,26 +635,25 @@ class DiffusionPipeline(ConfigMixin):
)
module_names, _ = self._get_signature_keys(self)
module_names = [m for m in module_names if hasattr(self, m)]
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for name in module_names:
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
module.to(torch_device, torch_dtype)
if (
module.dtype == torch.float16
and str(torch_device) in ["cpu"]
and not silence_dtype_warnings
and not is_offloaded
):
logger.warning(
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
" is not recommended to move them to `cpu` as running them will fail. Please make"
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
" support for`float16` operations on this device in PyTorch. Please, remove the"
" `torch_dtype=torch.float16` argument, or use another device for inference."
)
for module in modules:
module.to(torch_device, torch_dtype)
if (
module.dtype == torch.float16
and str(torch_device) in ["cpu"]
and not silence_dtype_warnings
and not is_offloaded
):
logger.warning(
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
" is not recommended to move them to `cpu` as running them will fail. Please make"
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
" support for`float16` operations on this device in PyTorch. Please, remove the"
" `torch_dtype=torch.float16` argument, or use another device for inference."
)
return self
@property
@@ -664,12 +663,12 @@ class DiffusionPipeline(ConfigMixin):
`torch.device`: The torch device on which the pipeline is located.
"""
module_names, _ = self._get_signature_keys(self)
module_names = [m for m in module_names if hasattr(self, m)]
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
for module in modules:
return module.device
for name in module_names:
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
return module.device
return torch.device("cpu")
@classmethod
@@ -1438,13 +1437,12 @@ class DiffusionPipeline(ConfigMixin):
for child in module.children():
fn_recursive_set_mem_eff(child)
module_names, _, _ = self.extract_init_dict(dict(self.config))
module_names = [m for m in module_names if hasattr(self, m)]
module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)
for module in modules:
fn_recursive_set_mem_eff(module)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
@@ -1471,10 +1469,9 @@ class DiffusionPipeline(ConfigMixin):
self.enable_attention_slicing(None)
def set_attention_slice(self, slice_size: Optional[int]):
module_names, _, _ = self.extract_init_dict(dict(self.config))
module_names = [m for m in module_names if hasattr(self, m)]
module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")]
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size)
for module in modules:
module.set_attention_slice(slice_size)

View File

@@ -528,7 +528,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
return timesteps, num_inference_steps - t_start

View File

@@ -390,7 +390,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
return timesteps, num_inference_steps - t_start

View File

@@ -511,7 +511,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
return timesteps, num_inference_steps - t_start

View File

@@ -507,7 +507,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
return timesteps, num_inference_steps - t_start

View File

@@ -85,7 +85,10 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
images[idx] = np.zeros(images[idx].shape) # black image
if torch.is_tensor(images) or torch.is_tensor(images[0]):
images[idx] = torch.zeros_like(images[idx]) # black image
else:
images[idx] = np.zeros(images[idx].shape) # black image
if any(has_nsfw_concepts):
logger.warning(

View File

@@ -441,7 +441,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
timesteps = self.scheduler.timesteps
# Prepare latent variables
num_channels_latents = self.unet.in_channels
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,

View File

@@ -413,9 +413,9 @@ class UnCLIPPipeline(DiffusionPipeline):
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
num_channels_latents = self.decoder.in_channels
height = self.decoder.sample_size
width = self.decoder.sample_size
num_channels_latents = self.decoder.config.in_channels
height = self.decoder.config.sample_size
width = self.decoder.config.sample_size
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
@@ -466,9 +466,9 @@ class UnCLIPPipeline(DiffusionPipeline):
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
channels = self.super_res_first.in_channels // 2
height = self.super_res_first.sample_size
width = self.super_res_first.sample_size
channels = self.super_res_first.config.in_channels // 2
height = self.super_res_first.config.sample_size
width = self.super_res_first.config.sample_size
super_res_latents = self.prepare_latents(
(batch_size, channels, height, width),

View File

@@ -339,9 +339,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
num_channels_latents = self.decoder.in_channels
height = self.decoder.sample_size
width = self.decoder.sample_size
num_channels_latents = self.decoder.config.in_channels
height = self.decoder.config.sample_size
width = self.decoder.config.sample_size
if decoder_latents is None:
decoder_latents = self.prepare_latents(
@@ -393,9 +393,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
channels = self.super_res_first.in_channels // 2
height = self.super_res_first.sample_size
width = self.super_res_first.sample_size
channels = self.super_res_first.config.in_channels // 2
height = self.super_res_first.config.sample_size
width = self.super_res_first.config.sample_size
if super_res_latents is None:
super_res_latents = self.prepare_latents(

View File

@@ -18,7 +18,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import deprecate, logging
from ...utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -544,19 +544,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
@property
def in_channels(self):
deprecate(
"in_channels",
"1.0.0",
(
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use"
" `unet.config.in_channels` instead"
),
standard_warn=False,
)
return self.config.in_channels
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""

View File

@@ -533,7 +533,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.image_unet.in_channels
num_channels_latents = self.image_unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -378,7 +378,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.image_unet.in_channels
num_channels_latents = self.image_unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -452,7 +452,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.image_unet.in_channels
num_channels_latents = self.image_unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@@ -22,7 +22,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, deprecate, randn_tensor
from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@@ -167,16 +167,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.variance_type = variance_type
@property
def num_train_timesteps(self):
deprecate(
"num_train_timesteps",
"1.0.0",
"Accessing `num_train_timesteps` directly via scheduler.num_train_timesteps is deprecated. Please use `scheduler.config.num_train_timesteps instead`",
standard_warn=False,
)
return self.config.num_train_timesteps
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the

View File

@@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Union
from packaging import version
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True):
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
from .. import __version__
deprecated_kwargs = take_from
@@ -32,7 +32,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn
if warning is not None:
warning = warning + " " if standard_warn else ""
warnings.warn(warning + message, FutureWarning, stacklevel=2)
warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel)
if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
call_frame = inspect.getouterframes(inspect.currentframe())[1]

View File

@@ -25,6 +25,7 @@ from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
HeunDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionImg2ImgPipeline,
@@ -416,6 +417,33 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
for module in pipe.text_encoder, pipe.unet, pipe.vae:
assert module.device == torch.device("cpu")
def test_img2img_2nd_order(self):
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
sd_pipe.scheduler = HeunDiscreteScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 10
inputs["strength"] = 0.75
image = sd_pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/img2img_heun.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 5e-2
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 11
inputs["strength"] = 0.75
image_other = sd_pipe(**inputs).images[0]
mean_diff = np.abs(image - image_other).mean()
# images should be very similar
assert mean_diff < 5e-2
def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
@@ -453,6 +481,20 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
def test_img2img_safety_checker_works(self):
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 20
# make sure the safety checker is activated
inputs["prompt"] = "naked, sex, porn"
out = sd_pipe(**inputs)
assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}"
assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros
@nightly
@require_torch_gpu

View File

@@ -293,16 +293,16 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
prior_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size)
shape = (batch_size, decoder.config.in_channels, decoder.config.sample_size, decoder.config.sample_size)
decoder_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
shape = (
batch_size,
super_res_first.in_channels // 2,
super_res_first.sample_size,
super_res_first.sample_size,
super_res_first.config.in_channels // 2,
super_res_first.config.sample_size,
super_res_first.config.sample_size,
)
super_res_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()

View File

@@ -379,16 +379,21 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
dtype = pipe.decoder.dtype
batch_size = 1
shape = (batch_size, pipe.decoder.in_channels, pipe.decoder.sample_size, pipe.decoder.sample_size)
shape = (
batch_size,
pipe.decoder.config.in_channels,
pipe.decoder.config.sample_size,
pipe.decoder.config.sample_size,
)
decoder_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
shape = (
batch_size,
pipe.super_res_first.in_channels // 2,
pipe.super_res_first.sample_size,
pipe.super_res_first.sample_size,
pipe.super_res_first.config.in_channels // 2,
pipe.super_res_first.config.sample_size,
pipe.super_res_first.config.sample_size,
)
super_res_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()

View File

@@ -596,3 +596,47 @@ class SchedulerCommonTest(unittest.TestCase):
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
assert scheduler.betas.tolist() == new_scheduler.betas.tolist()
def test_getattr_is_correct(self):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
# save some things to test
scheduler.dummy_attribute = 5
scheduler.register_to_config(test_attribute=5)
logger = logging.get_logger("diffusers.configuration_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
assert hasattr(scheduler, "dummy_attribute")
assert getattr(scheduler, "dummy_attribute") == 5
assert scheduler.dummy_attribute == 5
# no warning should be thrown
assert cap_logger.out == ""
logger = logging.get_logger("diffusers.schedulers.schedulering_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
assert hasattr(scheduler, "save_pretrained")
fn = scheduler.save_pretrained
fn_1 = getattr(scheduler, "save_pretrained")
assert fn == fn_1
# no warning should be thrown
assert cap_logger.out == ""
# warning should be thrown
with self.assertWarns(FutureWarning):
assert scheduler.test_attribute == 5
with self.assertWarns(FutureWarning):
assert getattr(scheduler, "test_attribute") == 5
with self.assertRaises(AttributeError) as error:
scheduler.does_not_exist
assert str(error.exception) == f"'{type(scheduler).__name__}' object has no attribute 'does_not_exist'"

View File

@@ -26,8 +26,8 @@ from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel
from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from diffusers.utils import logging, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
class ModelUtilsTest(unittest.TestCase):
@@ -155,6 +155,49 @@ class ModelTesterMixin:
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_getattr_is_correct(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
# save some things to test
model.dummy_attribute = 5
model.register_to_config(test_attribute=5)
logger = logging.get_logger("diffusers.models.modeling_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
assert hasattr(model, "dummy_attribute")
assert getattr(model, "dummy_attribute") == 5
assert model.dummy_attribute == 5
# no warning should be thrown
assert cap_logger.out == ""
logger = logging.get_logger("diffusers.models.modeling_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
assert hasattr(model, "save_pretrained")
fn = model.save_pretrained
fn_1 = getattr(model, "save_pretrained")
assert fn == fn_1
# no warning should be thrown
assert cap_logger.out == ""
# warning should be thrown
with self.assertWarns(FutureWarning):
assert model.test_attribute == 5
with self.assertWarns(FutureWarning):
assert getattr(model, "test_attribute") == 5
with self.assertRaises(AttributeError) as error:
model.does_not_exist
assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
def test_from_save_pretrained_variant(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()