mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
8 Commits
custom-blo
...
ip-11
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2d153176a | ||
|
|
7741fb0059 | ||
|
|
37444bcdba | ||
|
|
89f548ca33 | ||
|
|
cd8702e27b | ||
|
|
4475c0bccb | ||
|
|
a7af2b26b3 | ||
|
|
ac1c26d2ea |
@@ -33,16 +33,14 @@ from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -284,7 +282,9 @@ class IPAdapterMixin:
|
||||
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
||||
|
||||
for attn_name, attn_processor in unet.attn_processors.items():
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
if isinstance(
|
||||
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
@@ -342,7 +342,9 @@ class IPAdapterMixin:
|
||||
)
|
||||
attn_procs[name] = (
|
||||
attn_processor_class
|
||||
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
|
||||
if isinstance(
|
||||
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
)
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
@@ -765,6 +765,7 @@ class UNet2DConditionLoadersMixin:
|
||||
from ..models.attention_processor import (
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
@@ -804,11 +805,15 @@ class UNet2DConditionLoadersMixin:
|
||||
if cross_attention_dim is None or "motion_modules" in name:
|
||||
attn_processor_class = self.attn_processors[name].__class__
|
||||
attn_procs[name] = attn_processor_class()
|
||||
|
||||
else:
|
||||
attn_processor_class = (
|
||||
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
|
||||
)
|
||||
if "XFormers" in str(self.attn_processors[name].__class__):
|
||||
attn_processor_class = IPAdapterXFormersAttnProcessor
|
||||
else:
|
||||
attn_processor_class = (
|
||||
IPAdapterAttnProcessor2_0
|
||||
if hasattr(F, "scaled_dot_product_attention")
|
||||
else IPAdapterAttnProcessor
|
||||
)
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
|
||||
@@ -318,7 +318,10 @@ class Attention(nn.Module):
|
||||
XFormersAttnAddedKVProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
is_ip_adapter = hasattr(self, "processor") and isinstance(
|
||||
self.processor,
|
||||
(IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
|
||||
)
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if is_added_kv_processor and is_custom_diffusion:
|
||||
raise NotImplementedError(
|
||||
@@ -368,6 +371,19 @@ class Attention(nn.Module):
|
||||
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
||||
)
|
||||
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
||||
elif is_ip_adapter:
|
||||
processor = IPAdapterXFormersAttnProcessor(
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
scale=self.processor.scale,
|
||||
num_tokens=self.processor.num_tokens,
|
||||
attention_op=attention_op,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_ip"):
|
||||
processor.to(
|
||||
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
||||
)
|
||||
else:
|
||||
processor = XFormersAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
@@ -4542,6 +4558,238 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAdapterXFormersAttnProcessor(torch.nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapter using xFormers.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
||||
The context length of the image features.
|
||||
scale (`float` or `List[float]`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
attention_op (`Callable`, *optional*, defaults to `None`):
|
||||
The base
|
||||
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
||||
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
||||
operator.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
cross_attention_dim=None,
|
||||
num_tokens=(4,),
|
||||
scale=1.0,
|
||||
attention_op: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.attention_op = attention_op
|
||||
|
||||
if not isinstance(num_tokens, (tuple, list)):
|
||||
num_tokens = [num_tokens]
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(num_tokens)
|
||||
if len(scale) != len(num_tokens):
|
||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.ModuleList(
|
||||
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
||||
)
|
||||
self.to_v_ip = nn.ModuleList(
|
||||
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
ip_adapter_masks: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
# separate ip_hidden_states from encoder_hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, tuple):
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
||||
else:
|
||||
deprecation_message = (
|
||||
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
|
||||
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
|
||||
)
|
||||
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
[encoder_hidden_states[:, end_pos:, :]],
|
||||
)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# expand our mask's singleton query_tokens dimension:
|
||||
# [batch*heads, 1, key_tokens] ->
|
||||
# [batch*heads, query_tokens, key_tokens]
|
||||
# so that it can be added as a bias onto the attention scores that xformers computes:
|
||||
# [batch*heads, query_tokens, key_tokens]
|
||||
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
||||
_, query_tokens, _ = hidden_states.shape
|
||||
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op
|
||||
)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if ip_hidden_states:
|
||||
if ip_adapter_masks is not None:
|
||||
if not isinstance(ip_adapter_masks, List):
|
||||
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
||||
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
||||
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
||||
raise ValueError(
|
||||
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
|
||||
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
|
||||
f"({len(ip_hidden_states)})"
|
||||
)
|
||||
else:
|
||||
for index, (mask, scale, ip_state) in enumerate(
|
||||
zip(ip_adapter_masks, self.scale, ip_hidden_states)
|
||||
):
|
||||
if mask is None:
|
||||
continue
|
||||
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
||||
raise ValueError(
|
||||
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
||||
"[1, num_images_for_ip_adapter, height, width]."
|
||||
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
||||
)
|
||||
if mask.shape[1] != ip_state.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of masks ({mask.shape[1]}) does not match "
|
||||
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
||||
)
|
||||
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of masks ({mask.shape[1]}) does not match "
|
||||
f"number of scales ({len(scale)}) at index {index}"
|
||||
)
|
||||
else:
|
||||
ip_adapter_masks = [None] * len(self.scale)
|
||||
|
||||
# for ip-adapter
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
||||
):
|
||||
skip = False
|
||||
if isinstance(scale, list):
|
||||
if all(s == 0 for s in scale):
|
||||
skip = True
|
||||
elif scale == 0:
|
||||
skip = True
|
||||
if not skip:
|
||||
if mask is not None:
|
||||
mask = mask.to(torch.float16)
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * mask.shape[1]
|
||||
|
||||
current_num_images = mask.shape[1]
|
||||
for i in range(current_num_images):
|
||||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
||||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
||||
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
||||
|
||||
_current_ip_hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, ip_key, ip_value, op=self.attention_op
|
||||
)
|
||||
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
|
||||
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
|
||||
|
||||
mask_downsample = IPAdapterMaskProcessor.downsample(
|
||||
mask[:, i, :, :],
|
||||
batch_size,
|
||||
_current_ip_hidden_states.shape[1],
|
||||
_current_ip_hidden_states.shape[2],
|
||||
)
|
||||
|
||||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
||||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
||||
else:
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
||||
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
||||
|
||||
current_ip_hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, ip_key, ip_value, op=self.attention_op
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
||||
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PAGIdentitySelfAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
|
||||
Reference in New Issue
Block a user