[WIP] Refactor UniDiffuser Pipeline and Tests (#4948)

* Add VAE slicing and tiling methods.

* Switch to using VaeImageProcessing for preprocessing and postprocessing of images.

* Rename the VaeImageProcessor to vae_image_processor to avoid a name clash with the CLIPImageProcessor (image_processor).

* Remove the postprocess() function because we're using a VaeImageProcessor instead.

* Remove UniDiffuserPipeline.decode_image_latents because we're using VaeImageProcessor instead.

* Refactor generating text from text latents into a decode_text_latents method.

* Add enable_full_determinism() to UniDiffuser tests.

* make style

* Add PipelineLatentTesterMixin to UniDiffuserPipelineFastTests.

* Remove enable_model_cpu_offload since it is now part of DiffusionPipeline.

* Rename the VaeImageProcessor instance to self.image_processor for consistency with other pipelines and rename the CLIPImageProcessor instance to clip_image_processor to avoid a name clash.

* Update UniDiffuser conversion script.

* Make safe_serialization configurable in UniDiffuser conversion script.

* Rename image_processor to clip_image_processor in UniDiffuser tests.

* Add PipelineKarrasSchedulerTesterMixin to UniDiffuserPipelineFastTests.

* Add initial test for compiling the UniDiffuser model (not tested yet).

* Update encode_prompt and _encode_prompt to match that of StableDiffusionPipeline.

* Turn off standard classifier-free guidance for now.

* make style

* make fix-copies

* apply suggestions from review

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
dg845
2023-10-02 09:24:55 -07:00
committed by GitHub
parent db91e710da
commit cd1b8d7ca8
3 changed files with 270 additions and 153 deletions

View File

@@ -73,17 +73,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
@@ -92,6 +92,19 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
return mapping
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
# config.num_head_channels => num_head_channels
def assign_to_checkpoint(
@@ -104,8 +117,9 @@ def assign_to_checkpoint(
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new
checkpoint.
attention layers, and takes into account additional replacements that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
@@ -143,25 +157,16 @@ def assign_to_checkpoint(
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
shape = old_checkpoint[path["old"]].shape
if is_attn_weight and len(shape) == 3:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif is_attn_weight and len(shape) == 4:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def create_vae_diffusers_config(config_type):
# Hardcoded for now
if args.config_type == "test":
@@ -339,7 +344,7 @@ def create_text_decoder_config_big():
return text_decoder_config
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments.convert_ldm_vae_checkpoint
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1):
"""
Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL.
@@ -674,6 +679,11 @@ if __name__ == "__main__":
type=int,
help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.",
)
parser.add_argument(
"--safe_serialization",
action="store_true",
help="Whether to use safetensors/safe seialization when saving the pipeline.",
)
args = parser.parse_args()
@@ -766,11 +776,11 @@ if __name__ == "__main__":
vae=vae,
text_encoder=text_encoder,
image_encoder=image_encoder,
image_processor=image_processor,
clip_image_processor=image_processor,
clip_tokenizer=clip_tokenizer,
text_decoder=text_decoder,
text_tokenizer=text_tokenizer,
unet=unet,
scheduler=scheduler,
)
pipeline.save_pretrained(args.pipeline_output_path)
pipeline.save_pretrained(args.pipeline_output_path, safe_serialization=args.safe_serialization)

View File

@@ -13,9 +13,12 @@ from transformers import (
GPT2Tokenizer,
)
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging
from ...utils import deprecate, logging
from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
@@ -26,30 +29,6 @@ from .modeling_uvit import UniDiffuserModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
# New BaseOutput child class for joint image-text output
@dataclass
class ImageTextPipelineOutput(BaseOutput):
@@ -111,7 +90,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
image_encoder: CLIPVisionModelWithProjection,
image_processor: CLIPImageProcessor,
clip_image_processor: CLIPImageProcessor,
clip_tokenizer: CLIPTokenizer,
text_decoder: UniDiffuserTextDecoder,
text_tokenizer: GPT2Tokenizer,
@@ -130,7 +109,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
vae=vae,
text_encoder=text_encoder,
image_encoder=image_encoder,
image_processor=image_processor,
clip_image_processor=clip_image_processor,
clip_tokenizer=clip_tokenizer,
text_decoder=text_decoder,
text_tokenizer=text_tokenizer,
@@ -139,6 +118,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.num_channels_latents = vae.config.latent_channels
self.text_encoder_seq_len = text_encoder.config.max_position_embeddings
@@ -155,43 +135,38 @@ class UniDiffuserPipeline(DiffusionPipeline):
# TODO: handle safety checking?
self.safety_checker = None
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
# Add self.image_encoder, self.text_decoder to cpu_offloaded_models list
def enable_model_cpu_offload(self, gpu_id=0):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
self.vae.enable_slicing()
device = torch.device(f"cuda:{gpu_id}")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
hook = None
for cpu_offloaded_model in [
self.text_encoder.text_model,
self.image_encoder,
self.unet,
self.vae,
self.text_decoder.encode_prefix,
self.text_decoder.decode_prefix,
self.text_decoder,
]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
if self.safety_checker is not None:
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@@ -370,8 +345,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
)
return batch_size, multiplier
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# self.tokenizer => self.clip_tokenizer
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
@@ -381,6 +355,41 @@ class UniDiffuserPipeline(DiffusionPipeline):
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
**kwargs,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with self.tokenizer->self.clip_tokenizer
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -396,8 +405,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -405,7 +414,20 @@ class UniDiffuserPipeline(DiffusionPipeline):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
@@ -414,6 +436,10 @@ class UniDiffuserPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.clip_tokenizer)
text_inputs = self.clip_tokenizer(
prompt,
padding="max_length",
@@ -440,13 +466,31 @@ class UniDiffuserPipeline(DiffusionPipeline):
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -458,7 +502,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
@@ -474,6 +518,10 @@ class UniDiffuserPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.clip_tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.clip_tokenizer(
uncond_tokens,
@@ -498,17 +546,12 @@ class UniDiffuserPipeline(DiffusionPipeline):
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
return prompt_embeds, negative_prompt_embeds
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents
# Add num_prompts_per_image argument, sample from autoencoder moment distribution
@@ -587,7 +630,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
preprocessed_image = self.image_processor.preprocess(
preprocessed_image = self.clip_image_processor.preprocess(
image,
return_tensors="pt",
)
@@ -628,17 +671,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
return image_latents
# Note that the CLIP latents are not decoded for image generation.
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
# Rename: decode_latents -> decode_image_latents
def decode_image_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_text_latents(
self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None
):
@@ -720,6 +752,17 @@ class UniDiffuserPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
def decode_text_latents(self, text_latents, device):
output_token_list, seq_lengths = self.text_decoder.generate_captions(
text_latents, self.text_tokenizer.eos_token_id, device=device
)
output_list = output_token_list.cpu().numpy()
generated_text = [
self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
for output, length in zip(output_list, seq_lengths)
]
return generated_text
def _split(self, x, height, width):
r"""
Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W)
@@ -1181,7 +1224,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# Note that this differs from the formulation in the unidiffusers paper!
# do_classifier_free_guidance = guidance_scale > 1.0
do_classifier_free_guidance = guidance_scale > 1.0
# check if scheduler is in sigmas space
# scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
@@ -1194,15 +1237,18 @@ class UniDiffuserPipeline(DiffusionPipeline):
if mode in ["text2img"]:
# 3.1. Encode input prompt, if available
assert prompt is not None or prompt_embeds is not None
prompt_embeds = self._encode_prompt(
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=multiplier,
do_classifier_free_guidance=False, # don't support standard classifier-free guidance for now
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# if do_classifier_free_guidance:
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
else:
# 3.2. Prepare text latent variables, if input not available
prompt_embeds = self.prepare_text_latents(
@@ -1224,7 +1270,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
# 4.1. Encode images, if available
assert image is not None, "`img2text` requires a conditioning image"
# Encode image using VAE
image_vae = preprocess(image)
image_vae = self.image_processor.preprocess(image)
height, width = image_vae.shape[-2:]
image_vae_latents = self.encode_image_vae_latents(
image=image_vae,
@@ -1324,48 +1370,42 @@ class UniDiffuserPipeline(DiffusionPipeline):
callback(i, t, latents)
# 9. Post-processing
gen_image = None
gen_text = None
image = None
text = None
if mode == "joint":
image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width)
# Map latent VAE image back to pixel space
gen_image = self.decode_image_latents(image_vae_latents)
if not output_type == "latent":
# Map latent VAE image back to pixel space
image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = image_vae_latents
# Generate text using the text decoder
output_token_list, seq_lengths = self.text_decoder.generate_captions(
text_latents, self.text_tokenizer.eos_token_id, device=device
)
output_list = output_token_list.cpu().numpy()
gen_text = [
self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
for output, length in zip(output_list, seq_lengths)
]
text = self.decode_text_latents(text_latents, device)
elif mode in ["text2img", "img"]:
image_vae_latents, image_clip_latents = self._split(latents, height, width)
gen_image = self.decode_image_latents(image_vae_latents)
if not output_type == "latent":
# Map latent VAE image back to pixel space
image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = image_vae_latents
elif mode in ["img2text", "text"]:
text_latents = latents
output_token_list, seq_lengths = self.text_decoder.generate_captions(
text_latents, self.text_tokenizer.eos_token_id, device=device
)
output_list = output_token_list.cpu().numpy()
gen_text = [
self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
for output, length in zip(output_list, seq_lengths)
]
text = self.decode_text_latents(text_latents, device)
self.maybe_free_model_hooks()
# 10. Convert to PIL
if output_type == "pil" and gen_image is not None:
gen_image = self.numpy_to_pil(gen_image)
# 10. Postprocess the image, if necessary
if image is not None:
do_denormalize = [True] * image.shape[0]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (gen_image, gen_text)
return (image, text)
return ImageTextPipelineOutput(images=gen_image, text=gen_text)
return ImageTextPipelineOutput(images=image, text=text)

View File

@@ -1,5 +1,6 @@
import gc
import random
import traceback
import unittest
import numpy as np
@@ -20,17 +21,70 @@ from diffusers import (
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
from diffusers.utils.testing_utils import floats_tensor, load_image, nightly, require_torch_gpu, torch_device
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
load_image,
nightly,
require_torch_2,
require_torch_gpu,
run_test_in_subprocess,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
enable_full_determinism()
# Will be run via run_test_in_subprocess
def _test_unidiffuser_compile(in_queue, out_queue, timeout):
error = None
try:
inputs = in_queue.get(timeout=timeout)
torch_device = inputs.pop("torch_device")
seed = inputs.pop("seed")
inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(torch_device)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.set_progress_bar_config(disable=None)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
assert np.abs(image_slice - expected_slice).max() < 1e-1
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
class UniDiffuserPipelineFastTests(
PipelineTesterMixin, PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
):
pipeline_class = UniDiffuserPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
# vae_latents, not latents, is the argument that corresponds to VAE latent inputs
image_latents_params = frozenset(["vae_latents"])
def get_dummy_components(self):
unet = UniDiffuserModel.from_pretrained(
@@ -64,7 +118,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
subfolder="image_encoder",
)
# From the Stable Diffusion Image Variation pipeline tests
image_processor = CLIPImageProcessor(crop_size=32, size=32)
clip_image_processor = CLIPImageProcessor(crop_size=32, size=32)
# image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
text_tokenizer = GPT2Tokenizer.from_pretrained(
@@ -80,7 +134,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"vae": vae,
"text_encoder": text_encoder,
"image_encoder": image_encoder,
"image_processor": image_processor,
"clip_image_processor": clip_image_processor,
"clip_tokenizer": clip_tokenizer,
"text_decoder": text_decoder,
"text_tokenizer": text_tokenizer,
@@ -619,6 +673,19 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
expected_text_prefix = "An astronaut"
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
@unittest.skip(reason="Skip torch.compile test to speed up the slow test suite.")
@require_torch_2
def test_unidiffuser_compile(self, seed=0):
inputs = self.get_inputs(torch_device, seed=seed, generate_latents=True)
# Delete prompt and image for joint inference.
del inputs["prompt"]
del inputs["image"]
# Can't pickle a Generator object
del inputs["generator"]
inputs["torch_device"] = torch_device
inputs["seed"] = seed
run_test_in_subprocess(test_case=self, target_func=_test_unidiffuser_compile, inputs=inputs)
@nightly
@require_torch_gpu