mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
2 Commits
thomas/sma
...
test_devic
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a0b42a932 | ||
|
|
491483701b |
@@ -105,7 +105,9 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model, loading_info = UNet2DModel.from_pretrained(
|
||||
"fusing/unet-ldm-dummy-update", output_loading_info=True, device_map="auto"
|
||||
)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
@@ -186,7 +188,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
assert peak_accelerate < peak_normal
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
|
||||
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", device_map="auto")
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
|
||||
@@ -381,7 +383,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
def test_output_pretrained_ve_large(self):
|
||||
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
|
||||
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update", device_map="auto")
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -69,7 +69,9 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
||||
model, loading_info = AutoencoderKL.from_pretrained(
|
||||
"fusing/autoencoder-kl-dummy", output_loading_info=True, device_map="auto"
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
@@ -79,7 +81,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
|
||||
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", device_map="auto")
|
||||
model = model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
|
||||
@@ -65,7 +65,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
|
||||
model, loading_info = VQModel.from_pretrained(
|
||||
"fusing/vqgan-dummy", device_map="auto", output_loading_info=True
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
@@ -75,7 +77,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = VQModel.from_pretrained("fusing/vqgan-dummy")
|
||||
model = VQModel.from_pretrained("fusing/vqgan-dummy", device_map="auto")
|
||||
model.to(torch_device).eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user