Compare commits

...

20 Commits

Author SHA1 Message Date
Dhruv Nair
43f72b2940 update 2024-02-15 12:34:16 +00:00
Dhruv Nair
36e52ef801 Merge branch 'unet-refactor' of https://github.com/huggingface/diffusers into unet-refactor 2024-02-15 09:36:05 +00:00
Dhruv Nair
95fcd61904 update 2024-02-15 09:35:37 +00:00
Dhruv Nair
8f824bf0ab Merge branch 'main' into unet-refactor 2024-02-05 19:56:09 +05:30
Dhruv Nair
89bcec9a0d update 2024-02-05 13:38:33 +00:00
Dhruv Nair
f404b6926e update 2024-02-05 13:08:56 +00:00
Dhruv Nair
69f4b8ff5a update 2024-02-05 12:55:55 +00:00
Dhruv Nair
f0ec02350a update 2024-02-05 11:30:55 +00:00
Dhruv Nair
8bf046b7fb Add single file and IP Adapter support to PIA Pipeline (#6851)
update
2024-02-05 16:23:18 +05:30
Dhruv Nair
bb99623d09 Update IP Adapter tests to use cosine similarity distance (#6806)
* update

* update
2024-02-05 16:22:59 +05:30
Dhruv Nair
9fdd6de30f update 2024-02-05 10:47:02 +00:00
Dhruv Nair
aa3b85bdd6 Merge branch 'main' into unet-refactor 2024-02-05 08:17:41 +00:00
Dhruv Nair
fdf55b1f1c Fix posix path issue in testing utils (#6849)
update
2024-02-05 08:57:18 +05:30
小咩Goat
c6f8c310c3 Fix forward pass in UNetMotionModel when gradient checkpoint is enabled (#6744)
fix #6742

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-02-05 08:04:01 +05:30
Dhruv Nair
0953fed52b update 2024-01-04 08:54:00 +00:00
Dhruv Nair
bd375a8034 update 2024-01-04 06:57:36 +00:00
Dhruv Nair
17105d973c update 2024-01-03 16:43:09 +00:00
Dhruv Nair
32e04da6cf update 2024-01-03 16:08:10 +00:00
Dhruv Nair
c1e812b8fd Merge branch 'main' into unet-refactor 2024-01-03 12:09:10 +00:00
Dhruv Nair
784f4e9646 update 2023-12-19 11:57:01 +00:00
6 changed files with 1459 additions and 48 deletions

View File

@@ -1031,16 +1031,10 @@ class DownBlockMotion(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, scale create_custom_forward(resnet), hidden_states, temb, scale
) )
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states.requires_grad_(),
temb,
num_frames,
)
else: else:
hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
@@ -1221,10 +1215,10 @@ class CrossAttnDownBlockMotion(nn.Module):
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = motion_module( hidden_states = motion_module(
hidden_states, hidden_states,
num_frames=num_frames, num_frames=num_frames,
)[0] )[0]
# apply additional residuals to the output of the last pair of resnet and attention blocks # apply additional residuals to the output of the last pair of resnet and attention blocks
if i == len(blocks) - 1 and additional_residuals is not None: if i == len(blocks) - 1 and additional_residuals is not None:
@@ -1425,10 +1419,10 @@ class CrossAttnUpBlockMotion(nn.Module):
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = motion_module( hidden_states = motion_module(
hidden_states, hidden_states,
num_frames=num_frames, num_frames=num_frames,
)[0] )[0]
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
@@ -1563,15 +1557,10 @@ class UpBlockMotion(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb create_custom_forward(resnet), hidden_states, temb
) )
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
)
else: else:
hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,205 @@
from typing import Dict, Optional, Tuple, Union
import torch
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
class UNet2DConditionModelUtilsMixin:
def _check_config(
self,
down_block_types: Tuple[str],
up_block_types: Tuple[str],
only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int],
layers_per_block: [int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]],
reverse_transformer_layers_per_block: bool,
num_attention_heads: Optional[Union[int, Tuple[int]]],
):
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
@property
def attn_processors(self):
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1)
setattr(upsample_block, "s2", s2)
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)

View File

@@ -24,7 +24,7 @@ import torch.fft as fft
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter from ...models.unets.unet_motion_model import MotionAdapter
@@ -209,7 +209,9 @@ class PIAPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): class PIAPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-to-video generation. Pipeline for text-to-video generation.
@@ -685,6 +687,35 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents( def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
@@ -1107,12 +1138,9 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_videos_per_prompt
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
) )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)

View File

@@ -522,7 +522,7 @@ def load_hf_numpy(path) -> np.ndarray:
base_url = "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main" base_url = "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main"
if not path.startswith("http://") and not path.startswith("https://"): if not path.startswith("http://") and not path.startswith("https://"):
path = Path(base_url, urllib.parse.quote(path)).as_posix() path = os.path.join(base_url, urllib.parse.quote(path))
return load_numpy(path) return load_numpy(path)

View File

@@ -35,6 +35,7 @@ from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
from diffusers.utils import load_image from diffusers.utils import load_image
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu, require_torch_gpu,
slow, slow,
torch_device, torch_device,
@@ -119,7 +120,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
expected_slice = np.array([0.80810547, 0.88183594, 0.9296875, 0.9189453, 0.9848633, 1.0, 0.97021484, 1.0, 1.0]) expected_slice = np.array([0.80810547, 0.88183594, 0.9296875, 0.9189453, 0.9848633, 1.0, 0.97021484, 1.0, 1.0])
assert np.allclose(image_slice, expected_slice, atol=1e-3) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
@@ -131,7 +133,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
[0.30444336, 0.26513672, 0.22436523, 0.2758789, 0.25585938, 0.20751953, 0.25390625, 0.24633789, 0.21923828] [0.30444336, 0.26513672, 0.22436523, 0.2758789, 0.25585938, 0.20751953, 0.25390625, 0.24633789, 0.21923828]
) )
assert np.allclose(image_slice, expected_slice, atol=1e-3) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
def test_image_to_image(self): def test_image_to_image(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
@@ -149,7 +152,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
[0.22167969, 0.21875, 0.21728516, 0.22607422, 0.21948242, 0.23925781, 0.22387695, 0.25268555, 0.2722168] [0.22167969, 0.21875, 0.21728516, 0.22607422, 0.21948242, 0.23925781, 0.22387695, 0.25268555, 0.2722168]
) )
assert np.allclose(image_slice, expected_slice, atol=1e-3) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
@@ -161,7 +165,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
[0.35913086, 0.265625, 0.26367188, 0.24658203, 0.19750977, 0.39990234, 0.15258789, 0.20336914, 0.5517578] [0.35913086, 0.265625, 0.26367188, 0.24658203, 0.19750977, 0.39990234, 0.15258789, 0.20336914, 0.5517578]
) )
assert np.allclose(image_slice, expected_slice, atol=1e-3) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
def test_inpainting(self): def test_inpainting(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
@@ -179,7 +184,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
[0.27148438, 0.24047852, 0.22167969, 0.23217773, 0.21118164, 0.21142578, 0.21875, 0.20751953, 0.20019531] [0.27148438, 0.24047852, 0.22167969, 0.23217773, 0.21118164, 0.21142578, 0.21875, 0.20751953, 0.20019531]
) )
assert np.allclose(image_slice, expected_slice, atol=1e-3) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
@@ -187,11 +193,8 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
images = pipeline(**inputs).images images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten() image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array( max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
[0.27294922, 0.24023438, 0.21948242, 0.23242188, 0.20825195, 0.2055664, 0.21679688, 0.20336914, 0.19360352] assert max_diff < 5e-4
)
assert np.allclose(image_slice, expected_slice, atol=1e-3)
def test_text_to_image_model_cpu_offload(self): def test_text_to_image_model_cpu_offload(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
@@ -233,11 +236,10 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
images = pipeline(**inputs).images images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten() image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array( expected_slice = np.array([0.1958, 0.1475, 0.1396, 0.2412, 0.1658, 0.1533, 0.3997, 0.4055, 0.4128])
[0.18115234, 0.13500977, 0.13427734, 0.24194336, 0.17138672, 0.16625977, 0.4260254, 0.43359375, 0.4416504]
)
assert np.allclose(image_slice, expected_slice, atol=1e-3) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
def test_unload(self): def test_unload(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
@@ -277,7 +279,9 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
expected_slice = np.array( expected_slice = np.array(
[0.5234375, 0.53515625, 0.5629883, 0.57128906, 0.59521484, 0.62109375, 0.57910156, 0.6201172, 0.6508789] [0.5234375, 0.53515625, 0.5629883, 0.57128906, 0.59521484, 0.62109375, 0.57910156, 0.6201172, 0.6508789]
) )
assert np.allclose(image_slice, expected_slice, atol=1e-3)
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
@slow @slow
@@ -314,7 +318,8 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
] ]
) )
assert np.allclose(image_slice, expected_slice, atol=1e-3) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
@@ -339,7 +344,8 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
[0.0576596, 0.05600825, 0.04479006, 0.05288461, 0.05461192, 0.05137569, 0.04867965, 0.05301541, 0.04939842] [0.0576596, 0.05600825, 0.04479006, 0.05288461, 0.05461192, 0.05137569, 0.04867965, 0.05301541, 0.04939842]
) )
assert np.allclose(image_slice, expected_slice, atol=1e-3) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
def test_image_to_image_sdxl(self): def test_image_to_image_sdxl(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder") image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
@@ -432,7 +438,8 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
[0.14181179, 0.1493012, 0.14283323, 0.14602411, 0.14915377, 0.15015268, 0.14725655, 0.15009224, 0.15164584] [0.14181179, 0.1493012, 0.14283323, 0.14602411, 0.14915377, 0.15015268, 0.14725655, 0.15009224, 0.15164584]
) )
assert np.allclose(image_slice, expected_slice, atol=1e-3) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
@@ -457,4 +464,5 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494]) expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4