Compare commits

...

6 Commits

Author SHA1 Message Date
Dhruv Nair
5b902948a6 update 2024-02-07 07:38:11 +00:00
Dhruv Nair
e6a48db633 Refactor Deepfloyd IF tests. (#6855)
* update

* update

* update
2024-02-06 16:43:17 +05:30
sayakpaul
4f1df69d1a Revert "add attention_head_dim"
This reverts commit 15f6b22466.
2024-02-06 14:48:49 +05:30
sayakpaul
15f6b22466 add attention_head_dim 2024-02-06 14:48:07 +05:30
Sayak Paul
e6fd9ada3a [I2vGenXL] clean up things (#6845)
* remove _to_tensor

* remove _to_tensor definition

* remove _collapse_frames_into_batch

* remove lora for not bloating the code.

* remove sample_size.

* simplify code a bit more

* ensure timesteps are always in tensor.
2024-02-06 09:22:07 +05:30
Edward Li
493228a708 Fix AutoencoderTiny with use_slicing (#6850)
* Fix `AutoencoderTiny` with `use_slicing`

When using slicing with AutoencoderTiny, the encoder mistakenly encodes the entire batch for every image in the batch.

* Fixed formatting issue
2024-02-05 09:18:22 -10:00
10 changed files with 295 additions and 307 deletions

View File

@@ -292,7 +292,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
if self.use_slicing and x.shape[0] > 1:
output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
output = [
self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
]
output = torch.cat(output)
else:
output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)

View File

@@ -48,29 +48,6 @@ from .unet_3d_condition import UNet3DConditionOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _to_tensor(inputs, device):
if not torch.is_tensor(inputs):
# TODO: this requires sync between CPU and GPU. So try to pass `inputs` as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = device.type == "mps"
if isinstance(inputs, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
inputs = torch.tensor([inputs], dtype=dtype, device=device)
elif len(inputs.shape) == 0:
inputs = inputs[None].to(device)
return inputs
def _collapse_frames_into_batch(sample: torch.Tensor) -> torch.Tensor:
batch_size, channels, num_frames, height, width = sample.shape
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
return sample
class I2VGenXLTransformerTemporalEncoder(nn.Module):
def __init__(
self,
@@ -174,8 +151,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
):
super().__init__()
self.sample_size = sample_size
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
@@ -543,7 +518,18 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
forward_upsample_size = True
# 1. time
timesteps = _to_tensor(timestep, sample.device)
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
@@ -572,7 +558,13 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
context_emb = sample.new_zeros(batch_size, 0, self.config.cross_attention_dim)
context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1)
image_latents_context_embs = _collapse_frames_into_batch(image_latents[:, :, :1, :])
image_latents_for_context_embds = image_latents[:, :, :1, :]
image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape(
image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2],
image_latents_for_context_embds.shape[1],
image_latents_for_context_embds.shape[3],
image_latents_for_context_embds.shape[4],
)
image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs)
_batch_size, _channels, _height, _width = image_latents_context_embs.shape
@@ -586,7 +578,12 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
context_emb = torch.cat([context_emb, image_emb], dim=1)
context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)
image_latents = _collapse_frames_into_batch(image_latents)
image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
image_latents.shape[0] * image_latents.shape[2],
image_latents.shape[1],
image_latents.shape[3],
image_latents.shape[4],
)
image_latents = self.image_latents_proj_in(image_latents)
image_latents = (
image_latents[None, :]

View File

@@ -22,18 +22,13 @@ import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin
from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet
from ...schedulers import DDIMScheduler
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
@@ -207,7 +202,6 @@ class I2VGenXLPipeline(DiffusionPipeline):
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
@@ -233,23 +227,10 @@ class I2VGenXLPipeline(DiffusionPipeline):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
@@ -380,10 +361,6 @@ class I2VGenXLPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
def _encode_image(self, image, device, num_videos_per_prompt):
@@ -706,9 +683,6 @@ class I2VGenXLPipeline(DiffusionPipeline):
self._guidance_scale = guidance_scale
# 3.1 Encode input text prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
@@ -716,7 +690,6 @@ class I2VGenXLPipeline(DiffusionPipeline):
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.

View File

@@ -14,22 +14,16 @@
# limitations under the License.
import gc
import random
import unittest
import torch
from diffusers import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
IFInpaintingPipeline,
IFInpaintingSuperResolutionPipeline,
IFPipeline,
IFSuperResolutionPipeline,
)
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from diffusers.utils.testing_utils import load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@@ -97,77 +91,18 @@ class IFPipelineSlowTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
def test_all(self):
# if
def test_if_text_to_image(self):
pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload()
pipe_1 = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
pipe_2 = IFSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", variant="fp16", torch_dtype=torch.float16, text_encoder=None, tokenizer=None
)
# pre compute text embeddings and remove T5 to save memory
pipe_1.text_encoder.to("cuda")
prompt_embeds, negative_prompt_embeds = pipe_1.encode_prompt("anime turtle", device="cuda")
del pipe_1.tokenizer
del pipe_1.text_encoder
gc.collect()
pipe_1.tokenizer = None
pipe_1.text_encoder = None
pipe_1.enable_model_cpu_offload()
pipe_2.enable_model_cpu_offload()
pipe_1.unet.set_attn_processor(AttnAddedKVProcessor())
pipe_2.unet.set_attn_processor(AttnAddedKVProcessor())
self._test_if(pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds)
pipe_1.remove_all_hooks()
pipe_2.remove_all_hooks()
# img2img
pipe_1 = IFImg2ImgPipeline(**pipe_1.components)
pipe_2 = IFImg2ImgSuperResolutionPipeline(**pipe_2.components)
pipe_1.enable_model_cpu_offload()
pipe_2.enable_model_cpu_offload()
pipe_1.unet.set_attn_processor(AttnAddedKVProcessor())
pipe_2.unet.set_attn_processor(AttnAddedKVProcessor())
self._test_if_img2img(pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds)
pipe_1.remove_all_hooks()
pipe_2.remove_all_hooks()
# inpainting
pipe_1 = IFInpaintingPipeline(**pipe_1.components)
pipe_2 = IFInpaintingSuperResolutionPipeline(**pipe_2.components)
pipe_1.enable_model_cpu_offload()
pipe_2.enable_model_cpu_offload()
pipe_1.unet.set_attn_processor(AttnAddedKVProcessor())
pipe_2.unet.set_attn_processor(AttnAddedKVProcessor())
self._test_if_inpainting(pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds)
def _test_if(self, pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds):
# pipeline 1
_start_torch_memory_measurement()
torch.cuda.reset_max_memory_allocated()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe_1(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
output = pipe(
prompt="anime turtle",
num_inference_steps=2,
generator=generator,
output_type="np",
@@ -175,172 +110,11 @@ class IFPipelineSlowTests(unittest.TestCase):
image = output.images[0]
assert image.shape == (64, 64, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 13 * 10**9
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if.npy"
)
assert_mean_pixel_difference(image, expected_image)
# pipeline 2
_start_torch_memory_measurement()
generator = torch.Generator(device="cpu").manual_seed(0)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
output = pipe_2(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=image,
generator=generator,
num_inference_steps=2,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 4 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_superresolution_stage_II.npy"
)
assert_mean_pixel_difference(image, expected_image)
def _test_if_img2img(self, pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds):
# pipeline 1
_start_torch_memory_measurement()
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe_1(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=image,
num_inference_steps=2,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (64, 64, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 10 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_img2img.npy"
)
assert_mean_pixel_difference(image, expected_image)
# pipeline 2
_start_torch_memory_measurement()
generator = torch.Generator(device="cpu").manual_seed(0)
original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
output = pipe_2(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=image,
original_image=original_image,
generator=generator,
num_inference_steps=2,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 4 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_img2img_superresolution_stage_II.npy"
)
assert_mean_pixel_difference(image, expected_image)
def _test_if_inpainting(self, pipe_1, pipe_2, prompt_embeds, negative_prompt_embeds):
# pipeline 1
_start_torch_memory_measurement()
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
mask_image = floats_tensor((1, 3, 64, 64), rng=random.Random(1)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe_1(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=image,
mask_image=mask_image,
num_inference_steps=2,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (64, 64, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 10 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_inpainting.npy"
)
assert_mean_pixel_difference(image, expected_image)
# pipeline 2
_start_torch_memory_measurement()
generator = torch.Generator(device="cpu").manual_seed(0)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device)
mask_image = floats_tensor((1, 3, 256, 256), rng=random.Random(1)).to(torch_device)
output = pipe_2(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=image,
mask_image=mask_image,
original_image=original_image,
generator=generator,
num_inference_steps=2,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 4 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_inpainting_superresolution_stage_II.npy"
)
assert_mean_pixel_difference(image, expected_image)
def _start_torch_memory_measurement():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipe.remove_all_hooks()

View File

@@ -13,20 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import torch
from diffusers import IFImg2ImgPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, skip_mps, torch_device
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from . import IFPipelineTesterMixin
@@ -87,3 +89,43 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni
self._test_inference_batch_single_identical(
expected_max_diff=1e-2,
)
@slow
@require_torch_gpu
class IFImg2ImgPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_if_img2img(self):
pipe = IFImg2ImgPipeline.from_pretrained(
"DeepFloyd/IF-I-L-v1.0",
variant="fp16",
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload()
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(
prompt="anime turtle",
image=image,
num_inference_steps=2,
generator=generator,
output_type="np",
)
image = output.images[0]
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_img2img.npy"
)
assert_mean_pixel_difference(image, expected_image)
pipe.remove_all_hooks()

View File

@@ -13,17 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import torch
from diffusers import IFImg2ImgSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, skip_mps, torch_device
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from . import IFPipelineTesterMixin
@@ -82,3 +87,50 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT
self._test_inference_batch_single_identical(
expected_max_diff=1e-2,
)
@slow
@require_torch_gpu
class IFImg2ImgSuperResolutionPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_if_img2img_superresolution(self):
pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0",
variant="fp16",
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
output = pipe(
prompt="anime turtle",
image=image,
original_image=original_image,
generator=generator,
num_inference_steps=2,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_img2img_superresolution_stage_II.npy"
)
assert_mean_pixel_difference(image, expected_image)
pipe.remove_all_hooks()

View File

@@ -13,20 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import torch
from diffusers import IFInpaintingPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, skip_mps, torch_device
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from . import IFPipelineTesterMixin
@@ -85,3 +87,48 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin,
self._test_inference_batch_single_identical(
expected_max_diff=1e-2,
)
@slow
@require_torch_gpu
class IFInpaintingPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_if_inpainting(self):
pipe = IFInpaintingPipeline.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload()
# Super resolution test
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
mask_image = floats_tensor((1, 3, 64, 64), rng=random.Random(1)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(
prompt="anime prompts",
image=image,
mask_image=mask_image,
num_inference_steps=2,
generator=generator,
output_type="np",
)
image = output.images[0]
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_inpainting.npy"
)
assert_mean_pixel_difference(image, expected_image)
pipe.remove_all_hooks()

View File

@@ -13,20 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import torch
from diffusers import IFInpaintingSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, skip_mps, torch_device
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from . import IFPipelineTesterMixin
@@ -87,3 +89,55 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli
self._test_inference_batch_single_identical(
expected_max_diff=1e-2,
)
@slow
@require_torch_gpu
class IFInpaintingSuperResolutionPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_if_inpainting_superresolution(self):
pipe = IFInpaintingSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", variant="fp16", torch_dtype=torch.float16
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload()
# Super resolution test
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
generator = torch.Generator(device="cpu").manual_seed(0)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
original_image = floats_tensor((1, 3, 256, 256), rng=random.Random(0)).to(torch_device)
mask_image = floats_tensor((1, 3, 256, 256), rng=random.Random(1)).to(torch_device)
output = pipe(
prompt="anime turtle",
image=image,
original_image=original_image,
mask_image=mask_image,
generator=generator,
num_inference_steps=2,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_inpainting_superresolution_stage_II.npy"
)
assert_mean_pixel_difference(image, expected_image)
pipe.remove_all_hooks()

View File

@@ -13,17 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import torch
from diffusers import IFSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, skip_mps, torch_device
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from . import IFPipelineTesterMixin
@@ -80,3 +82,49 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi
self._test_inference_batch_single_identical(
expected_max_diff=1e-2,
)
@slow
@require_torch_gpu
class IFSuperResolutionPipelineSlowTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_if_superresolution(self):
pipe = IFSuperResolutionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0", variant="fp16", torch_dtype=torch.float16
)
pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload()
# Super resolution test
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(
prompt="anime turtle",
image=image,
generator=generator,
num_inference_steps=2,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/if/test_if_superresolution_stage_II.npy"
)
assert_mean_pixel_difference(image, expected_image)
pipe.remove_all_hooks()

View File

@@ -235,8 +235,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
inputs = self.get_dummy_inputs()
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.1958, 0.1475, 0.1396, 0.2412, 0.1658, 0.1533, 0.3997, 0.4055, 0.4128])
expected_slice = np.array([0.1704, 0.1296, 0.1272, 0.2212, 0.1514, 0.1479, 0.4172, 0.4263, 0.4360])
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4