mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 05:24:20 +08:00
Compare commits
6 Commits
fast-gpu-t
...
fix-cog-vi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64bf37f767 | ||
|
|
8ba90aa706 | ||
|
|
9d49b45b19 | ||
|
|
81da2e1c95 | ||
|
|
24053832b5 | ||
|
|
f6f16a0c11 |
1
.github/workflows/push_tests.yml
vendored
1
.github/workflows/push_tests.yml
vendored
@@ -1,6 +1,7 @@
|
||||
name: Fast GPU Tests on main
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import gc
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
@@ -56,6 +55,7 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
)
|
||||
@@ -210,9 +210,7 @@ def log_validation(
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
clear_objs_and_retain_memory(objs=[pipeline])
|
||||
|
||||
return images
|
||||
|
||||
@@ -1107,9 +1105,7 @@ def main(args):
|
||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
clear_objs_and_retain_memory(objs=[pipeline])
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -1455,12 +1451,10 @@ def main(args):
|
||||
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
del tokenizers, text_encoders
|
||||
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
clear_objs_and_retain_memory(
|
||||
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
|
||||
)
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1795,11 +1789,11 @@ def main(args):
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
)
|
||||
objs = []
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
clear_objs_and_retain_memory(objs=objs)
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -342,15 +342,58 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
embed_dim: int = 1920,
|
||||
text_embed_dim: int = 4096,
|
||||
bias: bool = True,
|
||||
sample_width: int = 90,
|
||||
sample_height: int = 60,
|
||||
sample_frames: int = 49,
|
||||
temporal_compression_ratio: int = 4,
|
||||
max_text_seq_length: int = 226,
|
||||
spatial_interpolation_scale: float = 1.875,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_positional_embeddings: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.embed_dim = embed_dim
|
||||
self.sample_height = sample_height
|
||||
self.sample_width = sample_width
|
||||
self.sample_frames = sample_frames
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
self.max_text_seq_length = max_text_seq_length
|
||||
self.spatial_interpolation_scale = spatial_interpolation_scale
|
||||
self.temporal_interpolation_scale = temporal_interpolation_scale
|
||||
self.use_positional_embeddings = use_positional_embeddings
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
||||
|
||||
if use_positional_embeddings:
|
||||
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
||||
|
||||
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
|
||||
post_patch_height = sample_height // self.patch_size
|
||||
post_patch_width = sample_width // self.patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
||||
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
pos_embedding = get_3d_sincos_pos_embed(
|
||||
self.embed_dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
self.spatial_interpolation_scale,
|
||||
self.temporal_interpolation_scale,
|
||||
)
|
||||
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
|
||||
joint_pos_embedding = torch.zeros(
|
||||
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
||||
)
|
||||
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
||||
|
||||
return joint_pos_embedding
|
||||
|
||||
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
||||
r"""
|
||||
Args:
|
||||
@@ -371,6 +414,21 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
embeds = torch.cat(
|
||||
[text_embeds, image_embeds], dim=1
|
||||
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
||||
|
||||
if self.use_positional_embeddings:
|
||||
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
||||
if (
|
||||
self.sample_height != height
|
||||
or self.sample_width != width
|
||||
or self.sample_frames != pre_time_compression_frames
|
||||
):
|
||||
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
|
||||
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
|
||||
else:
|
||||
pos_embedding = self.pos_embedding
|
||||
|
||||
embeds = embeds + pos_embedding
|
||||
|
||||
return embeds
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
@@ -239,33 +239,29 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
post_patch_height = sample_height // patch_size
|
||||
post_patch_width = sample_width // patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
|
||||
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
# 1. Patch embedding
|
||||
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
|
||||
self.patch_embed = CogVideoXPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
text_embed_dim=text_embed_dim,
|
||||
bias=True,
|
||||
sample_width=sample_width,
|
||||
sample_height=sample_height,
|
||||
sample_frames=sample_frames,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
max_text_seq_length=max_text_seq_length,
|
||||
spatial_interpolation_scale=spatial_interpolation_scale,
|
||||
temporal_interpolation_scale=temporal_interpolation_scale,
|
||||
use_positional_embeddings=not use_rotary_positional_embeddings,
|
||||
)
|
||||
self.embedding_dropout = nn.Dropout(dropout)
|
||||
|
||||
# 2. 3D positional embeddings
|
||||
spatial_pos_embedding = get_3d_sincos_pos_embed(
|
||||
inner_dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
spatial_interpolation_scale,
|
||||
temporal_interpolation_scale,
|
||||
)
|
||||
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
|
||||
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
|
||||
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
||||
|
||||
# 3. Time embeddings
|
||||
# 2. Time embeddings
|
||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
||||
|
||||
# 4. Define spatio-temporal transformers blocks
|
||||
# 3. Define spatio-temporal transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
CogVideoXBlock(
|
||||
@@ -284,7 +280,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
# 5. Output blocks
|
||||
# 4. Output blocks
|
||||
self.norm_out = AdaLayerNorm(
|
||||
embedding_dim=time_embed_dim,
|
||||
output_dim=2 * inner_dim,
|
||||
@@ -422,20 +418,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 2. Patch embedding
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
# 3. Position embedding
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
|
||||
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 4. Transformer blocks
|
||||
# 3. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -471,11 +460,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# 5. Final block
|
||||
# 4. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 6. Unpatchify
|
||||
# 5. Unpatchify
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
@@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
return weighting
|
||||
|
||||
|
||||
def clear_objs_and_retain_memory(objs: List[Any]):
|
||||
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
|
||||
if len(objs) >= 1:
|
||||
for obj in objs:
|
||||
del obj
|
||||
|
||||
gc.collect()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch_npu.empty_cache()
|
||||
|
||||
|
||||
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||
class EMAModel:
|
||||
"""
|
||||
|
||||
@@ -157,11 +157,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_integration_move_lora_dora_cpu(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
path = "runwayml/stable-diffusion-v1-5"
|
||||
path = "Lykon/dreamshaper-8"
|
||||
unet_lora_config = LoraConfig(
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
|
||||
@@ -528,6 +528,10 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("No attention module used in this model")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
return
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
||||
|
||||
@@ -55,17 +55,6 @@ class EmbeddingsTests(unittest.TestCase):
|
||||
assert grad > prev_grad
|
||||
prev_grad = grad
|
||||
|
||||
def test_timestep_defaults(self):
|
||||
embedding_dim = 16
|
||||
timesteps = torch.arange(10)
|
||||
|
||||
t1 = get_timestep_embedding(timesteps, embedding_dim)
|
||||
t2 = get_timestep_embedding(
|
||||
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10_000
|
||||
)
|
||||
|
||||
assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3)
|
||||
|
||||
def test_timestep_flip_sin_cos(self):
|
||||
embedding_dim = 16
|
||||
timesteps = torch.arange(10)
|
||||
|
||||
@@ -183,17 +183,6 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
|
||||
|
||||
class UNetTesterMixin:
|
||||
def test_forward_signature(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["sample", "timestep"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -220,6 +209,7 @@ class ModelTesterMixin:
|
||||
base_precision = 1e-3
|
||||
forward_requires_fresh_args = False
|
||||
model_split_percents = [0.5, 0.7, 0.9]
|
||||
uses_custom_attn_processor = False
|
||||
|
||||
def check_device_map_is_respected(self, model, device_map):
|
||||
for param_name, param in model.named_parameters():
|
||||
|
||||
@@ -32,6 +32,7 @@ enable_full_determinism()
|
||||
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = CogVideoXTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
|
||||
@@ -32,6 +32,7 @@ enable_full_determinism()
|
||||
class LuminaNextDiT2DModelTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = LuminaNextDiT2DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
|
||||
@@ -175,7 +175,7 @@ class AnimateDiffPipelineFastTests(
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array(
|
||||
@@ -209,7 +209,7 @@ class AnimateDiffPipelineFastTests(
|
||||
0.5620,
|
||||
]
|
||||
)
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
expected_slice = None
|
||||
|
||||
@@ -193,7 +193,7 @@ class AnimateDiffControlNetPipelineFastTests(
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array(
|
||||
@@ -218,7 +218,7 @@ class AnimateDiffControlNetPipelineFastTests(
|
||||
0.5155,
|
||||
]
|
||||
)
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
expected_slice = None
|
||||
|
||||
@@ -195,7 +195,7 @@ class AnimateDiffSparseControlNetPipelineFastTests(
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array(
|
||||
@@ -220,7 +220,7 @@ class AnimateDiffSparseControlNetPipelineFastTests(
|
||||
0.5155,
|
||||
]
|
||||
)
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
expected_slice = None
|
||||
|
||||
@@ -175,7 +175,7 @@ class AnimateDiffVideoToVideoPipelineFastTests(
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
|
||||
if torch_device == "cpu":
|
||||
@@ -201,7 +201,7 @@ class AnimateDiffVideoToVideoPipelineFastTests(
|
||||
0.5378,
|
||||
]
|
||||
)
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
|
||||
@@ -220,11 +220,11 @@ class ControlNetPipelineFastTests(
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.5234, 0.3333, 0.1745, 0.7605, 0.6224, 0.4637, 0.6989, 0.7526, 0.4665])
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
@@ -460,11 +460,11 @@ class StableDiffusionMultiControlNetPipelineFastTests(
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.2422, 0.3425, 0.4048, 0.5351, 0.3503, 0.2419, 0.4645, 0.4570, 0.3804])
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_save_pretrained_raise_not_implemented_exception(self):
|
||||
components = self.get_dummy_components()
|
||||
@@ -679,11 +679,11 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.5264, 0.3203, 0.1602, 0.8235, 0.6332, 0.4593, 0.7226, 0.7777, 0.4780])
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_save_pretrained_raise_not_implemented_exception(self):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -173,11 +173,11 @@ class ControlNetImg2ImgPipelineFastTests(
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.7096, 0.5149, 0.3571, 0.5897, 0.4715, 0.4052, 0.6098, 0.6886, 0.4213])
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
@@ -371,11 +371,11 @@ class StableDiffusionMultiControlNetPipelineFastTests(
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.5293, 0.7339, 0.6642, 0.3950, 0.5212, 0.5175, 0.7002, 0.5907, 0.5182])
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_save_pretrained_raise_not_implemented_exception(self):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -190,14 +190,14 @@ class StableDiffusionXLControlNetPipelineFastTests(
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_ip_adapter_single(self, from_ssd1b=False, expected_pipe_slice=None):
|
||||
def test_ip_adapter(self, from_ssd1b=False, expected_pipe_slice=None):
|
||||
if not from_ssd1b:
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array(
|
||||
[0.7335, 0.5866, 0.5623, 0.6242, 0.5751, 0.5999, 0.4091, 0.4590, 0.5054]
|
||||
)
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
@@ -970,12 +970,12 @@ class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNe
|
||||
# make sure that it's equal
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.7212, 0.5890, 0.5491, 0.6425, 0.5970, 0.6091, 0.4418, 0.4556, 0.5032])
|
||||
|
||||
return super().test_ip_adapter_single(from_ssd1b=True, expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(from_ssd1b=True, expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_controlnet_sdxl_lcm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
@@ -175,12 +175,12 @@ class ControlNetPipelineSDXLImg2ImgFastTests(
|
||||
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.6276, 0.5271, 0.5205, 0.5393, 0.5774, 0.5872, 0.5456, 0.5415, 0.5354])
|
||||
# TODO: update after slices.p
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_stable_diffusion_xl_controlnet_img2img(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
@@ -550,7 +550,7 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_ip_adapter_single_mask(self):
|
||||
def test_ip_adapter_mask(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
|
||||
@@ -108,11 +108,11 @@ class LatentConsistencyModelPipelineFastTests(
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.1403, 0.5072, 0.5316, 0.1202, 0.3865, 0.4211, 0.5363, 0.3557, 0.3645])
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_lcm_onestep(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
@@ -119,11 +119,11 @@ class LatentConsistencyModelImg2ImgPipelineFastTests(
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.4003, 0.3718, 0.2863, 0.5500, 0.5587, 0.3772, 0.4617, 0.4961, 0.4417])
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_lcm_onestep(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
@@ -175,7 +175,7 @@ class AnimateDiffPAGPipelineFastTests(
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
|
||||
if torch_device == "cpu":
|
||||
@@ -210,7 +210,7 @@ class AnimateDiffPAGPipelineFastTests(
|
||||
0.5538,
|
||||
]
|
||||
)
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
expected_slice = None
|
||||
|
||||
@@ -176,7 +176,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr
|
||||
|
||||
assert isinstance(pipe.unet, UNetMotionModel)
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
|
||||
if torch_device == "cpu":
|
||||
@@ -211,7 +211,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr
|
||||
0.5538,
|
||||
]
|
||||
)
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
expected_slice = None
|
||||
|
||||
@@ -253,11 +253,11 @@ class StableDiffusionImg2ImgPipelineFastTests(
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.4932, 0.5092, 0.5135, 0.5517, 0.5626, 0.6621, 0.6490, 0.5021, 0.5441])
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_stable_diffusion_img2img_multiple_init_images(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
@@ -385,14 +385,14 @@ class StableDiffusionInpaintPipelineFastTests(
|
||||
# they should be the same
|
||||
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
|
||||
|
||||
def test_ip_adapter_single(self, from_simple=False, expected_pipe_slice=None):
|
||||
def test_ip_adapter(self, from_simple=False, expected_pipe_slice=None):
|
||||
if not from_simple:
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array(
|
||||
[0.4390, 0.5452, 0.3772, 0.5448, 0.6031, 0.4480, 0.5194, 0.4687, 0.4640]
|
||||
)
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
|
||||
class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests):
|
||||
@@ -481,11 +481,11 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.6345, 0.5395, 0.5611, 0.5403, 0.5830, 0.5855, 0.5193, 0.5443, 0.5211])
|
||||
return super().test_ip_adapter_single(from_simple=True, expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(from_simple=True, expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_stable_diffusion_inpaint(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
@@ -281,9 +281,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(
|
||||
max_diff = np.abs(output - output_tuple).max()
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
def test_progress_bar(self):
|
||||
super().test_progress_bar()
|
||||
|
||||
def test_stable_diffusion_depth2img_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -330,12 +330,12 @@ class StableDiffusionXLPipelineFastTests(
|
||||
# make sure that it's equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.5388, 0.5452, 0.4694, 0.4583, 0.5253, 0.4832, 0.5288, 0.5035, 0.4766])
|
||||
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
|
||||
|
||||
@@ -290,7 +290,7 @@ class StableDiffusionXLAdapterPipelineFastTests(
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter_single(self, from_multi=False, expected_pipe_slice=None):
|
||||
def test_ip_adapter(self, from_multi=False, expected_pipe_slice=None):
|
||||
if not from_multi:
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
@@ -298,7 +298,7 @@ class StableDiffusionXLAdapterPipelineFastTests(
|
||||
[0.5752, 0.6155, 0.4826, 0.5111, 0.5741, 0.4678, 0.5199, 0.5231, 0.4794]
|
||||
)
|
||||
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_stable_diffusion_adapter_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
@@ -448,12 +448,12 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
|
||||
expected_slice = np.array([0.5617, 0.6081, 0.4807, 0.5071, 0.5665, 0.4614, 0.5165, 0.5164, 0.4786])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.5617, 0.6081, 0.4807, 0.5071, 0.5665, 0.4614, 0.5165, 0.5164, 0.4786])
|
||||
|
||||
return super().test_ip_adapter_single(from_multi=True, expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(from_multi=True, expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_inference_batch_consistent(
|
||||
self, batch_sizes=[2, 4, 13], additional_params_copy_to_batched_inputs=["num_inference_steps"]
|
||||
|
||||
@@ -310,12 +310,12 @@ class StableDiffusionXLImg2ImgPipelineFastTests(
|
||||
# make sure that it's equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.5133, 0.4626, 0.4970, 0.6273, 0.5160, 0.6891, 0.6639, 0.5892, 0.5709])
|
||||
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_stable_diffusion_xl_img2img_tiny_autoencoder(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
@@ -223,12 +223,12 @@ class StableDiffusionXLInpaintPipelineFastTests(
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.8274, 0.5538, 0.6141, 0.5843, 0.6865, 0.7082, 0.5861, 0.6123, 0.5344])
|
||||
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
|
||||
@@ -1,6 +1,25 @@
|
||||
import contextlib
|
||||
import io
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AnimateDiffPipeline,
|
||||
AnimateDiffVideoToVideoPipeline,
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
MotionAdapter,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
@@ -177,3 +196,251 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
|
||||
|
||||
class ProgressBarTests(unittest.TestCase):
|
||||
def get_dummy_components_image_generation(self):
|
||||
cross_attention_dim = 8
|
||||
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=1,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[4, 8],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=cross_attention_dim,
|
||||
intermediate_size=16,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=2,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_components_video_generation(self):
|
||||
cross_attention_dim = 8
|
||||
block_out_channels = (8, 8)
|
||||
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=2,
|
||||
sample_size=8,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="linear",
|
||||
clip_sample=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=block_out_channels,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=cross_attention_dim,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
torch.manual_seed(0)
|
||||
motion_adapter = MotionAdapter(
|
||||
block_out_channels=block_out_channels,
|
||||
motion_layers_per_block=2,
|
||||
motion_norm_num_groups=2,
|
||||
motion_num_attention_heads=4,
|
||||
)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"motion_adapter": motion_adapter,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def test_text_to_image(self):
|
||||
components = self.get_dummy_components_image_generation()
|
||||
pipe = StableDiffusionPipeline(**components)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs = {"prompt": "a cute cat", "num_inference_steps": 2}
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
stderr = stderr.getvalue()
|
||||
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
|
||||
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
|
||||
max_steps = re.search("/(.*?) ", stderr).group(1)
|
||||
self.assertTrue(max_steps is not None and len(max_steps) > 0)
|
||||
self.assertTrue(
|
||||
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
|
||||
)
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
||||
|
||||
def test_image_to_image(self):
|
||||
components = self.get_dummy_components_image_generation()
|
||||
pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = Image.new("RGB", (32, 32))
|
||||
inputs = {"prompt": "a cute cat", "num_inference_steps": 2, "strength": 0.5, "image": image}
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
stderr = stderr.getvalue()
|
||||
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
|
||||
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
|
||||
max_steps = re.search("/(.*?) ", stderr).group(1)
|
||||
self.assertTrue(max_steps is not None and len(max_steps) > 0)
|
||||
self.assertTrue(
|
||||
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
|
||||
)
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
||||
|
||||
def test_inpainting(self):
|
||||
components = self.get_dummy_components_image_generation()
|
||||
pipe = StableDiffusionInpaintPipeline(**components)
|
||||
pipe.to(torch_device)
|
||||
|
||||
image = Image.new("RGB", (32, 32))
|
||||
mask = Image.new("RGB", (32, 32))
|
||||
inputs = {
|
||||
"prompt": "a cute cat",
|
||||
"num_inference_steps": 2,
|
||||
"strength": 0.5,
|
||||
"image": image,
|
||||
"mask_image": mask,
|
||||
}
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
stderr = stderr.getvalue()
|
||||
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
|
||||
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
|
||||
max_steps = re.search("/(.*?) ", stderr).group(1)
|
||||
self.assertTrue(max_steps is not None and len(max_steps) > 0)
|
||||
self.assertTrue(
|
||||
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
|
||||
)
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
||||
|
||||
def test_text_to_video(self):
|
||||
components = self.get_dummy_components_video_generation()
|
||||
pipe = AnimateDiffPipeline(**components)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs = {"prompt": "a cute cat", "num_inference_steps": 2, "num_frames": 2}
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
stderr = stderr.getvalue()
|
||||
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
|
||||
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
|
||||
max_steps = re.search("/(.*?) ", stderr).group(1)
|
||||
self.assertTrue(max_steps is not None and len(max_steps) > 0)
|
||||
self.assertTrue(
|
||||
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
|
||||
)
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
||||
|
||||
def test_video_to_video(self):
|
||||
components = self.get_dummy_components_video_generation()
|
||||
pipe = AnimateDiffVideoToVideoPipeline(**components)
|
||||
pipe.to(torch_device)
|
||||
|
||||
num_frames = 2
|
||||
video = [Image.new("RGB", (32, 32))] * num_frames
|
||||
inputs = {"prompt": "a cute cat", "num_inference_steps": 2, "video": video}
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
stderr = stderr.getvalue()
|
||||
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
|
||||
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
|
||||
max_steps = re.search("/(.*?) ", stderr).group(1)
|
||||
self.assertTrue(max_steps is not None and len(max_steps) > 0)
|
||||
self.assertTrue(
|
||||
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
|
||||
)
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
@@ -141,52 +138,35 @@ class SDFunctionTesterMixin:
|
||||
assert np.abs(to_np(output_2) - to_np(output_1)).max() < 5e-1
|
||||
|
||||
# test that tiled decode works with various shapes
|
||||
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
|
||||
shapes = [(1, 4, 73, 97), (1, 4, 65, 49)]
|
||||
with torch.no_grad():
|
||||
for shape in shapes:
|
||||
zeros = torch.zeros(shape).to(torch_device)
|
||||
pipe.vae.decode(zeros)
|
||||
|
||||
# MPS currently doesn't support ComplexFloats, which are required for freeU - see https://github.com/huggingface/diffusers/issues/7569.
|
||||
# MPS currently doesn't support ComplexFloats, which are required for FreeU - see https://github.com/huggingface/diffusers/issues/7569.
|
||||
@skip_mps
|
||||
def test_freeu_enabled(self):
|
||||
def test_freeu(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Normal inference
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["return_dict"] = False
|
||||
inputs["output_type"] = "np"
|
||||
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
# FreeU-enabled inference
|
||||
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["return_dict"] = False
|
||||
inputs["output_type"] = "np"
|
||||
|
||||
output_freeu = pipe(**inputs)[0]
|
||||
|
||||
assert not np.allclose(
|
||||
output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
|
||||
), "Enabling of FreeU should lead to different results."
|
||||
|
||||
def test_freeu_disabled(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["return_dict"] = False
|
||||
inputs["output_type"] = "np"
|
||||
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
|
||||
# FreeU-disabled inference
|
||||
pipe.disable_freeu()
|
||||
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
for upsample_block in pipe.unet.up_blocks:
|
||||
for key in freeu_keys:
|
||||
@@ -195,8 +175,11 @@ class SDFunctionTesterMixin:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["return_dict"] = False
|
||||
inputs["output_type"] = "np"
|
||||
|
||||
output_no_freeu = pipe(**inputs)[0]
|
||||
|
||||
assert not np.allclose(
|
||||
output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
|
||||
), "Enabling of FreeU should lead to different results."
|
||||
assert np.allclose(
|
||||
output, output_no_freeu, atol=1e-2
|
||||
), f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
|
||||
@@ -290,7 +273,15 @@ class IPAdapterTesterMixin:
|
||||
inputs["return_dict"] = False
|
||||
return inputs
|
||||
|
||||
def test_ip_adapter_single(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
|
||||
r"""Tests for IP-Adapter.
|
||||
|
||||
The following scenarios are tested:
|
||||
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
|
||||
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
|
||||
"""
|
||||
# Raising the tolerance for this test when it's run on a CPU because we
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
@@ -307,6 +298,7 @@ class IPAdapterTesterMixin:
|
||||
else:
|
||||
output_without_adapter = expected_pipe_slice
|
||||
|
||||
# 1. Single IP-Adapter test cases
|
||||
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
@@ -338,16 +330,7 @@ class IPAdapterTesterMixin:
|
||||
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
|
||||
)
|
||||
|
||||
def test_ip_adapter_multi(self, expected_max_diff: float = 1e-4):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
output_without_adapter = pipe(**inputs)[0]
|
||||
|
||||
# 2. Multi IP-Adapter test cases
|
||||
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
|
||||
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
|
||||
@@ -357,12 +340,16 @@ class IPAdapterTesterMixin:
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([0.0, 0.0])
|
||||
output_without_multi_adapter_scale = pipe(**inputs)[0]
|
||||
if expected_pipe_slice is not None:
|
||||
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with multi ip adapter, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
|
||||
pipe.set_ip_adapter_scale([42.0, 42.0])
|
||||
output_with_multi_adapter_scale = pipe(**inputs)[0]
|
||||
if expected_pipe_slice is not None:
|
||||
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_multi_adapter_scale = np.abs(
|
||||
output_without_multi_adapter_scale - output_without_adapter
|
||||
@@ -1689,28 +1676,6 @@ class PipelineTesterMixin:
|
||||
if test_mean_pixel_difference:
|
||||
assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
|
||||
|
||||
def test_progress_bar(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
stderr = stderr.getvalue()
|
||||
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
|
||||
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
|
||||
max_steps = re.search("/(.*?) ", stderr).group(1)
|
||||
self.assertTrue(max_steps is not None and len(max_steps) > 0)
|
||||
self.assertTrue(
|
||||
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
|
||||
)
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
|
||||
|
||||
@@ -173,9 +173,6 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, SDFunctionTesterMixin,
|
||||
def test_num_images_per_prompt(self):
|
||||
pass
|
||||
|
||||
def test_progress_bar(self):
|
||||
return super().test_progress_bar()
|
||||
|
||||
|
||||
@slow
|
||||
@skip_mps
|
||||
|
||||
@@ -13,11 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import inspect
|
||||
import io
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -282,28 +279,6 @@ class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, PipelineFromPipe
|
||||
def test_pipeline_call_signature(self):
|
||||
pass
|
||||
|
||||
def test_progress_bar(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(self.generator_device)
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
stderr = stderr.getvalue()
|
||||
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
|
||||
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
|
||||
max_steps = re.search("/(.*?) ", stderr).group(1)
|
||||
self.assertTrue(max_steps is not None and len(max_steps) > 0)
|
||||
self.assertTrue(
|
||||
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
|
||||
)
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
|
||||
_ = pipe(**inputs)
|
||||
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_save_load_float16(self, expected_max_diff=1e-2):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -197,9 +197,6 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def test_num_images_per_prompt(self):
|
||||
pass
|
||||
|
||||
def test_progress_bar(self):
|
||||
return super().test_progress_bar()
|
||||
|
||||
|
||||
@nightly
|
||||
@skip_mps
|
||||
|
||||
Reference in New Issue
Block a user