Compare commits

...

1 Commits

Author SHA1 Message Date
Dhruv Nair
80dfbb99b8 update 2023-12-18 18:09:55 +00:00
6 changed files with 41 additions and 13 deletions

View File

@@ -283,6 +283,9 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}"
)
# Offload all models
self.maybe_free_model_hooks()
if output_type == "latent":
return ShapEPipelineOutput(images=latents)
@@ -312,9 +315,6 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
if output_type == "pil":
images = [self.numpy_to_pil(image) for image in images]
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (images,)

View File

@@ -477,8 +477,9 @@ class UnCLIPPipeline(DiffusionPipeline):
image = super_res_latents
# done super res
# post processing
self.maybe_free_model_hooks()
# post processing
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

View File

@@ -403,6 +403,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
image = super_res_latents
# done super res
self.maybe_free_model_hooks()
# post processing

View File

@@ -14,7 +14,7 @@ from diffusers import (
UNet2DConditionModel,
UNetMotionModel,
)
from diffusers.utils import logging
from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
@@ -233,6 +233,35 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device)
pipe(**inputs)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_without_offload = pipe(**inputs).frames[0]
output_without_offload = (
output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
)
pipe.enable_xformers_memory_efficient_attention()
inputs = self.get_dummy_inputs(torch_device)
output_with_offload = pipe(**inputs).frames[0]
output_with_offload = (
output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
)
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
@slow
@require_torch_gpu

View File

@@ -804,8 +804,7 @@ class StableDiffusionAdapterPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
pipe.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images

View File

@@ -681,7 +681,7 @@ class AdapterSDXLPipelineSlowTests(unittest.TestCase):
variant="fp16",
)
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
pipe.enable_sequential_cpu_offload()
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
@@ -694,8 +694,6 @@ class AdapterSDXLPipelineSlowTests(unittest.TestCase):
assert images[0].shape == (768, 512, 3)
original_image = images[0, -3:, -3:, -1].flatten()
expected_image = np.array(
[0.50346327, 0.50708383, 0.50719553, 0.5135172, 0.5155377, 0.5066059, 0.49680984, 0.5005894, 0.48509413]
)
assert numpy_cosine_similarity_distance(original_image, expected_image) < 1e-4
image_slice = images[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.4284, 0.4337, 0.4319, 0.4255, 0.4329, 0.4280, 0.4338, 0.4420, 0.4226])
assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4