mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
fix rl model tests
This commit is contained in:
@@ -122,13 +122,13 @@ class ResidualTemporalBlock(nn.Module):
|
||||
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
training_horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
training_horizon=128,
|
||||
transition_dim=14,
|
||||
cond_dim=3,
|
||||
predict_epsilon=False,
|
||||
clip_denoised=True,
|
||||
dim=32,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
dim_mults=(1, 4, 8),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -139,7 +139,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
|
||||
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
|
||||
|
||||
time_dim = dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
@@ -153,7 +152,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
print(in_out)
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
@@ -195,7 +193,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
nn.Conv1d(dim, transition_dim, 1),
|
||||
)
|
||||
|
||||
def forward(self, x, time):
|
||||
def forward(self, x, timesteps):
|
||||
"""
|
||||
x : [ batch x horizon x transition ]
|
||||
"""
|
||||
@@ -203,7 +201,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
# x = einops.rearrange(x, "b h t -> b t h")
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
t = self.time_mlp(time)
|
||||
t = self.time_mlp(timesteps)
|
||||
h = []
|
||||
|
||||
for resnet, resnet2, downsample in self.downs:
|
||||
|
||||
@@ -190,7 +190,7 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
output = model(**inputs_dict)
|
||||
noise = torch.randn((inputs_dict["x"].shape[0],) + self.get_output_shape).to(torch_device)
|
||||
noise = torch.randn((inputs_dict["x"].shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
@@ -210,11 +210,11 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
return {"x": noise, "timesteps": time_step}
|
||||
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
@@ -276,11 +276,11 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
|
||||
return {"x": noise, "timesteps": time_step, "low_res": low_res}
|
||||
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
def output_shape(self):
|
||||
return (6, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
@@ -367,11 +367,11 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
return {"x": noise, "timesteps": time_step, "transformer_out": emb}
|
||||
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
def output_shape(self):
|
||||
return (6, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
@@ -459,11 +459,11 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
return {"x": noise, "timesteps": time_step}
|
||||
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
@@ -552,11 +552,11 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
return {"x": noise, "timesteps": time_step, "mu": condition, "mask": mask}
|
||||
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
def input_shape(self):
|
||||
return (4, 32, 16)
|
||||
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
def output_shape(self):
|
||||
return (4, 32, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
@@ -610,6 +610,38 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = TemporalUNet
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, seq_len, num_features)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"x": noise, "timesteps": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 16, 14)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 16, 14)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"training_horizon": 128,
|
||||
"dim": 32,
|
||||
"dim_mults": [1, 4, 8],
|
||||
"predict_epsilon": False,
|
||||
"clip_denoised": True,
|
||||
"transition_dim": 14,
|
||||
"cond_dim": 3,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = TemporalUNet.from_pretrained(
|
||||
"fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True
|
||||
@@ -640,8 +672,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
output_slice = output[0, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580,
|
||||
-0.0584])
|
||||
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
@@ -662,11 +693,11 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
return {"x": noise, "timesteps": time_step}
|
||||
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
|
||||
Reference in New Issue
Block a user