mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
6 Commits
auto-pipel
...
temp/swigl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a38a74f3ab | ||
|
|
19b181e628 | ||
|
|
3ed26304e1 | ||
|
|
505777dd98 | ||
|
|
0e71a296d2 | ||
|
|
93aaea1da7 |
@@ -18,6 +18,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import logging
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
from .cross_attention import CrossAttention
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings
|
||||
@@ -29,6 +30,8 @@ if is_xformers_available():
|
||||
else:
|
||||
xformers = None
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
@@ -208,6 +211,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
final_dropout: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
print(f"Using {activation_fn} as activation_fn in BasicTransformerBlock.")
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
@@ -353,15 +357,22 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
use_bias = True
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim)
|
||||
if activation_fn == "gelu-approximate":
|
||||
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
||||
elif activation_fn == "geglu":
|
||||
print("Using GEGLU as the activation function in the FFN.")
|
||||
act_fn = GEGLU(dim, inner_dim)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
act_fn = ApproximateGELU(dim, inner_dim)
|
||||
elif activation_fn == "swiglu":
|
||||
print("Using SwiGLU as the activation function in the FFN.")
|
||||
inner_dim = int(2 * dim_out / 3)
|
||||
act_fn = SwiGLU(dim, inner_dim)
|
||||
use_bias = False
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
@@ -369,7 +380,7 @@ class FeedForward(nn.Module):
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(nn.Linear(inner_dim, dim_out))
|
||||
self.net.append(nn.Linear(inner_dim, dim_out, bias=use_bias))
|
||||
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
@@ -442,6 +453,22 @@ class ApproximateGELU(nn.Module):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
"""
|
||||
GEGLU-like that uses SiLU instead of GELU on the gates. SwiGLU is used in works like PaLM.
|
||||
|
||||
Reference: https://arxiv.org/abs/2002.05202
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim_in, dim_out, bias=False)
|
||||
self.w3 = nn.Linear(dim_in, dim_out, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return F.silu(self.w1(x)) * self.w3(x)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
|
||||
@@ -104,6 +104,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
print(f"Using {activation_fn} as activation_fn in Transformer2DModel.")
|
||||
|
||||
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
|
||||
@@ -42,6 +42,7 @@ def get_down_block(
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
ff_activation_fn="geglu",
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlock2D":
|
||||
@@ -103,6 +104,7 @@ def get_down_block(
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
ff_activation_fn=ff_activation_fn,
|
||||
)
|
||||
elif down_block_type == "SimpleCrossAttnDownBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
@@ -214,6 +216,7 @@ def get_up_block(
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
ff_activation_fn="geglu",
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlock2D":
|
||||
@@ -262,6 +265,7 @@ def get_up_block(
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
ff_activation_fn=ff_activation_fn,
|
||||
)
|
||||
elif up_block_type == "SimpleCrossAttnUpBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
@@ -465,8 +469,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
upcast_attention=False,
|
||||
ff_activation_fn="geglu",
|
||||
):
|
||||
super().__init__()
|
||||
print(f"Using {ff_activation_fn} as ff_activation_fn in UNetMidBlock2DCrossAttn")
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
@@ -501,6 +507,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
activation_fn=ff_activation_fn,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -512,6 +519,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
activation_fn=ff_activation_fn,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
@@ -742,6 +750,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
ff_activation_fn="geglu",
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -778,6 +787,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
activation_fn=ff_activation_fn,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -789,6 +799,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
activation_fn=ff_activation_fn,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
@@ -1712,6 +1723,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
ff_activation_fn="geglu",
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -1750,6 +1762,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
activation_fn=ff_activation_fn,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -1761,6 +1774,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
activation_fn=ff_activation_fn,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
|
||||
@@ -148,8 +148,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
conv_in_kernel: int = 3,
|
||||
conv_out_kernel: int = 3,
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
ff_activation_fn="geglu",
|
||||
):
|
||||
super().__init__()
|
||||
print(f"Using {ff_activation_fn} as ff_activation_fn in UNet2DConditionModel")
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
@@ -264,6 +266,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
ff_activation_fn=ff_activation_fn,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -282,6 +285,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
ff_activation_fn=ff_activation_fn,
|
||||
)
|
||||
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
||||
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
||||
@@ -341,6 +345,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
ff_activation_fn=ff_activation_fn,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
Reference in New Issue
Block a user