mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Compare commits
2 Commits
attn-refac
...
rename-att
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd7f9c74e1 | ||
|
|
b71f35b908 |
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...utils import is_torch_version
|
||||
from ...utils import deprecate, is_torch_version
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import Attention
|
||||
from ..resnet import (
|
||||
@@ -44,7 +44,8 @@ def get_down_block(
|
||||
add_downsample: bool,
|
||||
resnet_eps: float,
|
||||
resnet_act_fn: str,
|
||||
num_attention_heads: int,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
resnet_groups: Optional[int] = None,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
downsample_padding: Optional[int] = None,
|
||||
@@ -80,6 +81,16 @@ def get_down_block(
|
||||
elif down_block_type == "CrossAttnDownBlock3D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
||||
if num_attention_heads is not None:
|
||||
deprecation_message = (
|
||||
" passing `num`_attention_heads` to `unet_3d_blocks.get_down_block` for CrossAttnDownBlock3D is deprecated. "
|
||||
" Please use `attention_head_dim` instead."
|
||||
)
|
||||
deprecate("num_attention_heads not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = num_attention_heads
|
||||
if attention_head_dim is None:
|
||||
raise ValueError("`attention_head_dim` must be specified for CrossAttnDownBlock3D")
|
||||
return CrossAttnDownBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
@@ -91,7 +102,8 @@ def get_down_block(
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_attention_heads=None,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
@@ -173,7 +185,8 @@ def get_up_block(
|
||||
add_upsample: bool,
|
||||
resnet_eps: float,
|
||||
resnet_act_fn: str,
|
||||
num_attention_heads: int,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
resolution_idx: Optional[int] = None,
|
||||
resnet_groups: Optional[int] = None,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
@@ -212,6 +225,16 @@ def get_up_block(
|
||||
elif up_block_type == "CrossAttnUpBlock3D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
||||
if num_attention_heads is not None:
|
||||
deprecation_message = (
|
||||
" passing `num`_attention_heads` to `unet_3d_blocks.get_up_block` for CrossAttnUpBlock3D is deprecated. "
|
||||
" Please use `attention_head_dim` instead."
|
||||
)
|
||||
deprecate("num_attention_heads not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = num_attention_heads
|
||||
if attention_head_dim is None:
|
||||
raise ValueError("`attention_head_dim` must be specified for CrossAttnUpBlock3D")
|
||||
return CrossAttnUpBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
@@ -223,7 +246,8 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_attention_heads=None,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
@@ -314,7 +338,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
num_attention_heads: Optional[int] = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
output_scale_factor: float = 1.0,
|
||||
cross_attention_dim: int = 1280,
|
||||
dual_cross_attention: bool = False,
|
||||
@@ -322,9 +347,19 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
upcast_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if num_attention_heads is not None:
|
||||
deprecation_message = (
|
||||
" passing `num`_attention_heads` to `unet_3d_blocks.UNetMidBlock3DCrossAttn` is deprecated. "
|
||||
" Please use `attention_head_dim` instead."
|
||||
)
|
||||
deprecate("num_attention_heads not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = num_attention_heads
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if attention_head_dim is None:
|
||||
raise ValueError("`attention_head_dim` must be specified for UNetMidBlock3DCrossAttn")
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
# there is always at least one resnet
|
||||
@@ -356,8 +391,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
in_channels // num_attention_heads,
|
||||
num_attention_heads,
|
||||
in_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -368,8 +403,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
)
|
||||
temp_attentions.append(
|
||||
TransformerTemporalModel(
|
||||
in_channels // num_attention_heads,
|
||||
num_attention_heads,
|
||||
in_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -449,7 +484,8 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
num_attention_heads: Optional[int] = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
downsample_padding: int = 1,
|
||||
@@ -460,13 +496,23 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
upcast_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if num_attention_heads is not None:
|
||||
deprecation_message = (
|
||||
" passing `num`_attention_heads` to `unet_3d_blocks.CrossAttnDownBlock3D` is deprecated. "
|
||||
" Please use `attention_head_dim` instead."
|
||||
)
|
||||
deprecate("num_attention_heads not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = num_attention_heads
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if attention_head_dim is None:
|
||||
raise ValueError("`attention_head_dim` must be specified for CrossAttnDownBlock3D")
|
||||
resnets = []
|
||||
attentions = []
|
||||
temp_attentions = []
|
||||
temp_convs = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
@@ -494,8 +540,8 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
out_channels // num_attention_heads,
|
||||
num_attention_heads,
|
||||
out_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -507,8 +553,8 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
)
|
||||
temp_attentions.append(
|
||||
TransformerTemporalModel(
|
||||
out_channels // num_attention_heads,
|
||||
num_attention_heads,
|
||||
out_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -681,7 +727,8 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
num_attention_heads: Optional[int] = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
@@ -692,13 +739,25 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
resolution_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if num_attention_heads is not None:
|
||||
deprecation_message = (
|
||||
" passing `num`_attention_heads` to `unet_3d_blocks.CrossAttnUpBlock3D` is deprecated. "
|
||||
" Please use `attention_head_dim` instead."
|
||||
)
|
||||
deprecate("num_attention_heads not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = num_attention_heads
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
if attention_head_dim is None:
|
||||
raise ValueError("`attention_head_dim` must be specified for CrossAttnUpBlock3D")
|
||||
|
||||
resnets = []
|
||||
temp_convs = []
|
||||
attentions = []
|
||||
temp_attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
@@ -728,8 +787,8 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
out_channels // num_attention_heads,
|
||||
num_attention_heads,
|
||||
out_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -741,8 +800,8 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
)
|
||||
temp_attentions.append(
|
||||
TransformerTemporalModel(
|
||||
out_channels // num_attention_heads,
|
||||
num_attention_heads,
|
||||
out_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
|
||||
@@ -132,14 +132,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
||||
)
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
@@ -151,9 +143,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
# input
|
||||
@@ -187,8 +179,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
@@ -208,7 +200,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=False,
|
||||
)
|
||||
@@ -222,7 +214,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
num_attention_heads=None,
|
||||
attention_head_dim=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=False,
|
||||
)
|
||||
@@ -232,7 +225,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
@@ -261,7 +254,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=reversed_num_attention_heads[i],
|
||||
attention_head_dim=reversed_attention_head_dim[i],
|
||||
dual_cross_attention=False,
|
||||
resolution_idx=i,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user