Compare commits

...

4 Commits

Author SHA1 Message Date
yiyixuxu
ef8c0bf51d fix 2024-02-15 02:45:31 +00:00
yiyixuxu
04be74ed94 fix 2024-02-15 00:45:41 +00:00
yiyixuxu
e4bee5d8df fix a error 2024-02-15 00:42:19 +00:00
yiyixuxu
9b1ff58b40 first draft 2024-02-14 23:44:15 +00:00
3 changed files with 237 additions and 63 deletions

View File

@@ -20,7 +20,7 @@ from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlNetMixin
from ..utils import BaseOutput, logging
from ..utils import BaseOutput, deprecate, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
@@ -43,6 +43,20 @@ from .unets.unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def correct_incorrect_names(attention_head_dim, down_block_types, mid_block_type):
incorrect_attention_head_dim_name = False
if "CrossAttnDownBlock2D" in down_block_types or mid_block_type == "UNetMidBlock2DCrossAttn":
incorrect_attention_head_dim_name = True
if incorrect_attention_head_dim_name:
num_attention_heads = attention_head_dim
attention_head_dimension = None
else:
num_attention_heads = None
attention_head_dimension = attention_head_dim
return num_attention_heads, attention_head_dimension
@dataclass
class ControlNetOutput(BaseOutput):
"""
@@ -206,6 +220,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
attention_head_dimension: Optional[Union[int, Tuple[int]]] = None,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
@@ -222,15 +237,21 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
):
super().__init__()
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
if attention_head_dim is not None:
deprecation_message = " `attention_head_dim` is deprecated and will be removed in a future version. Use `num_attention_heads` and `attention_head_dimension` instead."
deprecate("attention_head_dim not None", "1.0.0", deprecation_message, standard_warn=False)
num_attention_heads, attention_head_dimension = correct_incorrect_names(
attention_head_dim, down_block_types, mid_block_type
)
logger.warning(
f"corrected potentially incorrect arguments, the model will be configured with `num_attention_heads` {num_attention_heads} and `attention_head_dimension` {attention_head_dimension}."
)
# Check inputs
if attention_head_dimension is not None and num_attention_heads is not None:
raise ValueError(
"You can only define either `attention_head_dimension` or `num_attention_heads` but not both."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
@@ -241,11 +262,43 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
if (
num_attention_heads is not None
and not isinstance(num_attention_heads, int)
and len(num_attention_heads) != len(down_block_types)
):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if (
attention_head_dimension is not None
and not isinstance(attention_head_dimension, int)
and len(attention_head_dimension) != len(down_block_types)
):
raise ValueError(
f"Must provide the same number of `attention_head_dimension` as `down_block_types`. `attention_head_dimension`: {attention_head_dimension}. `down_block_types`: {down_block_types}."
)
# we use attention_head_dim to calculate num_attention_heads
if attention_head_dimension is not None:
if isinstance(attention_head_dimension, int):
num_attention_heads = [out_channels // attention_head_dimension for out_channels in block_out_channels]
else:
num_attention_heads = [
out_channels // attn_dim
for out_channels, attn_dim in zip(block_out_channels, attention_head_dimension)
]
# we use num_attention_heads to calculate attention_head_dimension
elif num_attention_heads is not None:
if isinstance(num_attention_heads, int):
attention_head_dimension = [out_channels // num_attention_heads for out_channels in block_out_channels]
else:
attention_head_dimension = [
out_channels // num_heads
for out_channels, num_heads in zip(block_out_channels, num_attention_heads)
]
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
@@ -354,8 +407,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(attention_head_dimension, int):
attention_head_dimension = (attention_head_dimension,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
@@ -385,7 +438,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
attention_head_dim=attention_head_dimension[i],
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
@@ -422,6 +475,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
attention_head_dim=attention_head_dimension[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,

View File

@@ -119,6 +119,7 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
downsample_type=downsample_type,
@@ -140,6 +141,7 @@ def get_down_block(
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
@@ -161,6 +163,7 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
@@ -191,6 +194,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
@@ -218,6 +222,7 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
@@ -243,6 +248,7 @@ def get_down_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
add_self_attention=True if not add_downsample else False,
)
@@ -335,6 +341,7 @@ def get_up_block(
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
@@ -358,6 +365,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
@@ -382,6 +390,7 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type,
@@ -412,6 +421,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
@@ -440,6 +450,7 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
@@ -468,6 +479,7 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
@@ -555,6 +567,7 @@ class UNetMidBlock2D(nn.Module):
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
):
@@ -602,13 +615,15 @@ class UNetMidBlock2D(nn.Module):
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
if num_attention_heads is None:
num_attention_heads = in_channels // attention_head_dim
for _ in range(num_layers):
if self.add_attention:
attentions.append(
Attention(
in_channels,
heads=in_channels // attention_head_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -680,6 +695,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
attention_head_dim: Optional[int] = None,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
dual_cross_attention: bool = False,
@@ -693,6 +709,9 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
if attention_head_dim is None:
attention_head_dim = in_channels // num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
@@ -718,8 +737,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
@@ -732,8 +751,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
@@ -824,6 +843,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
@@ -838,7 +858,9 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
self.attention_head_dim = attention_head_dim
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.num_heads = in_channels // self.attention_head_dim
if num_attention_heads is None:
num_attention_heads = in_channels // attention_head_dim
self.num_heads = num_attention_heads
# there is always at least one resnet
resnets = [
@@ -949,6 +971,7 @@ class AttnDownBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
downsample_padding: int = 1,
@@ -965,6 +988,9 @@ class AttnDownBlock2D(nn.Module):
)
attention_head_dim = out_channels
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
@@ -984,7 +1010,7 @@ class AttnDownBlock2D(nn.Module):
attentions.append(
Attention(
out_channels,
heads=out_channels // attention_head_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -1074,6 +1100,7 @@ class CrossAttnDownBlock2D(nn.Module):
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
attention_head_dim: Optional[int] = None,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
downsample_padding: int = 1,
@@ -1090,6 +1117,9 @@ class CrossAttnDownBlock2D(nn.Module):
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if attention_head_dim is None:
attention_head_dim = out_channels // num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
@@ -1112,8 +1142,8 @@ class CrossAttnDownBlock2D(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
@@ -1127,8 +1157,8 @@ class CrossAttnDownBlock2D(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
@@ -1395,6 +1425,7 @@ class AttnDownEncoderBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
@@ -1410,6 +1441,9 @@ class AttnDownEncoderBlock2D(nn.Module):
)
attention_head_dim = out_channels
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
if resnet_time_scale_shift == "spatial":
@@ -1444,7 +1478,7 @@ class AttnDownEncoderBlock2D(nn.Module):
attentions.append(
Attention(
out_channels,
heads=out_channels // attention_head_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -1495,6 +1529,7 @@ class AttnSkipDownBlock2D(nn.Module):
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = np.sqrt(2.0),
add_downsample: bool = True,
@@ -1509,6 +1544,9 @@ class AttnSkipDownBlock2D(nn.Module):
)
attention_head_dim = out_channels
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
@@ -1529,7 +1567,7 @@ class AttnSkipDownBlock2D(nn.Module):
self.attentions.append(
Attention(
out_channels,
heads=out_channels // attention_head_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -1789,6 +1827,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
@@ -1805,7 +1844,9 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
attentions = []
self.attention_head_dim = attention_head_dim
self.num_heads = out_channels // self.attention_head_dim
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
self.num_heads = num_attention_heads
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
@@ -1833,7 +1874,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
Attention(
query_dim=out_channels,
cross_attention_dim=out_channels,
heads=self.num_heads,
heads=num_attention_heads,
dim_head=attention_head_dim,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
@@ -2027,6 +2068,7 @@ class KCrossAttnDownBlock2D(nn.Module):
num_layers: int = 4,
resnet_group_size: int = 32,
add_downsample: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 64,
add_self_attention: bool = False,
resnet_eps: float = 1e-5,
@@ -2036,6 +2078,9 @@ class KCrossAttnDownBlock2D(nn.Module):
resnets = []
attentions = []
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
self.has_cross_attention = True
for i in range(num_layers):
@@ -2059,9 +2104,9 @@ class KCrossAttnDownBlock2D(nn.Module):
)
attentions.append(
KAttentionBlock(
out_channels,
out_channels // attention_head_dim,
attention_head_dim,
dim=out_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
temb_channels=temb_channels,
attention_bias=True,
@@ -2158,6 +2203,7 @@ class AttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
upsample_type: str = "conv",
@@ -2174,6 +2220,9 @@ class AttnUpBlock2D(nn.Module):
)
attention_head_dim = out_channels
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -2195,7 +2244,7 @@ class AttnUpBlock2D(nn.Module):
attentions.append(
Attention(
out_channels,
heads=out_channels // attention_head_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -2280,6 +2329,7 @@ class CrossAttnUpBlock2D(nn.Module):
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
attention_head_dim: Optional[int] = None,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
@@ -2296,6 +2346,9 @@ class CrossAttnUpBlock2D(nn.Module):
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if attention_head_dim is None:
attention_head_dim = out_channels // num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
@@ -2320,8 +2373,8 @@ class CrossAttnUpBlock2D(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
@@ -2335,8 +2388,8 @@ class CrossAttnUpBlock2D(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
@@ -2634,6 +2687,7 @@ class AttnUpDecoderBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
@@ -2649,6 +2703,9 @@ class AttnUpDecoderBlock2D(nn.Module):
)
attention_head_dim = out_channels
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
@@ -2685,7 +2742,7 @@ class AttnUpDecoderBlock2D(nn.Module):
attentions.append(
Attention(
out_channels,
heads=out_channels // attention_head_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -2737,6 +2794,7 @@ class AttnSkipUpBlock2D(nn.Module):
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = np.sqrt(2.0),
add_upsample: bool = True,
@@ -2771,10 +2829,13 @@ class AttnSkipUpBlock2D(nn.Module):
)
attention_head_dim = out_channels
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
self.attentions.append(
Attention(
out_channels,
heads=out_channels // attention_head_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -3082,6 +3143,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
@@ -3097,7 +3159,9 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
self.has_cross_attention = True
self.attention_head_dim = attention_head_dim
self.num_heads = out_channels // self.attention_head_dim
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
self.num_heads = num_attention_heads
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
@@ -3127,8 +3191,8 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
Attention(
query_dim=out_channels,
cross_attention_dim=out_channels,
heads=self.num_heads,
dim_head=self.attention_head_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
bias=True,
@@ -3334,6 +3398,7 @@ class KCrossAttnUpBlock2D(nn.Module):
resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu",
resnet_group_size: int = 32,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1, # attention dim_head
cross_attention_dim: int = 768,
add_upsample: bool = True,
@@ -3350,6 +3415,11 @@ class KCrossAttnUpBlock2D(nn.Module):
self.has_cross_attention = True
self.attention_head_dim = attention_head_dim
if num_attention_heads is not None:
logger.warn(
"`num_attention_heads` argument is passed but ignored. The number of attention heads is determined by `attention_head_dim`, `in_channels` and `out_channels`."
)
# in_channels, and out_channels for the block (k-unet)
k_in_channels = out_channels if is_first_block else 2 * out_channels
k_out_channels = in_channels
@@ -3383,11 +3453,11 @@ class KCrossAttnUpBlock2D(nn.Module):
)
attentions.append(
KAttentionBlock(
k_out_channels if (i == num_layers - 1) else out_channels,
k_out_channels // attention_head_dim
dim=k_out_channels if (i == num_layers - 1) else out_channels,
num_attention_heads=k_out_channels // attention_head_dim
if (i == num_layers - 1)
else out_channels // attention_head_dim,
attention_head_dim,
attention_head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
temb_channels=temb_channels,
attention_bias=True,

View File

@@ -55,6 +55,24 @@ from .unet_2d_blocks import (
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def correct_incorrect_names(attention_head_dim, down_block_types, mid_block_type, up_block_types):
incorrect_attention_head_dim_name = False
if (
"CrossAttnDownBlock2D" in down_block_types
or "CrossAttnUpBlock2D" in up_block_types
or mid_block_type == "UNetMidBlock2DCrossAttn"
):
incorrect_attention_head_dim_name = True
if incorrect_attention_head_dim_name:
num_attention_heads = attention_head_dim
attention_head_dimension = None
else:
num_attention_heads = None
attention_head_dimension = attention_head_dim
return num_attention_heads, attention_head_dimension
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
@@ -196,6 +214,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
attention_head_dimension: Optional[Union[int, Tuple[int]]] = None,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
@@ -225,20 +244,22 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
self.sample_size = sample_size
if num_attention_heads is not None:
raise ValueError(
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
if attention_head_dim is not None:
deprecation_message = " `attention_head_dim` is deprecated and will be removed in a future version. Use `num_attention_heads` and `attention_head_dimension` instead."
deprecate("attention_head_dim not None", "1.0.0", deprecation_message, standard_warn=False)
num_attention_heads, attention_head_dimension = correct_incorrect_names(
attention_head_dim, down_block_types, mid_block_type, up_block_types
)
logger.warning(
f"corrected potentially incorrect arguments, the model will be configured with `num_attention_heads` {num_attention_heads} and `attention_head_dimension` {attention_head_dimension}."
)
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if attention_head_dimension is not None and num_attention_heads is not None:
raise ValueError(
"You can only define either `attention_head_dimension` or `num_attention_heads` but not both."
)
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
@@ -254,14 +275,22 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
if (
num_attention_heads is not None
and not isinstance(num_attention_heads, int)
and len(num_attention_heads) != len(down_block_types)
):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
if (
attention_head_dimension is not None
and not isinstance(attention_head_dimension, int)
and len(attention_head_dimension) != len(down_block_types)
):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
f"Must provide the same number of `attention_head_dimension` as `down_block_types`. `attention_head_dimension`: {attention_head_dimension}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
@@ -278,6 +307,24 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# we use attention_head_dim to calculate num_attention_heads
if attention_head_dimension is not None:
if isinstance(attention_head_dimension, int):
num_attention_heads = [out_channels // attention_head_dimension for out_channels in block_out_channels]
else:
num_attention_heads = [
out_channels // attn_dim
for out_channels, attn_dim in zip(block_out_channels, attention_head_dimension)
]
# we use num_attention_heads to calculate attention_head_dimension
elif num_attention_heads is not None:
if isinstance(num_attention_heads, int):
attention_head_dimension = [out_channels // num_attention_heads for out_channels in block_out_channels]
else:
attention_head_dimension = [
out_channels // num_heads
for out_channels, num_heads in zip(block_out_channels, num_attention_heads)
]
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
@@ -422,8 +469,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(attention_head_dimension, int):
attention_head_dimension = (attention_head_dimension,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
@@ -472,7 +519,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
attention_head_dim=attention_head_dimension[i],
dropout=dropout,
)
self.down_blocks.append(down_block)
@@ -490,6 +537,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
attention_head_dim=attention_head_dimension[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
@@ -505,7 +553,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim[-1],
attention_head_dim=attention_head_dim[-1],
num_attention_heads=num_attention_heads[-1],
attention_head_dim=attention_head_dimension[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
@@ -536,6 +585,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_attention_head_dimension = list(reversed(attention_head_dimension))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = (
@@ -584,7 +634,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
attention_head_dim=reversed_attention_head_dimension[i],
dropout=dropout,
)
self.up_blocks.append(up_block)