mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
17 Commits
add-quanto
...
1d_blocks
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0169697ba2 | ||
|
|
ea7ac3bd89 | ||
|
|
d70f8dc057 | ||
|
|
dc2c3992d1 | ||
|
|
198fd951ec | ||
|
|
0dff586964 | ||
|
|
cd1225c8ae | ||
|
|
511ebe5a84 | ||
|
|
07771ebea2 | ||
|
|
64c5688284 | ||
|
|
8b7f2e301d | ||
|
|
81a666d52c | ||
|
|
b98c62eb61 | ||
|
|
741122e722 | ||
|
|
5df4c8b81f | ||
|
|
084b51ac30 | ||
|
|
1c693f9b68 |
@@ -288,8 +288,16 @@ _kernels = {
|
||||
}
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, kernel="linear", pad_mode="reflect"):
|
||||
class KernelDownsample1D(nn.Module):
|
||||
"""
|
||||
A static downsample module that is not updated by the optimizer.
|
||||
|
||||
Parameters:
|
||||
kernel (`str`): `linear`, `cubic`, or `lanczos3` for different static kernels used in convolution.
|
||||
pad_mode (`str`): defaults to `reflect`, use with torch.nn.functional.pad.
|
||||
"""
|
||||
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel])
|
||||
@@ -304,8 +312,16 @@ class Downsample1d(nn.Module):
|
||||
return F.conv1d(hidden_states, weight, stride=2)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, kernel="linear", pad_mode="reflect"):
|
||||
class KernelUpsample1D(nn.Module):
|
||||
"""
|
||||
A static upsample module that is not updated by the optimizer.
|
||||
|
||||
Parameters:
|
||||
kernel (`str`): `linear`, `cubic`, or `lanczos3` for different static kernels used in convolution.
|
||||
pad_mode (`str`): defaults to `reflect`, use with torch.nn.functional.pad.
|
||||
"""
|
||||
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||
@@ -321,7 +337,7 @@ class Upsample1d(nn.Module):
|
||||
|
||||
|
||||
class SelfAttention1d(nn.Module):
|
||||
def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
|
||||
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
|
||||
super().__init__()
|
||||
self.channels = in_channels
|
||||
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
|
||||
@@ -379,7 +395,7 @@ class SelfAttention1d(nn.Module):
|
||||
|
||||
|
||||
class ResConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
|
||||
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
|
||||
super().__init__()
|
||||
self.is_last = is_last
|
||||
self.has_conv_skip = in_channels != out_channels
|
||||
@@ -413,13 +429,12 @@ class ResConvBlock(nn.Module):
|
||||
|
||||
|
||||
class UNetMidBlock1D(nn.Module):
|
||||
def __init__(self, mid_channels, in_channels, out_channels=None):
|
||||
def __init__(self, mid_channels: int, in_channels: int, out_channels: int = None):
|
||||
super().__init__()
|
||||
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
self.down = Downsample1d("cubic")
|
||||
self.down = KernelDownsample1D("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
@@ -436,7 +451,7 @@ class UNetMidBlock1D(nn.Module):
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
self.up = KernelUpsample1D(kernel="cubic")
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
@@ -453,21 +468,26 @@ class UNetMidBlock1D(nn.Module):
|
||||
|
||||
|
||||
class AttnDownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels, in_channels, mid_channels=None):
|
||||
def __init__(self, out_channels: int, in_channels: int, num_layers: int = 3, mid_channels: int = None):
|
||||
super().__init__()
|
||||
|
||||
if num_layers < 1:
|
||||
raise ValueError("AttnDownBlock1D requires added num_layers >= 1")
|
||||
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
self.down = KernelDownsample1D("cubic")
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else mid_channels
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||
attentions.append(SelfAttention1d(mid_channels, mid_channels // 32))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
attentions.append(SelfAttention1d(out_channels, out_channels // 32))
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
@@ -483,16 +503,22 @@ class AttnDownBlock1D(nn.Module):
|
||||
|
||||
|
||||
class DownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels, in_channels, mid_channels=None):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError("DownBlock1D requires added num_layers >= 1")
|
||||
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
self.down = KernelDownsample1D("cubic")
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else mid_channels
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
@@ -506,15 +532,21 @@ class DownBlock1D(nn.Module):
|
||||
|
||||
|
||||
class DownBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, out_channels, in_channels, mid_channels=None):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError("DownBlock1DNoSkip requires added num_layers >= 1")
|
||||
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else mid_channels
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
@@ -527,24 +559,28 @@ class DownBlock1DNoSkip(nn.Module):
|
||||
|
||||
|
||||
class AttnUpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError("AttnUpBlock1D requires added num_layers >= 1")
|
||||
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = 2 * in_channels if i == 0 else mid_channels
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||
attentions.append(SelfAttention1d(mid_channels, mid_channels // 32))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
attentions.append(SelfAttention1d(out_channels, out_channels // 32))
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
self.up = KernelUpsample1D(kernel="cubic")
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
@@ -560,18 +596,24 @@ class AttnUpBlock1D(nn.Module):
|
||||
|
||||
|
||||
class UpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError("UpBlock1D requires added num_layers >= 1")
|
||||
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = 2 * in_channels if i == 0 else mid_channels
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
self.up = KernelUpsample1D(kernel="cubic")
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
@@ -586,15 +628,21 @@ class UpBlock1D(nn.Module):
|
||||
|
||||
|
||||
class UpBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError("UpBlock1D requires added num_layers >= 1")
|
||||
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
|
||||
]
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = 2 * in_channels if i == 0 else mid_channels
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user