mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-19 02:44:53 +08:00
Compare commits
20 Commits
posix-test
...
unet-refac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43f72b2940 | ||
|
|
36e52ef801 | ||
|
|
95fcd61904 | ||
|
|
8f824bf0ab | ||
|
|
89bcec9a0d | ||
|
|
f404b6926e | ||
|
|
69f4b8ff5a | ||
|
|
f0ec02350a | ||
|
|
8bf046b7fb | ||
|
|
bb99623d09 | ||
|
|
9fdd6de30f | ||
|
|
aa3b85bdd6 | ||
|
|
fdf55b1f1c | ||
|
|
c6f8c310c3 | ||
|
|
0953fed52b | ||
|
|
bd375a8034 | ||
|
|
17105d973c | ||
|
|
32e04da6cf | ||
|
|
c1e812b8fd | ||
|
|
784f4e9646 |
@@ -1031,16 +1031,10 @@ class DownBlockMotion(nn.Module):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, scale
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(motion_module),
|
||||
hidden_states.requires_grad_(),
|
||||
temb,
|
||||
num_frames,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
@@ -1221,10 +1215,10 @@ class CrossAttnDownBlockMotion(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
|
||||
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
||||
if i == len(blocks) - 1 and additional_residuals is not None:
|
||||
@@ -1425,10 +1419,10 @@ class CrossAttnUpBlockMotion(nn.Module):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
hidden_states = motion_module(
|
||||
hidden_states,
|
||||
num_frames=num_frames,
|
||||
)[0]
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -1563,15 +1557,10 @@ class UpBlockMotion(nn.Module):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
||||
1181
src/diffusers/models/unets/unet_stable_diffusion.py
Normal file
1181
src/diffusers/models/unets/unet_stable_diffusion.py
Normal file
File diff suppressed because it is too large
Load Diff
205
src/diffusers/models/unets/unet_utils.py
Normal file
205
src/diffusers/models/unets/unet_utils.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
|
||||
|
||||
class UNet2DConditionModelUtilsMixin:
|
||||
def _check_config(
|
||||
self,
|
||||
down_block_types: Tuple[str],
|
||||
up_block_types: Tuple[str],
|
||||
only_cross_attention: Union[bool, Tuple[bool]],
|
||||
block_out_channels: Tuple[int],
|
||||
layers_per_block: [int, Tuple[int]],
|
||||
cross_attention_dim: Union[int, Tuple[int]],
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]],
|
||||
reverse_transformer_layers_per_block: bool,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]],
|
||||
):
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
||||
for layer_number_per_block in transformer_layers_per_block:
|
||||
if isinstance(layer_number_per_block, list):
|
||||
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
||||
|
||||
@property
|
||||
def attn_processors(self):
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
||||
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
setattr(upsample_block, "s1", s1)
|
||||
setattr(upsample_block, "s2", s2)
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
for k in freeu_keys:
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
@@ -24,7 +24,7 @@ import torch.fft as fft
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
@@ -209,7 +209,9 @@ class PIAPipelineOutput(BaseOutput):
|
||||
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
||||
|
||||
|
||||
class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
|
||||
class PIAPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
@@ -685,6 +687,35 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
@@ -1107,12 +1138,9 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_videos_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -522,7 +522,7 @@ def load_hf_numpy(path) -> np.ndarray:
|
||||
base_url = "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main"
|
||||
|
||||
if not path.startswith("http://") and not path.startswith("https://"):
|
||||
path = Path(base_url, urllib.parse.quote(path)).as_posix()
|
||||
path = os.path.join(base_url, urllib.parse.quote(path))
|
||||
|
||||
return load_numpy(path)
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -119,7 +120,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
|
||||
expected_slice = np.array([0.80810547, 0.88183594, 0.9296875, 0.9189453, 0.9848633, 1.0, 0.97021484, 1.0, 1.0])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
||||
|
||||
@@ -131,7 +133,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.30444336, 0.26513672, 0.22436523, 0.2758789, 0.25585938, 0.20751953, 0.25390625, 0.24633789, 0.21923828]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_image_to_image(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -149,7 +152,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.22167969, 0.21875, 0.21728516, 0.22607422, 0.21948242, 0.23925781, 0.22387695, 0.25268555, 0.2722168]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
||||
|
||||
@@ -161,7 +165,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.35913086, 0.265625, 0.26367188, 0.24658203, 0.19750977, 0.39990234, 0.15258789, 0.20336914, 0.5517578]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_inpainting(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -179,7 +184,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.27148438, 0.24047852, 0.22167969, 0.23217773, 0.21118164, 0.21142578, 0.21875, 0.20751953, 0.20019531]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
||||
|
||||
@@ -187,11 +193,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.27294922, 0.24023438, 0.21948242, 0.23242188, 0.20825195, 0.2055664, 0.21679688, 0.20336914, 0.19360352]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_text_to_image_model_cpu_offload(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -233,11 +236,10 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.18115234, 0.13500977, 0.13427734, 0.24194336, 0.17138672, 0.16625977, 0.4260254, 0.43359375, 0.4416504]
|
||||
)
|
||||
expected_slice = np.array([0.1958, 0.1475, 0.1396, 0.2412, 0.1658, 0.1533, 0.3997, 0.4055, 0.4128])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_unload(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -277,7 +279,9 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
expected_slice = np.array(
|
||||
[0.5234375, 0.53515625, 0.5629883, 0.57128906, 0.59521484, 0.62109375, 0.57910156, 0.6201172, 0.6508789]
|
||||
)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
|
||||
@slow
|
||||
@@ -314,7 +318,8 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
|
||||
@@ -339,7 +344,8 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.0576596, 0.05600825, 0.04479006, 0.05288461, 0.05461192, 0.05137569, 0.04867965, 0.05301541, 0.04939842]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_image_to_image_sdxl(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
|
||||
@@ -432,7 +438,8 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
[0.14181179, 0.1493012, 0.14283323, 0.14602411, 0.14915377, 0.15015268, 0.14725655, 0.15009224, 0.15164584]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
@@ -457,4 +464,5 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
|
||||
expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
Reference in New Issue
Block a user