Compare commits

...

17 Commits

Author SHA1 Message Date
Patrick von Platen
0169697ba2 Merge branch 'main' into 1d_blocks 2022-12-01 17:53:44 +01:00
Nathan Lambert
ea7ac3bd89 reset dance diff test 2022-11-29 16:25:10 -08:00
Nathan Lambert
d70f8dc057 revert more changes 2022-11-29 15:49:10 -08:00
Nathan Lambert
dc2c3992d1 fix weird layer counting 2022-11-29 15:40:54 -08:00
Nathan Lambert
198fd951ec fix 2022-11-29 15:22:20 -08:00
Nathan Lambert
0dff586964 revert dance diff test 2022-11-29 15:21:38 -08:00
Nathan Lambert
cd1225c8ae revert some breaking changes 2022-11-29 15:20:29 -08:00
Nathan Lambert
511ebe5a84 Merge remote-tracking branch 'origin' into 1d_blocks 2022-11-29 15:12:56 -08:00
Nathan Lambert
07771ebea2 change numbers for dummy model dance diffusion 2022-11-17 10:45:11 -08:00
Nathan Lambert
64c5688284 revert change to make less breaking 2022-11-14 14:12:47 -08:00
Nathan Lambert
8b7f2e301d Merge branch 'main' into 1d_blocks 2022-11-14 14:05:48 -08:00
Patrick von Platen
81a666d52c Merge branch 'main' into 1d_blocks 2022-10-31 19:19:03 +01:00
Nathan Lambert
b98c62eb61 quality 2022-10-26 16:05:51 -07:00
Nathan Lambert
741122e722 style 2022-10-26 16:00:39 -07:00
Nathan Lambert
5df4c8b81f add layers args 2022-10-26 15:56:04 -07:00
Nathan Lambert
084b51ac30 adding num_layers arg 2022-10-26 15:44:05 -07:00
Nathan Lambert
1c693f9b68 type checking 2022-10-26 15:31:30 -07:00

View File

@@ -288,8 +288,16 @@ _kernels = {
}
class Downsample1d(nn.Module):
def __init__(self, kernel="linear", pad_mode="reflect"):
class KernelDownsample1D(nn.Module):
"""
A static downsample module that is not updated by the optimizer.
Parameters:
kernel (`str`): `linear`, `cubic`, or `lanczos3` for different static kernels used in convolution.
pad_mode (`str`): defaults to `reflect`, use with torch.nn.functional.pad.
"""
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
@@ -304,8 +312,16 @@ class Downsample1d(nn.Module):
return F.conv1d(hidden_states, weight, stride=2)
class Upsample1d(nn.Module):
def __init__(self, kernel="linear", pad_mode="reflect"):
class KernelUpsample1D(nn.Module):
"""
A static upsample module that is not updated by the optimizer.
Parameters:
kernel (`str`): `linear`, `cubic`, or `lanczos3` for different static kernels used in convolution.
pad_mode (`str`): defaults to `reflect`, use with torch.nn.functional.pad.
"""
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
@@ -321,7 +337,7 @@ class Upsample1d(nn.Module):
class SelfAttention1d(nn.Module):
def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
super().__init__()
self.channels = in_channels
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
@@ -379,7 +395,7 @@ class SelfAttention1d(nn.Module):
class ResConvBlock(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
super().__init__()
self.is_last = is_last
self.has_conv_skip = in_channels != out_channels
@@ -413,13 +429,12 @@ class ResConvBlock(nn.Module):
class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels, in_channels, out_channels=None):
def __init__(self, mid_channels: int, in_channels: int, out_channels: int = None):
super().__init__()
out_channels = in_channels if out_channels is None else out_channels
# there is always at least one resnet
self.down = Downsample1d("cubic")
self.down = KernelDownsample1D("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
@@ -436,7 +451,7 @@ class UNetMidBlock1D(nn.Module):
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.up = Upsample1d(kernel="cubic")
self.up = KernelUpsample1D(kernel="cubic")
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
@@ -453,21 +468,26 @@ class UNetMidBlock1D(nn.Module):
class AttnDownBlock1D(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None):
def __init__(self, out_channels: int, in_channels: int, num_layers: int = 3, mid_channels: int = None):
super().__init__()
if num_layers < 1:
raise ValueError("AttnDownBlock1D requires added num_layers >= 1")
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.down = KernelDownsample1D("cubic")
resnets = []
attentions = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else mid_channels
if i < (num_layers - 1):
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
attentions.append(SelfAttention1d(mid_channels, mid_channels // 32))
else:
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
attentions.append(SelfAttention1d(out_channels, out_channels // 32))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
@@ -483,16 +503,22 @@ class AttnDownBlock1D(nn.Module):
class DownBlock1D(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None):
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 3):
super().__init__()
if num_layers < 1:
raise ValueError("DownBlock1D requires added num_layers >= 1")
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.down = KernelDownsample1D("cubic")
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else mid_channels
if i < (num_layers - 1):
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
else:
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
self.resnets = nn.ModuleList(resnets)
@@ -506,15 +532,21 @@ class DownBlock1D(nn.Module):
class DownBlock1DNoSkip(nn.Module):
def __init__(self, out_channels, in_channels, mid_channels=None):
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 3):
super().__init__()
if num_layers < 1:
raise ValueError("DownBlock1DNoSkip requires added num_layers >= 1")
mid_channels = out_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else mid_channels
if i < (num_layers - 1):
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
else:
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
self.resnets = nn.ModuleList(resnets)
@@ -527,24 +559,28 @@ class DownBlock1DNoSkip(nn.Module):
class AttnUpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
super().__init__()
if num_layers < 1:
raise ValueError("AttnUpBlock1D requires added num_layers >= 1")
mid_channels = out_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
resnets = []
attentions = []
for i in range(num_layers):
in_channels = 2 * in_channels if i == 0 else mid_channels
if i < (num_layers - 1):
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
attentions.append(SelfAttention1d(mid_channels, mid_channels // 32))
else:
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
attentions.append(SelfAttention1d(out_channels, out_channels // 32))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
self.up = KernelUpsample1D(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1]
@@ -560,18 +596,24 @@ class AttnUpBlock1D(nn.Module):
class UpBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
super().__init__()
if num_layers < 1:
raise ValueError("UpBlock1D requires added num_layers >= 1")
mid_channels = in_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
resnets = []
for i in range(num_layers):
in_channels = 2 * in_channels if i == 0 else mid_channels
if i < (num_layers - 1):
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
else:
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
self.up = KernelUpsample1D(kernel="cubic")
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1]
@@ -586,15 +628,21 @@ class UpBlock1D(nn.Module):
class UpBlock1DNoSkip(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
super().__init__()
if num_layers < 1:
raise ValueError("UpBlock1D requires added num_layers >= 1")
mid_channels = in_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
]
resnets = []
for i in range(num_layers):
in_channels = 2 * in_channels if i == 0 else mid_channels
if i < (num_layers - 1):
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
else:
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True))
self.resnets = nn.ModuleList(resnets)