mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-18 18:34:37 +08:00
Compare commits
6 Commits
ci-style-b
...
temp/swigl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a38a74f3ab | ||
|
|
19b181e628 | ||
|
|
3ed26304e1 | ||
|
|
505777dd98 | ||
|
|
0e71a296d2 | ||
|
|
93aaea1da7 |
@@ -18,6 +18,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from ..utils import logging
|
||||||
from ..utils.import_utils import is_xformers_available
|
from ..utils.import_utils import is_xformers_available
|
||||||
from .cross_attention import CrossAttention
|
from .cross_attention import CrossAttention
|
||||||
from .embeddings import CombinedTimestepLabelEmbeddings
|
from .embeddings import CombinedTimestepLabelEmbeddings
|
||||||
@@ -29,6 +30,8 @@ if is_xformers_available():
|
|||||||
else:
|
else:
|
||||||
xformers = None
|
xformers = None
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -208,6 +211,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
final_dropout: bool = False,
|
final_dropout: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
print(f"Using {activation_fn} as activation_fn in BasicTransformerBlock.")
|
||||||
self.only_cross_attention = only_cross_attention
|
self.only_cross_attention = only_cross_attention
|
||||||
|
|
||||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||||
@@ -353,15 +357,22 @@ class FeedForward(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
dim_out = dim_out if dim_out is not None else dim
|
dim_out = dim_out if dim_out is not None else dim
|
||||||
|
use_bias = True
|
||||||
|
|
||||||
if activation_fn == "gelu":
|
if activation_fn == "gelu":
|
||||||
act_fn = GELU(dim, inner_dim)
|
act_fn = GELU(dim, inner_dim)
|
||||||
if activation_fn == "gelu-approximate":
|
if activation_fn == "gelu-approximate":
|
||||||
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
||||||
elif activation_fn == "geglu":
|
elif activation_fn == "geglu":
|
||||||
|
print("Using GEGLU as the activation function in the FFN.")
|
||||||
act_fn = GEGLU(dim, inner_dim)
|
act_fn = GEGLU(dim, inner_dim)
|
||||||
elif activation_fn == "geglu-approximate":
|
elif activation_fn == "geglu-approximate":
|
||||||
act_fn = ApproximateGELU(dim, inner_dim)
|
act_fn = ApproximateGELU(dim, inner_dim)
|
||||||
|
elif activation_fn == "swiglu":
|
||||||
|
print("Using SwiGLU as the activation function in the FFN.")
|
||||||
|
inner_dim = int(2 * dim_out / 3)
|
||||||
|
act_fn = SwiGLU(dim, inner_dim)
|
||||||
|
use_bias = False
|
||||||
|
|
||||||
self.net = nn.ModuleList([])
|
self.net = nn.ModuleList([])
|
||||||
# project in
|
# project in
|
||||||
@@ -369,7 +380,7 @@ class FeedForward(nn.Module):
|
|||||||
# project dropout
|
# project dropout
|
||||||
self.net.append(nn.Dropout(dropout))
|
self.net.append(nn.Dropout(dropout))
|
||||||
# project out
|
# project out
|
||||||
self.net.append(nn.Linear(inner_dim, dim_out))
|
self.net.append(nn.Linear(inner_dim, dim_out, bias=use_bias))
|
||||||
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
||||||
if final_dropout:
|
if final_dropout:
|
||||||
self.net.append(nn.Dropout(dropout))
|
self.net.append(nn.Dropout(dropout))
|
||||||
@@ -442,6 +453,22 @@ class ApproximateGELU(nn.Module):
|
|||||||
return x * torch.sigmoid(1.702 * x)
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLU(nn.Module):
|
||||||
|
"""
|
||||||
|
GEGLU-like that uses SiLU instead of GELU on the gates. SwiGLU is used in works like PaLM.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2002.05202
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim_in: int, dim_out: int):
|
||||||
|
super().__init__()
|
||||||
|
self.w1 = nn.Linear(dim_in, dim_out, bias=False)
|
||||||
|
self.w3 = nn.Linear(dim_in, dim_out, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.silu(self.w1(x)) * self.w3(x)
|
||||||
|
|
||||||
|
|
||||||
class AdaLayerNorm(nn.Module):
|
class AdaLayerNorm(nn.Module):
|
||||||
"""
|
"""
|
||||||
Norm layer modified to incorporate timestep embeddings.
|
Norm layer modified to incorporate timestep embeddings.
|
||||||
|
|||||||
@@ -104,6 +104,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.attention_head_dim = attention_head_dim
|
self.attention_head_dim = attention_head_dim
|
||||||
inner_dim = num_attention_heads * attention_head_dim
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
print(f"Using {activation_fn} as activation_fn in Transformer2DModel.")
|
||||||
|
|
||||||
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||||
# Define whether input is continuous or discrete depending on configuration
|
# Define whether input is continuous or discrete depending on configuration
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ def get_down_block(
|
|||||||
only_cross_attention=False,
|
only_cross_attention=False,
|
||||||
upcast_attention=False,
|
upcast_attention=False,
|
||||||
resnet_time_scale_shift="default",
|
resnet_time_scale_shift="default",
|
||||||
|
ff_activation_fn="geglu",
|
||||||
):
|
):
|
||||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||||
if down_block_type == "DownBlock2D":
|
if down_block_type == "DownBlock2D":
|
||||||
@@ -103,6 +104,7 @@ def get_down_block(
|
|||||||
only_cross_attention=only_cross_attention,
|
only_cross_attention=only_cross_attention,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
ff_activation_fn=ff_activation_fn,
|
||||||
)
|
)
|
||||||
elif down_block_type == "SimpleCrossAttnDownBlock2D":
|
elif down_block_type == "SimpleCrossAttnDownBlock2D":
|
||||||
if cross_attention_dim is None:
|
if cross_attention_dim is None:
|
||||||
@@ -214,6 +216,7 @@ def get_up_block(
|
|||||||
only_cross_attention=False,
|
only_cross_attention=False,
|
||||||
upcast_attention=False,
|
upcast_attention=False,
|
||||||
resnet_time_scale_shift="default",
|
resnet_time_scale_shift="default",
|
||||||
|
ff_activation_fn="geglu",
|
||||||
):
|
):
|
||||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||||
if up_block_type == "UpBlock2D":
|
if up_block_type == "UpBlock2D":
|
||||||
@@ -262,6 +265,7 @@ def get_up_block(
|
|||||||
only_cross_attention=only_cross_attention,
|
only_cross_attention=only_cross_attention,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
ff_activation_fn=ff_activation_fn,
|
||||||
)
|
)
|
||||||
elif up_block_type == "SimpleCrossAttnUpBlock2D":
|
elif up_block_type == "SimpleCrossAttnUpBlock2D":
|
||||||
if cross_attention_dim is None:
|
if cross_attention_dim is None:
|
||||||
@@ -465,8 +469,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|||||||
dual_cross_attention=False,
|
dual_cross_attention=False,
|
||||||
use_linear_projection=False,
|
use_linear_projection=False,
|
||||||
upcast_attention=False,
|
upcast_attention=False,
|
||||||
|
ff_activation_fn="geglu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
print(f"Using {ff_activation_fn} as ff_activation_fn in UNetMidBlock2DCrossAttn")
|
||||||
|
|
||||||
self.has_cross_attention = True
|
self.has_cross_attention = True
|
||||||
self.attn_num_head_channels = attn_num_head_channels
|
self.attn_num_head_channels = attn_num_head_channels
|
||||||
@@ -501,6 +507,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|||||||
norm_num_groups=resnet_groups,
|
norm_num_groups=resnet_groups,
|
||||||
use_linear_projection=use_linear_projection,
|
use_linear_projection=use_linear_projection,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
|
activation_fn=ff_activation_fn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -512,6 +519,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|||||||
num_layers=1,
|
num_layers=1,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
norm_num_groups=resnet_groups,
|
norm_num_groups=resnet_groups,
|
||||||
|
activation_fn=ff_activation_fn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
resnets.append(
|
resnets.append(
|
||||||
@@ -742,6 +750,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|||||||
use_linear_projection=False,
|
use_linear_projection=False,
|
||||||
only_cross_attention=False,
|
only_cross_attention=False,
|
||||||
upcast_attention=False,
|
upcast_attention=False,
|
||||||
|
ff_activation_fn="geglu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
resnets = []
|
resnets = []
|
||||||
@@ -778,6 +787,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|||||||
use_linear_projection=use_linear_projection,
|
use_linear_projection=use_linear_projection,
|
||||||
only_cross_attention=only_cross_attention,
|
only_cross_attention=only_cross_attention,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
|
activation_fn=ff_activation_fn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -789,6 +799,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|||||||
num_layers=1,
|
num_layers=1,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
norm_num_groups=resnet_groups,
|
norm_num_groups=resnet_groups,
|
||||||
|
activation_fn=ff_activation_fn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.attentions = nn.ModuleList(attentions)
|
self.attentions = nn.ModuleList(attentions)
|
||||||
@@ -1712,6 +1723,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|||||||
use_linear_projection=False,
|
use_linear_projection=False,
|
||||||
only_cross_attention=False,
|
only_cross_attention=False,
|
||||||
upcast_attention=False,
|
upcast_attention=False,
|
||||||
|
ff_activation_fn="geglu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
resnets = []
|
resnets = []
|
||||||
@@ -1750,6 +1762,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|||||||
use_linear_projection=use_linear_projection,
|
use_linear_projection=use_linear_projection,
|
||||||
only_cross_attention=only_cross_attention,
|
only_cross_attention=only_cross_attention,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
|
activation_fn=ff_activation_fn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1761,6 +1774,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|||||||
num_layers=1,
|
num_layers=1,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
norm_num_groups=resnet_groups,
|
norm_num_groups=resnet_groups,
|
||||||
|
activation_fn=ff_activation_fn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.attentions = nn.ModuleList(attentions)
|
self.attentions = nn.ModuleList(attentions)
|
||||||
|
|||||||
@@ -148,8 +148,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|||||||
conv_in_kernel: int = 3,
|
conv_in_kernel: int = 3,
|
||||||
conv_out_kernel: int = 3,
|
conv_out_kernel: int = 3,
|
||||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||||
|
ff_activation_fn="geglu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
print(f"Using {ff_activation_fn} as ff_activation_fn in UNet2DConditionModel")
|
||||||
|
|
||||||
self.sample_size = sample_size
|
self.sample_size = sample_size
|
||||||
|
|
||||||
@@ -264,6 +266,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|||||||
only_cross_attention=only_cross_attention[i],
|
only_cross_attention=only_cross_attention[i],
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
ff_activation_fn=ff_activation_fn,
|
||||||
)
|
)
|
||||||
self.down_blocks.append(down_block)
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
@@ -282,6 +285,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|||||||
dual_cross_attention=dual_cross_attention,
|
dual_cross_attention=dual_cross_attention,
|
||||||
use_linear_projection=use_linear_projection,
|
use_linear_projection=use_linear_projection,
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
|
ff_activation_fn=ff_activation_fn,
|
||||||
)
|
)
|
||||||
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
||||||
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
||||||
@@ -341,6 +345,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|||||||
only_cross_attention=only_cross_attention[i],
|
only_cross_attention=only_cross_attention[i],
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
ff_activation_fn=ff_activation_fn,
|
||||||
)
|
)
|
||||||
self.up_blocks.append(up_block)
|
self.up_blocks.append(up_block)
|
||||||
prev_output_channel = output_channel
|
prev_output_channel = output_channel
|
||||||
|
|||||||
Reference in New Issue
Block a user