Compare commits

...

5 Commits

Author SHA1 Message Date
sayakpaul
c980c53a8a fix-copies 2024-03-07 17:51:18 +05:30
sayakpaul
a0fcca6dd5 remove device. 2024-03-07 17:46:23 +05:30
sayakpaul
c2cc79c2c0 fix more 2024-03-07 17:42:52 +05:30
sayakpaul
ec64d34d5c checking 2024-03-07 17:41:57 +05:30
sayakpaul
43fbc3aec5 debug 2024-03-07 17:40:24 +05:30

View File

@@ -99,14 +99,13 @@ class SDFunctionTesterMixin:
assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2
def test_vae_tiling(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# make sure here that pndm scheduler skips prk
if "safety_checker" in components:
components["safety_checker"] = None
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
@@ -126,7 +125,7 @@ class SDFunctionTesterMixin:
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
zeros = torch.zeros(shape).to(device)
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)
def test_freeu_enabled(self):