Compare commits

...

6 Commits

Author SHA1 Message Date
Sayak Paul
a38a74f3ab adding ff_activation_fn properly. 2023-02-26 16:35:20 +05:30
Sayak Paul
19b181e628 adding: print statements to all ff_attn_fn affected blocks. 2023-02-26 16:17:34 +05:30
Sayak Paul
3ed26304e1 logger -> print. 2023-02-26 15:52:15 +05:30
Sayak Paul
505777dd98 add logging for geglu, default act_fn for ffn. 2023-02-26 15:46:09 +05:30
Sayak Paul
0e71a296d2 add logging for swiglu for being extra cautious. 2023-02-26 15:42:23 +05:30
Sayak Paul
93aaea1da7 add: support for Swiglu. 2023-02-26 11:38:50 +05:30
4 changed files with 48 additions and 1 deletions

View File

@@ -18,6 +18,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import logging
from ..utils.import_utils import is_xformers_available
from .cross_attention import CrossAttention
from .embeddings import CombinedTimestepLabelEmbeddings
@@ -29,6 +30,8 @@ if is_xformers_available():
else:
xformers = None
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class AttentionBlock(nn.Module):
"""
@@ -208,6 +211,7 @@ class BasicTransformerBlock(nn.Module):
final_dropout: bool = False,
):
super().__init__()
print(f"Using {activation_fn} as activation_fn in BasicTransformerBlock.")
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
@@ -353,15 +357,22 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
use_bias = True
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh")
elif activation_fn == "geglu":
print("Using GEGLU as the activation function in the FFN.")
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)
elif activation_fn == "swiglu":
print("Using SwiGLU as the activation function in the FFN.")
inner_dim = int(2 * dim_out / 3)
act_fn = SwiGLU(dim, inner_dim)
use_bias = False
self.net = nn.ModuleList([])
# project in
@@ -369,7 +380,7 @@ class FeedForward(nn.Module):
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))
self.net.append(nn.Linear(inner_dim, dim_out, bias=use_bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
@@ -442,6 +453,22 @@ class ApproximateGELU(nn.Module):
return x * torch.sigmoid(1.702 * x)
class SwiGLU(nn.Module):
"""
GEGLU-like that uses SiLU instead of GELU on the gates. SwiGLU is used in works like PaLM.
Reference: https://arxiv.org/abs/2002.05202
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.w1 = nn.Linear(dim_in, dim_out, bias=False)
self.w3 = nn.Linear(dim_in, dim_out, bias=False)
def forward(self, x):
return F.silu(self.w1(x)) * self.w3(x)
class AdaLayerNorm(nn.Module):
"""
Norm layer modified to incorporate timestep embeddings.

View File

@@ -104,6 +104,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
print(f"Using {activation_fn} as activation_fn in Transformer2DModel.")
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration

View File

@@ -42,6 +42,7 @@ def get_down_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
ff_activation_fn="geglu",
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
@@ -103,6 +104,7 @@ def get_down_block(
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
ff_activation_fn=ff_activation_fn,
)
elif down_block_type == "SimpleCrossAttnDownBlock2D":
if cross_attention_dim is None:
@@ -214,6 +216,7 @@ def get_up_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
ff_activation_fn="geglu",
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
@@ -262,6 +265,7 @@ def get_up_block(
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
ff_activation_fn=ff_activation_fn,
)
elif up_block_type == "SimpleCrossAttnUpBlock2D":
if cross_attention_dim is None:
@@ -465,8 +469,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
ff_activation_fn="geglu",
):
super().__init__()
print(f"Using {ff_activation_fn} as ff_activation_fn in UNetMidBlock2DCrossAttn")
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
@@ -501,6 +507,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
activation_fn=ff_activation_fn,
)
)
else:
@@ -512,6 +519,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
activation_fn=ff_activation_fn,
)
)
resnets.append(
@@ -742,6 +750,7 @@ class CrossAttnDownBlock2D(nn.Module):
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
ff_activation_fn="geglu",
):
super().__init__()
resnets = []
@@ -778,6 +787,7 @@ class CrossAttnDownBlock2D(nn.Module):
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
activation_fn=ff_activation_fn,
)
)
else:
@@ -789,6 +799,7 @@ class CrossAttnDownBlock2D(nn.Module):
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
activation_fn=ff_activation_fn,
)
)
self.attentions = nn.ModuleList(attentions)
@@ -1712,6 +1723,7 @@ class CrossAttnUpBlock2D(nn.Module):
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
ff_activation_fn="geglu",
):
super().__init__()
resnets = []
@@ -1750,6 +1762,7 @@ class CrossAttnUpBlock2D(nn.Module):
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
activation_fn=ff_activation_fn,
)
)
else:
@@ -1761,6 +1774,7 @@ class CrossAttnUpBlock2D(nn.Module):
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
activation_fn=ff_activation_fn,
)
)
self.attentions = nn.ModuleList(attentions)

View File

@@ -148,8 +148,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
ff_activation_fn="geglu",
):
super().__init__()
print(f"Using {ff_activation_fn} as ff_activation_fn in UNet2DConditionModel")
self.sample_size = sample_size
@@ -264,6 +266,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
ff_activation_fn=ff_activation_fn,
)
self.down_blocks.append(down_block)
@@ -282,6 +285,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
ff_activation_fn=ff_activation_fn,
)
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
@@ -341,6 +345,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
ff_activation_fn=ff_activation_fn,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel