mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 13:04:15 +08:00
Compare commits
2 Commits
rename-att
...
fix-part-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8d40b3d5d | ||
|
|
d699d686c0 |
@@ -20,7 +20,7 @@ from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalControlNetMixin
|
||||
from ..utils import BaseOutput, deprecate, logging
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
@@ -43,24 +43,6 @@ from .unets.unet_2d_condition import UNet2DConditionModel
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def correct_incorrect_names(attention_head_dim, down_block_types, mid_block_type, block_out_channels):
|
||||
incorrect_attention_head_dim_name = False
|
||||
if "CrossAttnDownBlock2D" in down_block_types or mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
incorrect_attention_head_dim_name = True
|
||||
|
||||
if incorrect_attention_head_dim_name:
|
||||
num_attention_heads = attention_head_dim
|
||||
else:
|
||||
# we use attention_head_dim to calculate num_attention_heads
|
||||
if isinstance(attention_head_dim, int):
|
||||
num_attention_heads = [out_channels // attention_head_dim for out_channels in block_out_channels]
|
||||
else:
|
||||
num_attention_heads = [
|
||||
out_channels // attn_dim for out_channels, attn_dim in zip(block_out_channels, attention_head_dim)
|
||||
]
|
||||
return num_attention_heads
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetOutput(BaseOutput):
|
||||
"""
|
||||
@@ -240,22 +222,15 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if attention_head_dim is not None:
|
||||
deprecation_message = " `attention_head_dim` is deprecated and will be removed in a future version. Use `num_attention_heads`."
|
||||
deprecate("attention_head_dim not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
num_attention_heads = correct_incorrect_names(
|
||||
attention_head_dim, down_block_types, mid_block_type, block_out_channels
|
||||
)
|
||||
logger.warning(
|
||||
f"corrected potentially incorrect arguments attention_head_dim {attention_head_dim}."
|
||||
f" the model will be configured with `num_attention_heads` {num_attention_heads}."
|
||||
)
|
||||
attention_head_dim = None
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if num_attention_heads is None:
|
||||
raise ValueError("`num_attention_heads` cannot be None.")
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
@@ -270,13 +245,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# we use num_attention_heads to calculate attention_head_dim
|
||||
attention_head_dim = [
|
||||
out_channels // num_heads for out_channels, num_heads in zip(block_out_channels, num_attention_heads)
|
||||
]
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
@@ -386,6 +354,12 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
@@ -411,7 +385,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
@@ -448,7 +422,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
attention_head_dim=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
|
||||
@@ -119,7 +119,6 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
downsample_type=downsample_type,
|
||||
@@ -141,7 +140,6 @@ def get_down_block(
|
||||
downsample_padding=downsample_padding,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
@@ -163,7 +161,6 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
@@ -194,7 +191,6 @@ def get_down_block(
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
@@ -222,7 +218,6 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
@@ -248,7 +243,6 @@ def get_down_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
add_self_attention=True if not add_downsample else False,
|
||||
)
|
||||
@@ -341,7 +335,6 @@ def get_up_block(
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
@@ -365,7 +358,6 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
@@ -390,7 +382,6 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
upsample_type=upsample_type,
|
||||
@@ -421,7 +412,6 @@ def get_up_block(
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
@@ -450,7 +440,6 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
temb_channels=temb_channels,
|
||||
@@ -479,7 +468,6 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
|
||||
@@ -567,7 +555,6 @@ class UNetMidBlock2D(nn.Module):
|
||||
attn_groups: Optional[int] = None,
|
||||
resnet_pre_norm: bool = True,
|
||||
add_attention: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
):
|
||||
@@ -615,15 +602,13 @@ class UNetMidBlock2D(nn.Module):
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
||||
)
|
||||
attention_head_dim = in_channels
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = in_channels // attention_head_dim
|
||||
|
||||
for _ in range(num_layers):
|
||||
if self.add_attention:
|
||||
attentions.append(
|
||||
Attention(
|
||||
in_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=in_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -695,7 +680,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
output_scale_factor: float = 1.0,
|
||||
cross_attention_dim: int = 1280,
|
||||
dual_cross_attention: bool = False,
|
||||
@@ -709,9 +693,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
self.num_attention_heads = num_attention_heads
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = in_channels // num_attention_heads
|
||||
|
||||
# support for variable transformer layers per block
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
@@ -737,8 +718,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -751,8 +732,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -843,7 +824,6 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
cross_attention_dim: int = 1280,
|
||||
@@ -858,9 +838,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
self.attention_head_dim = attention_head_dim
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = in_channels // attention_head_dim
|
||||
self.num_heads = num_attention_heads
|
||||
self.num_heads = in_channels // self.attention_head_dim
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
@@ -971,7 +949,6 @@ class AttnDownBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
downsample_padding: int = 1,
|
||||
@@ -988,9 +965,6 @@ class AttnDownBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
@@ -1010,7 +984,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=out_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -1100,7 +1074,6 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
downsample_padding: int = 1,
|
||||
@@ -1117,9 +1090,6 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = out_channels // num_attention_heads
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
@@ -1142,8 +1112,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -1157,8 +1127,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -1425,7 +1395,6 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_downsample: bool = True,
|
||||
@@ -1441,9 +1410,6 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
if resnet_time_scale_shift == "spatial":
|
||||
@@ -1478,7 +1444,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=out_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -1529,7 +1495,6 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = np.sqrt(2.0),
|
||||
add_downsample: bool = True,
|
||||
@@ -1544,9 +1509,6 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
self.resnets.append(
|
||||
@@ -1567,7 +1529,7 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
self.attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=out_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -1827,7 +1789,6 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
@@ -1844,9 +1805,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
self.num_heads = num_attention_heads
|
||||
self.num_heads = out_channels // self.attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
@@ -1874,7 +1833,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
Attention(
|
||||
query_dim=out_channels,
|
||||
cross_attention_dim=out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=self.num_heads,
|
||||
dim_head=attention_head_dim,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
@@ -2068,7 +2027,6 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
num_layers: int = 4,
|
||||
resnet_group_size: int = 32,
|
||||
add_downsample: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 64,
|
||||
add_self_attention: bool = False,
|
||||
resnet_eps: float = 1e-5,
|
||||
@@ -2078,9 +2036,6 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
self.has_cross_attention = True
|
||||
|
||||
for i in range(num_layers):
|
||||
@@ -2104,9 +2059,9 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
)
|
||||
attentions.append(
|
||||
KAttentionBlock(
|
||||
dim=out_channels,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
out_channels,
|
||||
out_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
temb_channels=temb_channels,
|
||||
attention_bias=True,
|
||||
@@ -2203,7 +2158,6 @@ class AttnUpBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
upsample_type: str = "conv",
|
||||
@@ -2220,9 +2174,6 @@ class AttnUpBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
@@ -2244,7 +2195,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=out_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -2329,7 +2280,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
@@ -2346,9 +2296,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = out_channels // num_attention_heads
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
@@ -2373,8 +2320,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -2388,8 +2335,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -2687,7 +2634,6 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
@@ -2703,9 +2649,6 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
input_channels = in_channels if i == 0 else out_channels
|
||||
|
||||
@@ -2742,7 +2685,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=out_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -2794,7 +2737,6 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = np.sqrt(2.0),
|
||||
add_upsample: bool = True,
|
||||
@@ -2829,13 +2771,10 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
self.attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=out_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -3143,7 +3082,6 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
@@ -3159,9 +3097,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
self.has_cross_attention = True
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
self.num_heads = num_attention_heads
|
||||
self.num_heads = out_channels // self.attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
@@ -3191,8 +3127,8 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
Attention(
|
||||
query_dim=out_channels,
|
||||
cross_attention_dim=out_channels,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
heads=self.num_heads,
|
||||
dim_head=self.attention_head_dim,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
@@ -3398,7 +3334,6 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
resnet_eps: float = 1e-5,
|
||||
resnet_act_fn: str = "gelu",
|
||||
resnet_group_size: int = 32,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1, # attention dim_head
|
||||
cross_attention_dim: int = 768,
|
||||
add_upsample: bool = True,
|
||||
@@ -3415,11 +3350,6 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
self.has_cross_attention = True
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
if num_attention_heads is not None:
|
||||
logger.warn(
|
||||
"`num_attention_heads` argument is passed but ignored. The number of attention heads is determined by `attention_head_dim`, `in_channels` and `out_channels`."
|
||||
)
|
||||
|
||||
# in_channels, and out_channels for the block (k-unet)
|
||||
k_in_channels = out_channels if is_first_block else 2 * out_channels
|
||||
k_out_channels = in_channels
|
||||
@@ -3453,11 +3383,11 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
)
|
||||
attentions.append(
|
||||
KAttentionBlock(
|
||||
dim=k_out_channels if (i == num_layers - 1) else out_channels,
|
||||
num_attention_heads=k_out_channels // attention_head_dim
|
||||
k_out_channels if (i == num_layers - 1) else out_channels,
|
||||
k_out_channels // attention_head_dim
|
||||
if (i == num_layers - 1)
|
||||
else out_channels // attention_head_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
temb_channels=temb_channels,
|
||||
attention_bias=True,
|
||||
|
||||
@@ -55,28 +55,6 @@ from .unet_2d_blocks import (
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def correct_incorrect_names(attention_head_dim, down_block_types, mid_block_type, up_block_types, block_out_channels):
|
||||
incorrect_attention_head_dim_name = False
|
||||
if (
|
||||
"CrossAttnDownBlock2D" in down_block_types
|
||||
or "CrossAttnUpBlock2D" in up_block_types
|
||||
or mid_block_type == "UNetMidBlock2DCrossAttn"
|
||||
):
|
||||
incorrect_attention_head_dim_name = True
|
||||
|
||||
if incorrect_attention_head_dim_name:
|
||||
num_attention_heads = attention_head_dim
|
||||
else:
|
||||
# we use attention_head_dim to calculate num_attention_heads
|
||||
if isinstance(attention_head_dim, int):
|
||||
num_attention_heads = [out_channels // attention_head_dim for out_channels in block_out_channels]
|
||||
else:
|
||||
num_attention_heads = [
|
||||
out_channels // attn_dim for out_channels, attn_dim in zip(block_out_channels, attention_head_dim)
|
||||
]
|
||||
return num_attention_heads
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DConditionOutput(BaseOutput):
|
||||
"""
|
||||
@@ -247,21 +225,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
if attention_head_dim is not None:
|
||||
deprecation_message = " `attention_head_dim` is deprecated and will be removed in a future version. Use `num_attention_heads` instead."
|
||||
deprecate("attention_head_dim not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
num_attention_heads = correct_incorrect_names(
|
||||
attention_head_dim, down_block_types, mid_block_type, up_block_types, block_out_channels
|
||||
if num_attention_heads is not None:
|
||||
raise ValueError(
|
||||
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
||||
)
|
||||
logger.warning(
|
||||
f"corrected potentially incorrect arguments attention_head_dim {attention_head_dim}."
|
||||
f"the model will be configured with `num_attention_heads` {num_attention_heads}."
|
||||
)
|
||||
attention_head_dim = None
|
||||
|
||||
if num_attention_heads is None:
|
||||
raise ValueError("`num_attention_heads` cannot be None.")
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
@@ -282,6 +259,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
||||
@@ -296,14 +278,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
if isinstance(layer_number_per_block, list):
|
||||
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
||||
|
||||
# make sure num_attention_heads is a tuple
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# we use num_attention_heads to calculate attention_head_dim
|
||||
attention_head_dim = [
|
||||
out_channels // num_heads for out_channels, num_heads in zip(block_out_channels, num_attention_heads)
|
||||
]
|
||||
# input
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
@@ -445,6 +419,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
if mid_block_only_cross_attention is None:
|
||||
mid_block_only_cross_attention = False
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(cross_attention_dim, int):
|
||||
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
||||
|
||||
@@ -492,7 +472,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
resnet_skip_time_act=resnet_skip_time_act,
|
||||
resnet_out_scale_factor=resnet_out_scale_factor,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
attention_head_dim=attention_head_dim[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
@@ -510,7 +490,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
attention_head_dim=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -526,7 +505,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
attention_head_dim=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
@@ -558,7 +536,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
reversed_layers_per_block = list(reversed(layers_per_block))
|
||||
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
||||
reversed_transformer_layers_per_block = (
|
||||
@@ -607,7 +584,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
resnet_skip_time_act=resnet_skip_time_act,
|
||||
resnet_out_scale_factor=resnet_out_scale_factor,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
attention_head_dim=reversed_attention_head_dim[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
@@ -268,6 +268,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
return objs
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
|
||||
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
||||
@@ -1785,6 +1786,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
class UpBlockFlat(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1895,6 +1897,7 @@ class UpBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
class CrossAttnUpBlockFlat(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2068,6 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlat(nn.Module):
|
||||
"""
|
||||
A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
|
||||
@@ -2223,6 +2227,7 @@ class UNetMidBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2369,6 +2374,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -27,7 +27,13 @@ from diffusers import (
|
||||
PixArtAlphaPipeline,
|
||||
Transformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
@@ -332,37 +338,35 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_pixart_1024(self):
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
prompt = self.prompt
|
||||
|
||||
image = pipe(prompt, generator=generator, output_type="np").images
|
||||
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.0742, 0.0835, 0.2114, 0.0295, 0.0784, 0.2361, 0.1738, 0.2251, 0.3589])
|
||||
|
||||
expected_slice = np.array([0.1941, 0.2117, 0.2188, 0.1946, 0.218, 0.2124, 0.199, 0.2437, 0.2583])
|
||||
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
|
||||
self.assertLessEqual(max_diff, 1e-4)
|
||||
|
||||
def test_pixart_512(self):
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = self.prompt
|
||||
|
||||
image = pipe(prompt, generator=generator, output_type="np").images
|
||||
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.3477, 0.3882, 0.4541, 0.3413, 0.3821, 0.4463, 0.4001, 0.4409, 0.4958])
|
||||
|
||||
expected_slice = np.array([0.2637, 0.291, 0.2939, 0.207, 0.2512, 0.2783, 0.2168, 0.2324, 0.2817])
|
||||
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
|
||||
self.assertLessEqual(max_diff, 1e-4)
|
||||
|
||||
def test_pixart_1024_without_resolution_binning(self):
|
||||
generator = torch.manual_seed(0)
|
||||
@@ -372,7 +376,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
prompt = self.prompt
|
||||
height, width = 1024, 768
|
||||
num_inference_steps = 10
|
||||
num_inference_steps = 2
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
@@ -406,7 +410,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
prompt = self.prompt
|
||||
height, width = 512, 768
|
||||
num_inference_steps = 10
|
||||
num_inference_steps = 2
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
|
||||
Reference in New Issue
Block a user