fix rl model tests

This commit is contained in:
Patrick von Platen
2022-06-28 09:50:21 +00:00
parent 85d991a12a
commit a859b1992b
2 changed files with 52 additions and 23 deletions

View File

@@ -122,13 +122,13 @@ class ResidualTemporalBlock(nn.Module):
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__(
self,
training_horizon,
transition_dim,
cond_dim,
training_horizon=128,
transition_dim=14,
cond_dim=3,
predict_epsilon=False,
clip_denoised=True,
dim=32,
dim_mults=(1, 2, 4, 8),
dim_mults=(1, 4, 8),
):
super().__init__()
@@ -139,7 +139,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim = dim
self.time_mlp = nn.Sequential(
@@ -153,7 +152,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
print(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
@@ -195,7 +193,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
nn.Conv1d(dim, transition_dim, 1),
)
def forward(self, x, time):
def forward(self, x, timesteps):
"""
x : [ batch x horizon x transition ]
"""
@@ -203,7 +201,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
# x = einops.rearrange(x, "b h t -> b t h")
x = x.permute(0, 2, 1)
t = self.time_mlp(time)
t = self.time_mlp(timesteps)
h = []
for resnet, resnet2, downsample in self.downs:

View File

@@ -190,7 +190,7 @@ class ModelTesterMixin:
model.to(torch_device)
model.train()
output = model(**inputs_dict)
noise = torch.randn((inputs_dict["x"].shape[0],) + self.get_output_shape).to(torch_device)
noise = torch.randn((inputs_dict["x"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
@@ -210,11 +210,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step}
@property
def get_input_shape(self):
def input_shape(self):
return (3, 32, 32)
@property
def get_output_shape(self):
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
@@ -276,11 +276,11 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step, "low_res": low_res}
@property
def get_input_shape(self):
def input_shape(self):
return (3, 32, 32)
@property
def get_output_shape(self):
def output_shape(self):
return (6, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
@@ -367,11 +367,11 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step, "transformer_out": emb}
@property
def get_input_shape(self):
def input_shape(self):
return (3, 32, 32)
@property
def get_output_shape(self):
def output_shape(self):
return (6, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
@@ -459,11 +459,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step}
@property
def get_input_shape(self):
def input_shape(self):
return (4, 32, 32)
@property
def get_output_shape(self):
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
@@ -552,11 +552,11 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step, "mu": condition, "mask": mask}
@property
def get_input_shape(self):
def input_shape(self):
return (4, 32, 16)
@property
def get_output_shape(self):
def output_shape(self):
return (4, 32, 16)
def prepare_init_args_and_inputs_for_common(self):
@@ -610,6 +610,38 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = TemporalUNet
@property
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, seq_len, num_features)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"x": noise, "timesteps": time_step}
@property
def input_shape(self):
return (4, 16, 14)
@property
def output_shape(self):
return (4, 16, 14)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"training_horizon": 128,
"dim": 32,
"dim_mults": [1, 4, 8],
"predict_epsilon": False,
"clip_denoised": True,
"transition_dim": 14,
"cond_dim": 3,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_pretrained_hub(self):
model, loading_info = TemporalUNet.from_pretrained(
"fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True
@@ -640,8 +672,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
output_slice = output[0, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580,
-0.0584])
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
@@ -662,11 +693,11 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
return {"x": noise, "timesteps": time_step}
@property
def get_input_shape(self):
def input_shape(self):
return (3, 32, 32)
@property
def get_output_shape(self):
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):