mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-13 16:04:41 +08:00
Compare commits
17 Commits
edit-pypi-
...
1d_blocks
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0169697ba2 | ||
|
|
ea7ac3bd89 | ||
|
|
d70f8dc057 | ||
|
|
dc2c3992d1 | ||
|
|
198fd951ec | ||
|
|
0dff586964 | ||
|
|
cd1225c8ae | ||
|
|
511ebe5a84 | ||
|
|
07771ebea2 | ||
|
|
64c5688284 | ||
|
|
8b7f2e301d | ||
|
|
81a666d52c | ||
|
|
b98c62eb61 | ||
|
|
741122e722 | ||
|
|
5df4c8b81f | ||
|
|
084b51ac30 | ||
|
|
1c693f9b68 |
@@ -288,8 +288,16 @@ _kernels = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Downsample1d(nn.Module):
|
class KernelDownsample1D(nn.Module):
|
||||||
def __init__(self, kernel="linear", pad_mode="reflect"):
|
"""
|
||||||
|
A static downsample module that is not updated by the optimizer.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
kernel (`str`): `linear`, `cubic`, or `lanczos3` for different static kernels used in convolution.
|
||||||
|
pad_mode (`str`): defaults to `reflect`, use with torch.nn.functional.pad.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pad_mode = pad_mode
|
self.pad_mode = pad_mode
|
||||||
kernel_1d = torch.tensor(_kernels[kernel])
|
kernel_1d = torch.tensor(_kernels[kernel])
|
||||||
@@ -304,8 +312,16 @@ class Downsample1d(nn.Module):
|
|||||||
return F.conv1d(hidden_states, weight, stride=2)
|
return F.conv1d(hidden_states, weight, stride=2)
|
||||||
|
|
||||||
|
|
||||||
class Upsample1d(nn.Module):
|
class KernelUpsample1D(nn.Module):
|
||||||
def __init__(self, kernel="linear", pad_mode="reflect"):
|
"""
|
||||||
|
A static upsample module that is not updated by the optimizer.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
kernel (`str`): `linear`, `cubic`, or `lanczos3` for different static kernels used in convolution.
|
||||||
|
pad_mode (`str`): defaults to `reflect`, use with torch.nn.functional.pad.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pad_mode = pad_mode
|
self.pad_mode = pad_mode
|
||||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||||
@@ -321,7 +337,7 @@ class Upsample1d(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SelfAttention1d(nn.Module):
|
class SelfAttention1d(nn.Module):
|
||||||
def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
|
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = in_channels
|
self.channels = in_channels
|
||||||
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
|
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
|
||||||
@@ -379,7 +395,7 @@ class SelfAttention1d(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ResConvBlock(nn.Module):
|
class ResConvBlock(nn.Module):
|
||||||
def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
|
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.is_last = is_last
|
self.is_last = is_last
|
||||||
self.has_conv_skip = in_channels != out_channels
|
self.has_conv_skip = in_channels != out_channels
|
||||||
@@ -413,13 +429,12 @@ class ResConvBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class UNetMidBlock1D(nn.Module):
|
class UNetMidBlock1D(nn.Module):
|
||||||
def __init__(self, mid_channels, in_channels, out_channels=None):
|
def __init__(self, mid_channels: int, in_channels: int, out_channels: int = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
|
||||||
# there is always at least one resnet
|
self.down = KernelDownsample1D("cubic")
|
||||||
self.down = Downsample1d("cubic")
|
|
||||||
resnets = [
|
resnets = [
|
||||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||||
@@ -436,7 +451,7 @@ class UNetMidBlock1D(nn.Module):
|
|||||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||||
SelfAttention1d(out_channels, out_channels // 32),
|
SelfAttention1d(out_channels, out_channels // 32),
|
||||||
]
|
]
|
||||||
self.up = Upsample1d(kernel="cubic")
|
self.up = KernelUpsample1D(kernel="cubic")
|
||||||
|
|
||||||
self.attentions = nn.ModuleList(attentions)
|
self.attentions = nn.ModuleList(attentions)
|
||||||
self.resnets = nn.ModuleList(resnets)
|
self.resnets = nn.ModuleList(resnets)
|
||||||
@@ -453,21 +468,26 @@ class UNetMidBlock1D(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class AttnDownBlock1D(nn.Module):
|
class AttnDownBlock1D(nn.Module):
|
||||||
def __init__(self, out_channels, in_channels, mid_channels=None):
|
def __init__(self, out_channels: int, in_channels: int, num_layers: int = 3, mid_channels: int = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if num_layers < 1:
|
||||||
|
raise ValueError("AttnDownBlock1D requires added num_layers >= 1")
|
||||||
|
|
||||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||||
|
|
||||||
self.down = Downsample1d("cubic")
|
self.down = KernelDownsample1D("cubic")
|
||||||
resnets = [
|
resnets = []
|
||||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
attentions = []
|
||||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
|
||||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
for i in range(num_layers):
|
||||||
]
|
in_channels = in_channels if i == 0 else mid_channels
|
||||||
attentions = [
|
if i < (num_layers - 1):
|
||||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
attentions.append(SelfAttention1d(mid_channels, mid_channels // 32))
|
||||||
SelfAttention1d(out_channels, out_channels // 32),
|
else:
|
||||||
]
|
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||||
|
attentions.append(SelfAttention1d(out_channels, out_channels // 32))
|
||||||
|
|
||||||
self.attentions = nn.ModuleList(attentions)
|
self.attentions = nn.ModuleList(attentions)
|
||||||
self.resnets = nn.ModuleList(resnets)
|
self.resnets = nn.ModuleList(resnets)
|
||||||
@@ -483,16 +503,22 @@ class AttnDownBlock1D(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DownBlock1D(nn.Module):
|
class DownBlock1D(nn.Module):
|
||||||
def __init__(self, out_channels, in_channels, mid_channels=None):
|
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if num_layers < 1:
|
||||||
|
raise ValueError("DownBlock1D requires added num_layers >= 1")
|
||||||
|
|
||||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||||
|
|
||||||
self.down = Downsample1d("cubic")
|
self.down = KernelDownsample1D("cubic")
|
||||||
resnets = [
|
resnets = []
|
||||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
|
||||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
for i in range(num_layers):
|
||||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
in_channels = in_channels if i == 0 else mid_channels
|
||||||
]
|
if i < (num_layers - 1):
|
||||||
|
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||||
|
else:
|
||||||
|
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||||
|
|
||||||
self.resnets = nn.ModuleList(resnets)
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
@@ -506,15 +532,21 @@ class DownBlock1D(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DownBlock1DNoSkip(nn.Module):
|
class DownBlock1DNoSkip(nn.Module):
|
||||||
def __init__(self, out_channels, in_channels, mid_channels=None):
|
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if num_layers < 1:
|
||||||
|
raise ValueError("DownBlock1DNoSkip requires added num_layers >= 1")
|
||||||
|
|
||||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||||
|
|
||||||
resnets = [
|
resnets = []
|
||||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
|
||||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
for i in range(num_layers):
|
||||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
in_channels = in_channels if i == 0 else mid_channels
|
||||||
]
|
if i < (num_layers - 1):
|
||||||
|
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||||
|
else:
|
||||||
|
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||||
|
|
||||||
self.resnets = nn.ModuleList(resnets)
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
@@ -527,24 +559,28 @@ class DownBlock1DNoSkip(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class AttnUpBlock1D(nn.Module):
|
class AttnUpBlock1D(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if num_layers < 1:
|
||||||
|
raise ValueError("AttnUpBlock1D requires added num_layers >= 1")
|
||||||
|
|
||||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||||
|
|
||||||
resnets = [
|
resnets = []
|
||||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
attentions = []
|
||||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
|
||||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
for i in range(num_layers):
|
||||||
]
|
in_channels = 2 * in_channels if i == 0 else mid_channels
|
||||||
attentions = [
|
if i < (num_layers - 1):
|
||||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
attentions.append(SelfAttention1d(mid_channels, mid_channels // 32))
|
||||||
SelfAttention1d(out_channels, out_channels // 32),
|
else:
|
||||||
]
|
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||||
|
attentions.append(SelfAttention1d(out_channels, out_channels // 32))
|
||||||
|
|
||||||
self.attentions = nn.ModuleList(attentions)
|
self.attentions = nn.ModuleList(attentions)
|
||||||
self.resnets = nn.ModuleList(resnets)
|
self.resnets = nn.ModuleList(resnets)
|
||||||
self.up = Upsample1d(kernel="cubic")
|
self.up = KernelUpsample1D(kernel="cubic")
|
||||||
|
|
||||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||||
res_hidden_states = res_hidden_states_tuple[-1]
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
@@ -560,18 +596,24 @@ class AttnUpBlock1D(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class UpBlock1D(nn.Module):
|
class UpBlock1D(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if num_layers < 1:
|
||||||
|
raise ValueError("UpBlock1D requires added num_layers >= 1")
|
||||||
|
|
||||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||||
|
|
||||||
resnets = [
|
resnets = []
|
||||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
|
||||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
for i in range(num_layers):
|
||||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
in_channels = 2 * in_channels if i == 0 else mid_channels
|
||||||
]
|
if i < (num_layers - 1):
|
||||||
|
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||||
|
else:
|
||||||
|
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||||
|
|
||||||
self.resnets = nn.ModuleList(resnets)
|
self.resnets = nn.ModuleList(resnets)
|
||||||
self.up = Upsample1d(kernel="cubic")
|
self.up = KernelUpsample1D(kernel="cubic")
|
||||||
|
|
||||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||||
res_hidden_states = res_hidden_states_tuple[-1]
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
@@ -586,15 +628,21 @@ class UpBlock1D(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class UpBlock1DNoSkip(nn.Module):
|
class UpBlock1DNoSkip(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if num_layers < 1:
|
||||||
|
raise ValueError("UpBlock1D requires added num_layers >= 1")
|
||||||
|
|
||||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||||
|
|
||||||
resnets = [
|
resnets = []
|
||||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
|
||||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
for i in range(num_layers):
|
||||||
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
|
in_channels = 2 * in_channels if i == 0 else mid_channels
|
||||||
]
|
if i < (num_layers - 1):
|
||||||
|
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||||
|
else:
|
||||||
|
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True))
|
||||||
|
|
||||||
self.resnets = nn.ModuleList(resnets)
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user