Compare commits

...

2 Commits

Author SHA1 Message Date
Dhruv Nair
858dfd6411 update 2023-12-06 12:14:35 +00:00
Dhruv Nair
6cb2178a91 Revert "fix"
This reverts commit f90a5139a2.
2023-12-06 06:44:02 +00:00
6 changed files with 9 additions and 10 deletions

View File

@@ -446,9 +446,8 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
# Relevant to StableDiffusionUpscalePipeline
if "num_class_embeds" in config:
if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict):
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]

View File

View File

@@ -164,7 +164,7 @@ class PriorTransformerIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
@parameterized.expand(
[

View File

@@ -869,7 +869,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = torch.float16 if fp16 else torch.float32

View File

@@ -485,7 +485,7 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
@@ -565,7 +565,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
@@ -820,7 +820,7 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32

View File

@@ -310,7 +310,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
@@ -531,7 +531,7 @@ class StableDiffusion2PipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"