Compare commits

...

2 Commits

Author SHA1 Message Date
Patrick von Platen
3a0b42a932 Merge branch 'main' into test_device_map_auto_on_cpu 2022-10-31 19:55:04 +01:00
Patrick von Platen
491483701b [Tests] Speed up CPU device auto tests 2022-10-28 15:18:11 +00:00
3 changed files with 13 additions and 7 deletions

View File

@@ -105,7 +105,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
return init_dict, inputs_dict
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model, loading_info = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
@@ -186,7 +188,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert peak_accelerate < peak_normal
def test_output_pretrained(self):
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", device_map="auto")
model.eval()
model.to(torch_device)
@@ -381,7 +383,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_large(self):
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update", device_map="auto")
model.to(torch_device)
torch.manual_seed(0)

View File

@@ -69,7 +69,9 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
pass
def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
model, loading_info = AutoencoderKL.from_pretrained(
"fusing/autoencoder-kl-dummy", output_loading_info=True, device_map="auto"
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
@@ -79,7 +81,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
def test_output_pretrained(self):
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", device_map="auto")
model = model.to(torch_device)
model.eval()

View File

@@ -65,7 +65,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
pass
def test_from_pretrained_hub(self):
model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
model, loading_info = VQModel.from_pretrained(
"fusing/vqgan-dummy", device_map="auto", output_loading_info=True
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
@@ -75,7 +77,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
def test_output_pretrained(self):
model = VQModel.from_pretrained("fusing/vqgan-dummy")
model = VQModel.from_pretrained("fusing/vqgan-dummy", device_map="auto")
model.to(torch_device).eval()
torch.manual_seed(0)