mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-01 00:15:00 +08:00
Compare commits
20 Commits
modular-lo
...
feat/ip_ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa3edf9317 | ||
|
|
bed9eee4b5 | ||
|
|
2e83d6c1b1 | ||
|
|
351180f7d0 | ||
|
|
cacee6dd3a | ||
|
|
3d696884fb | ||
|
|
95e38ac87a | ||
|
|
6031383d80 | ||
|
|
f051c9ebaf | ||
|
|
f10eb255db | ||
|
|
651302b352 | ||
|
|
dded7c41ba | ||
|
|
8fe3064c5f | ||
|
|
023c2b7c20 | ||
|
|
a45292b4c9 | ||
|
|
f9aaa54aa6 | ||
|
|
5887af07e7 | ||
|
|
f3755d4905 | ||
|
|
c4646f876a | ||
|
|
08a182835f |
@@ -252,6 +252,7 @@ else:
|
|||||||
"StableDiffusionInpaintPipeline",
|
"StableDiffusionInpaintPipeline",
|
||||||
"StableDiffusionInpaintPipelineLegacy",
|
"StableDiffusionInpaintPipelineLegacy",
|
||||||
"StableDiffusionInstructPix2PixPipeline",
|
"StableDiffusionInstructPix2PixPipeline",
|
||||||
|
"StableDiffusionIPAdapterPipeline",
|
||||||
"StableDiffusionLatentUpscalePipeline",
|
"StableDiffusionLatentUpscalePipeline",
|
||||||
"StableDiffusionLDM3DPipeline",
|
"StableDiffusionLDM3DPipeline",
|
||||||
"StableDiffusionModelEditingPipeline",
|
"StableDiffusionModelEditingPipeline",
|
||||||
@@ -596,6 +597,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
StableDiffusionInpaintPipeline,
|
StableDiffusionInpaintPipeline,
|
||||||
StableDiffusionInpaintPipelineLegacy,
|
StableDiffusionInpaintPipelineLegacy,
|
||||||
StableDiffusionInstructPix2PixPipeline,
|
StableDiffusionInstructPix2PixPipeline,
|
||||||
|
StableDiffusionIPAdapterPipeline,
|
||||||
StableDiffusionLatentUpscalePipeline,
|
StableDiffusionLatentUpscalePipeline,
|
||||||
StableDiffusionLDM3DPipeline,
|
StableDiffusionLDM3DPipeline,
|
||||||
StableDiffusionModelEditingPipeline,
|
StableDiffusionModelEditingPipeline,
|
||||||
|
|||||||
@@ -1969,6 +1969,412 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
|||||||
return attn.processor(attn, hidden_states, *args, **kwargs)
|
return attn.processor(attn, hidden_states, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterAttnProcessor(nn.Module):
|
||||||
|
r"""
|
||||||
|
Attention processor for IP-Adapater.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_size (`int`):
|
||||||
|
The hidden size of the attention layer.
|
||||||
|
cross_attention_dim (`int`):
|
||||||
|
The number of channels in the `encoder_hidden_states`.
|
||||||
|
text_context_len (`int`, defaults to 77):
|
||||||
|
The context length of the text features.
|
||||||
|
scale (`float`, defaults to 1.0):
|
||||||
|
the weight scale of image prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.text_context_len = text_context_len
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
temb=None,
|
||||||
|
scale=1.0,
|
||||||
|
):
|
||||||
|
if scale != 1.0:
|
||||||
|
logger.warning("`scale` of IPAttnProcessor should be set with `IPAdapterPipeline.set_scale`.")
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
# split hidden states
|
||||||
|
encoder_hidden_states, ip_hidden_states = (
|
||||||
|
encoder_hidden_states[:, : self.text_context_len, :],
|
||||||
|
encoder_hidden_states[:, self.text_context_len :, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
key = attn.head_to_batch_dim(key)
|
||||||
|
value = attn.head_to_batch_dim(value)
|
||||||
|
|
||||||
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||||
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# for ip-adapter
|
||||||
|
ip_key = self.to_k_ip(ip_hidden_states)
|
||||||
|
ip_value = self.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
|
ip_key = attn.head_to_batch_dim(ip_key)
|
||||||
|
ip_value = attn.head_to_batch_dim(ip_value)
|
||||||
|
|
||||||
|
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||||
|
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||||
|
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||||
|
r"""
|
||||||
|
Attention processor for IP-Adapater for PyTorch 2.0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_size (`int`):
|
||||||
|
The hidden size of the attention layer.
|
||||||
|
cross_attention_dim (`int`):
|
||||||
|
The number of channels in the `encoder_hidden_states`.
|
||||||
|
text_context_len (`int`, defaults to 77):
|
||||||
|
The context length of the text features.
|
||||||
|
scale (`float`, defaults to 1.0):
|
||||||
|
the weight scale of image prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError(
|
||||||
|
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.text_context_len = text_context_len
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
temb=None,
|
||||||
|
scale=1.0,
|
||||||
|
):
|
||||||
|
if scale != 1.0:
|
||||||
|
logger.warning("`scale` of IPAttnProcessor should be set by " "`IPAdapterPipeline.set_scale`")
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
# split hidden states
|
||||||
|
encoder_hidden_states, ip_hidden_states = (
|
||||||
|
encoder_hidden_states[:, : self.text_context_len, :],
|
||||||
|
encoder_hidden_states[:, self.text_context_len :, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# for ip-adapter
|
||||||
|
ip_key = self.to_k_ip(ip_hidden_states)
|
||||||
|
ip_value = self.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterControlNetAttnProcessor:
|
||||||
|
r"""
|
||||||
|
Default processor for performing attention-related computations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, text_context_len=77):
|
||||||
|
self.text_context_len = text_context_len
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
temb=None,
|
||||||
|
scale=1.0,
|
||||||
|
):
|
||||||
|
if scale != 1.0:
|
||||||
|
logger.warning("`scale` of IPAttnProcessor should be set by " "`IPAdapterPipeline.set_scale`")
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = encoder_hidden_states[:, : self.text_context_len] # only use text
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
key = attn.head_to_batch_dim(key)
|
||||||
|
value = attn.head_to_batch_dim(value)
|
||||||
|
|
||||||
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||||
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterControlNetAttnProcessor2_0:
|
||||||
|
r"""
|
||||||
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, text_context_len=77):
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError(
|
||||||
|
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||||
|
)
|
||||||
|
self.text_context_len = text_context_len
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
temb=None,
|
||||||
|
scale=1.0,
|
||||||
|
):
|
||||||
|
if scale != 1.0:
|
||||||
|
logger.warning("`scale` of IPAttnProcessor should be set by " "`IPAdapterPipeline.set_scale`")
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = encoder_hidden_states[:, : self.text_context_len]
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
LORA_ATTENTION_PROCESSORS = (
|
LORA_ATTENTION_PROCESSORS = (
|
||||||
LoRAAttnProcessor,
|
LoRAAttnProcessor,
|
||||||
LoRAAttnProcessor2_0,
|
LoRAAttnProcessor2_0,
|
||||||
@@ -1992,6 +2398,8 @@ CROSS_ATTENTION_PROCESSORS = (
|
|||||||
LoRAAttnProcessor,
|
LoRAAttnProcessor,
|
||||||
LoRAAttnProcessor2_0,
|
LoRAAttnProcessor2_0,
|
||||||
LoRAXFormersAttnProcessor,
|
LoRAXFormersAttnProcessor,
|
||||||
|
IPAdapterAttnProcessor,
|
||||||
|
IPAdapterAttnProcessor2_0,
|
||||||
)
|
)
|
||||||
|
|
||||||
AttentionProcessor = Union[
|
AttentionProcessor = Union[
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ else:
|
|||||||
"IFPipeline",
|
"IFPipeline",
|
||||||
"IFSuperResolutionPipeline",
|
"IFSuperResolutionPipeline",
|
||||||
]
|
]
|
||||||
|
_import_structure["ip_adapter"] = ["StableDiffusionIPAdapterPipeline"]
|
||||||
_import_structure["kandinsky"] = [
|
_import_structure["kandinsky"] = [
|
||||||
"KandinskyCombinedPipeline",
|
"KandinskyCombinedPipeline",
|
||||||
"KandinskyImg2ImgCombinedPipeline",
|
"KandinskyImg2ImgCombinedPipeline",
|
||||||
@@ -316,6 +317,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
IFPipeline,
|
IFPipeline,
|
||||||
IFSuperResolutionPipeline,
|
IFSuperResolutionPipeline,
|
||||||
)
|
)
|
||||||
|
from .ip_adapter import StableDiffusionIPAdapterPipeline
|
||||||
from .kandinsky import (
|
from .kandinsky import (
|
||||||
KandinskyCombinedPipeline,
|
KandinskyCombinedPipeline,
|
||||||
KandinskyImg2ImgCombinedPipeline,
|
KandinskyImg2ImgCombinedPipeline,
|
||||||
|
|||||||
2
src/diffusers/pipelines/ip_adapter/__init__.py
Normal file
2
src/diffusers/pipelines/ip_adapter/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from .image_projection import ImageProjectionModel
|
||||||
|
from .pipeline_ip_adapter import StableDiffusionIPAdapterPipeline
|
||||||
39
src/diffusers/pipelines/ip_adapter/image_projection.py
Normal file
39
src/diffusers/pipelines/ip_adapter/image_projection.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ...models.modeling_utils import ModelMixin
|
||||||
|
|
||||||
|
|
||||||
|
class ImageProjectionModel(ModelMixin, ConfigMixin):
|
||||||
|
"""Image Projection Model."""
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(self, cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||||
|
self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
||||||
|
self.norm = nn.LayerNorm(cross_attention_dim)
|
||||||
|
|
||||||
|
def forward(self, image_embeds):
|
||||||
|
embeds = image_embeds
|
||||||
|
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||||
|
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
||||||
|
)
|
||||||
|
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
||||||
|
return clip_extra_context_tokens
|
||||||
625
src/diffusers/pipelines/ip_adapter/pipeline_ip_adapter.py
Normal file
625
src/diffusers/pipelines/ip_adapter/pipeline_ip_adapter.py
Normal file
@@ -0,0 +1,625 @@
|
|||||||
|
# Copyright 2023 IP Adapter Authors and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||||
|
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||||
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
|
from ...models.attention_processor import (
|
||||||
|
AttnProcessor,
|
||||||
|
AttnProcessor2_0,
|
||||||
|
IPAdapterAttnProcessor,
|
||||||
|
IPAdapterAttnProcessor2_0,
|
||||||
|
)
|
||||||
|
from ...models.lora import adjust_lora_scale_text_encoder
|
||||||
|
from ...schedulers import KarrasDiffusionSchedulers
|
||||||
|
from ...utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, _get_model_file, logging
|
||||||
|
from ...utils.torch_utils import randn_tensor
|
||||||
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
|
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||||
|
from .image_projection import ImageProjectionModel
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||||
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||||
|
"""
|
||||||
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||||
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
||||||
|
"""
|
||||||
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||||
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||||
|
# rescale the results from guidance (fixes overexposure)
|
||||||
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||||
|
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||||
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||||
|
return noise_cfg
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionIPAdapterPipeline(DiffusionPipeline):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vae: AutoencoderKL,
|
||||||
|
text_encoder: CLIPTextModel,
|
||||||
|
tokenizer: CLIPTokenizer,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
ip_adapter_image_processor: CLIPImageProcessor,
|
||||||
|
image_encoder: CLIPVisionModelWithProjection,
|
||||||
|
scheduler: KarrasDiffusionSchedulers,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.register_modules(
|
||||||
|
vae=vae,
|
||||||
|
unet=unet,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
image_encoder=image_encoder,
|
||||||
|
ip_adapter_image_processor=ip_adapter_image_processor,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||||
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
||||||
|
|
||||||
|
def _set_ip_adapter(self):
|
||||||
|
unet = self.unet
|
||||||
|
attn_procs = {}
|
||||||
|
for name in unet.attn_processors.keys():
|
||||||
|
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||||
|
if name.startswith("mid_block"):
|
||||||
|
hidden_size = unet.config.block_out_channels[-1]
|
||||||
|
elif name.startswith("up_blocks"):
|
||||||
|
block_id = int(name[len("up_blocks.")])
|
||||||
|
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||||
|
elif name.startswith("down_blocks"):
|
||||||
|
block_id = int(name[len("down_blocks.")])
|
||||||
|
hidden_size = unet.config.block_out_channels[block_id]
|
||||||
|
if cross_attention_dim is None:
|
||||||
|
attn_processor_class = (
|
||||||
|
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
|
||||||
|
)
|
||||||
|
attn_procs[name] = attn_processor_class()
|
||||||
|
else:
|
||||||
|
attn_processor_class = (
|
||||||
|
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
|
||||||
|
)
|
||||||
|
attn_procs[name] = attn_processor_class(
|
||||||
|
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
||||||
|
).to(dtype=unet.dtype, device=unet.device)
|
||||||
|
|
||||||
|
unet.set_attn_processor(attn_procs)
|
||||||
|
|
||||||
|
# TODO: create a separate pipeline for this: `StableDiffusionControlNetIPAdapterPipeline`.
|
||||||
|
# if hasattr(self.pipeline, "controlnet"):
|
||||||
|
# attn_processor_class = (
|
||||||
|
# CNAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CNAttnProcessor
|
||||||
|
# )
|
||||||
|
# self.pipeline.controlnet.set_attn_processor(attn_processor_class())
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||||
|
r"""
|
||||||
|
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||||
|
computing decoding in one step.
|
||||||
|
"""
|
||||||
|
self.vae.disable_slicing()
|
||||||
|
|
||||||
|
def disable_vae_slicing(self):
|
||||||
|
r"""
|
||||||
|
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||||
|
computing decoding in one step.
|
||||||
|
"""
|
||||||
|
self.vae.disable_slicing()
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
||||||
|
def enable_vae_tiling(self):
|
||||||
|
r"""
|
||||||
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||||
|
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||||
|
processing larger images.
|
||||||
|
"""
|
||||||
|
self.vae.enable_tiling()
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||||
|
def disable_vae_tiling(self):
|
||||||
|
r"""
|
||||||
|
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||||
|
computing decoding in one step.
|
||||||
|
"""
|
||||||
|
self.vae.disable_tiling()
|
||||||
|
|
||||||
|
def load_ip_adapter(
|
||||||
|
self,
|
||||||
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||||
|
the Hub.
|
||||||
|
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||||
|
with [`ModelMixin.save_pretrained`].
|
||||||
|
- A [torch state
|
||||||
|
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||||
|
|
||||||
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||||
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||||
|
is not used.
|
||||||
|
force_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
|
cached versions if they exist.
|
||||||
|
resume_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||||
|
incompletely downloaded files are deleted.
|
||||||
|
proxies (`Dict[str, str]`, *optional*):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||||
|
won't be downloaded from the Hub.
|
||||||
|
use_auth_token (`str` or *bool*, *optional*):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||||
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||||
|
revision (`str`, *optional*, defaults to `"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||||
|
allowed by Git.
|
||||||
|
subfolder (`str`, *optional*, defaults to `""`):
|
||||||
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||||
|
"""
|
||||||
|
self._set_ip_adapter()
|
||||||
|
|
||||||
|
# Load the main state dict first/
|
||||||
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
subfolder = kwargs.pop("subfolder", None)
|
||||||
|
weight_name = kwargs.pop("weight_name", None)
|
||||||
|
|
||||||
|
user_agent = {
|
||||||
|
"file_type": "attn_procs_weights",
|
||||||
|
"framework": "pytorch",
|
||||||
|
}
|
||||||
|
|
||||||
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||||
|
model_file = _get_model_file(
|
||||||
|
pretrained_model_name_or_path_or_dict,
|
||||||
|
weights_name=weight_name,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
state_dict = torch.load(model_file, map_location="cpu")
|
||||||
|
else:
|
||||||
|
state_dict = pretrained_model_name_or_path_or_dict
|
||||||
|
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
if keys != ["image_proj", "ip_adapter"]:
|
||||||
|
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing.")
|
||||||
|
|
||||||
|
# Handle image projection layers.
|
||||||
|
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
|
||||||
|
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
|
||||||
|
image_projection = ImageProjectionModel(
|
||||||
|
cross_attention_dim=cross_attention_dim, clip_embeddings_dim=clip_embeddings_dim
|
||||||
|
)
|
||||||
|
image_projection.to(dtype=self.unet.dtype, device=self.unet.device)
|
||||||
|
image_projection.load_state_dict(state_dict["image_proj"])
|
||||||
|
self.image_projection = image_projection
|
||||||
|
|
||||||
|
# Handle IP-Adapter cross-attention layers.
|
||||||
|
ip_layers = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
module if isinstance(module, nn.Module) else nn.Identity()
|
||||||
|
for module in self.unet.attn_processors.values()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||||
|
|
||||||
|
def set_scale(self, scale):
|
||||||
|
for attn_processor in self.unet.attn_processors.values():
|
||||||
|
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||||
|
attn_processor.scale = scale
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||||
|
def prepare_extra_step_kwargs(self, generator, eta):
|
||||||
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||||
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||||
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||||
|
# and should be between [0, 1]
|
||||||
|
|
||||||
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||||
|
extra_step_kwargs = {}
|
||||||
|
if accepts_eta:
|
||||||
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
# check if the scheduler accepts generator
|
||||||
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||||
|
if accepts_generator:
|
||||||
|
extra_step_kwargs["generator"] = generator
|
||||||
|
return extra_step_kwargs
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
device,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt=None,
|
||||||
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
lora_scale: Optional[float] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Encodes the prompt into text encoder hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
prompt to be encoded
|
||||||
|
device: (`torch.device`):
|
||||||
|
torch device
|
||||||
|
num_images_per_prompt (`int`):
|
||||||
|
number of images that should be generated per prompt
|
||||||
|
do_classifier_free_guidance (`bool`):
|
||||||
|
whether to use classifier free guidance or not
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||||
|
less than `1`).
|
||||||
|
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
lora_scale (`float`, *optional*):
|
||||||
|
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||||
|
"""
|
||||||
|
# set lora scale so that monkey patched LoRA
|
||||||
|
# function of text encoder can correctly access it
|
||||||
|
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||||
|
self._lora_scale = lora_scale
|
||||||
|
|
||||||
|
# dynamically adjust the LoRA scale
|
||||||
|
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||||
|
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
if prompt_embeds is None:
|
||||||
|
# textual inversion: procecss multi-vector tokens if necessary
|
||||||
|
if isinstance(self, TextualInversionLoaderMixin):
|
||||||
|
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||||
|
|
||||||
|
text_inputs = self.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=self.tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
text_input_ids = text_inputs.input_ids
|
||||||
|
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||||
|
text_input_ids, untruncated_ids
|
||||||
|
):
|
||||||
|
removed_text = self.tokenizer.batch_decode(
|
||||||
|
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||||
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||||
|
attention_mask = text_inputs.attention_mask.to(device)
|
||||||
|
else:
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
|
prompt_embeds = self.text_encoder(
|
||||||
|
text_input_ids.to(device),
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
prompt_embeds = prompt_embeds[0]
|
||||||
|
|
||||||
|
if self.text_encoder is not None:
|
||||||
|
prompt_embeds_dtype = self.text_encoder.dtype
|
||||||
|
elif self.unet is not None:
|
||||||
|
prompt_embeds_dtype = self.unet.dtype
|
||||||
|
else:
|
||||||
|
prompt_embeds_dtype = prompt_embeds.dtype
|
||||||
|
|
||||||
|
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||||
|
|
||||||
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||||
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
# get unconditional embeddings for classifier free guidance
|
||||||
|
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||||
|
uncond_tokens: List[str]
|
||||||
|
if negative_prompt is None:
|
||||||
|
uncond_tokens = [""] * batch_size
|
||||||
|
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||||
|
raise TypeError(
|
||||||
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
|
f" {type(prompt)}."
|
||||||
|
)
|
||||||
|
elif isinstance(negative_prompt, str):
|
||||||
|
uncond_tokens = [negative_prompt]
|
||||||
|
elif batch_size != len(negative_prompt):
|
||||||
|
raise ValueError(
|
||||||
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||||
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||||
|
" the batch size of `prompt`."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
uncond_tokens = negative_prompt
|
||||||
|
|
||||||
|
# textual inversion: procecss multi-vector tokens if necessary
|
||||||
|
if isinstance(self, TextualInversionLoaderMixin):
|
||||||
|
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||||
|
|
||||||
|
max_length = prompt_embeds.shape[1]
|
||||||
|
uncond_input = self.tokenizer(
|
||||||
|
uncond_tokens,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||||
|
attention_mask = uncond_input.attention_mask.to(device)
|
||||||
|
else:
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
|
negative_prompt_embeds = self.text_encoder(
|
||||||
|
uncond_input.input_ids.to(device),
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||||
|
seq_len = negative_prompt_embeds.shape[1]
|
||||||
|
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||||
|
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
return prompt_embeds, negative_prompt_embeds
|
||||||
|
|
||||||
|
def encode_image(self, image, device, num_images_per_prompt):
|
||||||
|
dtype = next(self.image_encoder.parameters()).dtype
|
||||||
|
|
||||||
|
if not isinstance(image, torch.Tensor):
|
||||||
|
image = self.ip_adapter_image_processor(image, return_tensors="pt").pixel_values
|
||||||
|
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
(image_embeddings,) = self.image_encoder(image).image_embeds
|
||||||
|
image_prompt_embeds = self.image_projection(image_embeddings)
|
||||||
|
uncond_image_prompt_embeds = self.image_projection(torch.zeros_like(image_embeddings))
|
||||||
|
|
||||||
|
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
||||||
|
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
||||||
|
image_prompt_embeds = image_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
|
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||||
|
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
|
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||||
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||||
|
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||||
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
latents = latents.to(device)
|
||||||
|
|
||||||
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
latents = latents * self.scheduler.init_noise_sigma
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
image: PipelineImageInput = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
guidance_scale: float = 7.5,
|
||||||
|
guidance_rescale: float = 0.0,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
ip_adapter_scale: float = 1.0,
|
||||||
|
eta: float = 0.0,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||||
|
callback_steps: int = 1,
|
||||||
|
):
|
||||||
|
# 0. Set IP Adapter scale
|
||||||
|
self.set_scale(ip_adapter_scale)
|
||||||
|
|
||||||
|
# 1. Check inputs and raise error if needed.
|
||||||
|
if hasattr(self, "image_projection") and getattr(self, "image_projection") is None:
|
||||||
|
raise (
|
||||||
|
"This pipeline cannot be called without having an `image_projection` module. Did you call `load_ip_adapter()` before running the pipeline?"
|
||||||
|
)
|
||||||
|
# TODO
|
||||||
|
|
||||||
|
# 1. Define call parameters
|
||||||
|
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
|
# corresponds to doing no classifier free guidance.
|
||||||
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
# 2. Encode input image
|
||||||
|
image_embeddings, uncond_image_embeddings = self.encode_image(image, device, num_images_per_prompt)
|
||||||
|
|
||||||
|
# 3. Encode prompt
|
||||||
|
text_encoder_lora_scale = (
|
||||||
|
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||||
|
)
|
||||||
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||||
|
prompt,
|
||||||
|
device,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
lora_scale=text_encoder_lora_scale,
|
||||||
|
)
|
||||||
|
prompt_embeds = torch.cat([prompt_embeds, image_embeddings], dim=1)
|
||||||
|
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_embeddings], dim=1)
|
||||||
|
|
||||||
|
# For classifier free guidance, we need to do two forward passes.
|
||||||
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||||
|
# to avoid doing two forward passes
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||||
|
|
||||||
|
# 4. Prepare timesteps
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
|
||||||
|
# 5. Prepare latent variables
|
||||||
|
num_channels_latents = self.unet.config.in_channels
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
|
# 7. Denoising loop
|
||||||
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
noise_pred = self.unet(
|
||||||
|
latent_model_input,
|
||||||
|
t,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||||
|
progress_bar.update()
|
||||||
|
if callback is not None and i % callback_steps == 0:
|
||||||
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
if not output_type == "latent":
|
||||||
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||||
|
has_nsfw_concept = None
|
||||||
|
# image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||||
|
else:
|
||||||
|
image = latents
|
||||||
|
has_nsfw_concept = None
|
||||||
|
|
||||||
|
if has_nsfw_concept is None:
|
||||||
|
do_denormalize = [True] * image.shape[0]
|
||||||
|
else:
|
||||||
|
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||||
|
|
||||||
|
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||||
|
|
||||||
|
# # Offload last model to CPU
|
||||||
|
# TODO
|
||||||
|
# if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||||
|
# self.final_offload_hook.offload()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image, has_nsfw_concept)
|
||||||
|
|
||||||
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||||
@@ -17,10 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...image_processor import VaeImageProcessor
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...models.lora import adjust_lora_scale_text_encoder
|
from ...models.lora import adjust_lora_scale_text_encoder
|
||||||
@@ -103,7 +103,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
|||||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||||
_optional_components = ["safety_checker", "feature_extractor"]
|
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
||||||
_exclude_from_cpu_offload = ["safety_checker"]
|
_exclude_from_cpu_offload = ["safety_checker"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -111,6 +111,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
|||||||
vae: AutoencoderKL,
|
vae: AutoencoderKL,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
|
image_encoder: CLIPVisionModelWithProjection,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
scheduler: KarrasDiffusionSchedulers,
|
scheduler: KarrasDiffusionSchedulers,
|
||||||
safety_checker: StableDiffusionSafetyChecker,
|
safety_checker: StableDiffusionSafetyChecker,
|
||||||
@@ -191,6 +192,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
safety_checker=safety_checker,
|
safety_checker=safety_checker,
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
|
image_encoder=image_encoder,
|
||||||
)
|
)
|
||||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||||
@@ -438,6 +440,19 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
|||||||
|
|
||||||
return prompt_embeds, negative_prompt_embeds
|
return prompt_embeds, negative_prompt_embeds
|
||||||
|
|
||||||
|
def encode_image(self, image, device, num_images_per_prompt):
|
||||||
|
dtype = next(self.image_encoder.parameters()).dtype
|
||||||
|
|
||||||
|
if not isinstance(image, torch.Tensor):
|
||||||
|
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||||
|
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
image_embeds = self.image_encoder(image).image_embeds
|
||||||
|
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||||
|
|
||||||
|
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||||
|
return image_embeds, uncond_image_embeds
|
||||||
|
|
||||||
def run_safety_checker(self, image, device, dtype):
|
def run_safety_checker(self, image, device, dtype):
|
||||||
if self.safety_checker is None:
|
if self.safety_checker is None:
|
||||||
has_nsfw_concept = None
|
has_nsfw_concept = None
|
||||||
@@ -575,6 +590,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]] = None,
|
prompt: Union[str, List[str]] = None,
|
||||||
|
image_prompt: PipelineImageInput = None,
|
||||||
height: Optional[int] = None,
|
height: Optional[int] = None,
|
||||||
width: Optional[int] = None,
|
width: Optional[int] = None,
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
@@ -706,6 +722,11 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
|||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||||
|
|
||||||
|
if image_prompt is not None:
|
||||||
|
image_embeds, negative_image_embeds = self.image_encoder(image_prompt, device, num_images_per_prompt)
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||||
|
|
||||||
# 4. Prepare timesteps
|
# 4. Prepare timesteps
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
timesteps = self.scheduler.timesteps
|
timesteps = self.scheduler.timesteps
|
||||||
@@ -733,13 +754,15 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
|||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
if image_prompt is not None:
|
||||||
|
added_cond_kwargs = {"image_embeds": image_embeds}
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
latent_model_input,
|
latent_model_input,
|
||||||
t,
|
t,
|
||||||
encoder_hidden_states=prompt_embeds,
|
encoder_hidden_states=prompt_embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user