mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
[WIP] Refactor UniDiffuser Pipeline and Tests (#4948)
* Add VAE slicing and tiling methods. * Switch to using VaeImageProcessing for preprocessing and postprocessing of images. * Rename the VaeImageProcessor to vae_image_processor to avoid a name clash with the CLIPImageProcessor (image_processor). * Remove the postprocess() function because we're using a VaeImageProcessor instead. * Remove UniDiffuserPipeline.decode_image_latents because we're using VaeImageProcessor instead. * Refactor generating text from text latents into a decode_text_latents method. * Add enable_full_determinism() to UniDiffuser tests. * make style * Add PipelineLatentTesterMixin to UniDiffuserPipelineFastTests. * Remove enable_model_cpu_offload since it is now part of DiffusionPipeline. * Rename the VaeImageProcessor instance to self.image_processor for consistency with other pipelines and rename the CLIPImageProcessor instance to clip_image_processor to avoid a name clash. * Update UniDiffuser conversion script. * Make safe_serialization configurable in UniDiffuser conversion script. * Rename image_processor to clip_image_processor in UniDiffuser tests. * Add PipelineKarrasSchedulerTesterMixin to UniDiffuserPipelineFastTests. * Add initial test for compiling the UniDiffuser model (not tested yet). * Update encode_prompt and _encode_prompt to match that of StableDiffusionPipeline. * Turn off standard classifier-free guidance for now. * make style * make fix-copies * apply suggestions from review --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -73,17 +73,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||
|
||||
new_item = new_item.replace("q.weight", "query.weight")
|
||||
new_item = new_item.replace("q.bias", "query.bias")
|
||||
new_item = new_item.replace("q.weight", "to_q.weight")
|
||||
new_item = new_item.replace("q.bias", "to_q.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "key.weight")
|
||||
new_item = new_item.replace("k.bias", "key.bias")
|
||||
new_item = new_item.replace("k.weight", "to_k.weight")
|
||||
new_item = new_item.replace("k.bias", "to_k.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "value.weight")
|
||||
new_item = new_item.replace("v.bias", "value.bias")
|
||||
new_item = new_item.replace("v.weight", "to_v.weight")
|
||||
new_item = new_item.replace("v.bias", "to_v.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
@@ -92,6 +92,19 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
return mapping
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
|
||||
# config.num_head_channels => num_head_channels
|
||||
def assign_to_checkpoint(
|
||||
@@ -104,8 +117,9 @@ def assign_to_checkpoint(
|
||||
):
|
||||
"""
|
||||
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
||||
attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new
|
||||
checkpoint.
|
||||
attention layers, and takes into account additional replacements that may arise.
|
||||
|
||||
Assigns the weights to the new checkpoint.
|
||||
"""
|
||||
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||
|
||||
@@ -143,25 +157,16 @@ def assign_to_checkpoint(
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
if "proj_attn.weight" in new_path:
|
||||
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
|
||||
shape = old_checkpoint[path["old"]].shape
|
||||
if is_attn_weight and len(shape) == 3:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||
elif is_attn_weight and len(shape) == 4:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
def create_vae_diffusers_config(config_type):
|
||||
# Hardcoded for now
|
||||
if args.config_type == "test":
|
||||
@@ -339,7 +344,7 @@ def create_text_decoder_config_big():
|
||||
return text_decoder_config
|
||||
|
||||
|
||||
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments.convert_ldm_vae_checkpoint
|
||||
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
|
||||
def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1):
|
||||
"""
|
||||
Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL.
|
||||
@@ -674,6 +679,11 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization",
|
||||
action="store_true",
|
||||
help="Whether to use safetensors/safe seialization when saving the pipeline.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -766,11 +776,11 @@ if __name__ == "__main__":
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
clip_image_processor=image_processor,
|
||||
clip_tokenizer=clip_tokenizer,
|
||||
text_decoder=text_decoder,
|
||||
text_tokenizer=text_tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipeline.save_pretrained(args.pipeline_output_path)
|
||||
pipeline.save_pretrained(args.pipeline_output_path, safe_serialization=args.safe_serialization)
|
||||
|
||||
@@ -13,9 +13,12 @@ from transformers import (
|
||||
GPT2Tokenizer,
|
||||
)
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging
|
||||
from ...utils import deprecate, logging
|
||||
from ...utils.outputs import BaseOutput
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -26,30 +29,6 @@ from .modeling_uvit import UniDiffuserModel
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
|
||||
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
# New BaseOutput child class for joint image-text output
|
||||
@dataclass
|
||||
class ImageTextPipelineOutput(BaseOutput):
|
||||
@@ -111,7 +90,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
image_encoder: CLIPVisionModelWithProjection,
|
||||
image_processor: CLIPImageProcessor,
|
||||
clip_image_processor: CLIPImageProcessor,
|
||||
clip_tokenizer: CLIPTokenizer,
|
||||
text_decoder: UniDiffuserTextDecoder,
|
||||
text_tokenizer: GPT2Tokenizer,
|
||||
@@ -130,7 +109,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
clip_image_processor=clip_image_processor,
|
||||
clip_tokenizer=clip_tokenizer,
|
||||
text_decoder=text_decoder,
|
||||
text_tokenizer=text_tokenizer,
|
||||
@@ -139,6 +118,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.num_channels_latents = vae.config.latent_channels
|
||||
self.text_encoder_seq_len = text_encoder.config.max_position_embeddings
|
||||
@@ -155,43 +135,38 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
# TODO: handle safety checking?
|
||||
self.safety_checker = None
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
|
||||
# Add self.image_encoder, self.text_decoder to cpu_offloaded_models list
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
self.vae.enable_slicing()
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [
|
||||
self.text_encoder.text_model,
|
||||
self.image_encoder,
|
||||
self.unet,
|
||||
self.vae,
|
||||
self.text_decoder.encode_prefix,
|
||||
self.text_decoder.decode_prefix,
|
||||
self.text_decoder,
|
||||
]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
@@ -370,8 +345,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
)
|
||||
return batch_size, multiplier
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
# self.tokenizer => self.clip_tokenizer
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -381,6 +355,41 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with self.tokenizer->self.clip_tokenizer
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
@@ -396,8 +405,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
@@ -405,7 +414,20 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
@@ -414,6 +436,10 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.clip_tokenizer)
|
||||
|
||||
text_inputs = self.clip_tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -440,13 +466,31 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
if clip_skip is None:
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
||||
)
|
||||
# Access the `hidden_states` first, that contains a tuple of
|
||||
# all the hidden states from the encoder layers. Then index into
|
||||
# the tuple to access the hidden states from the desired layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
||||
# We also need to apply the final LayerNorm here to not mess with the
|
||||
# representations. The `last_hidden_states` that we typically use for
|
||||
# obtaining the final prompt representations passes through the LayerNorm
|
||||
# layer.
|
||||
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
@@ -458,7 +502,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
@@ -474,6 +518,10 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.clip_tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.clip_tokenizer(
|
||||
uncond_tokens,
|
||||
@@ -498,17 +546,12 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents
|
||||
# Add num_prompts_per_image argument, sample from autoencoder moment distribution
|
||||
@@ -587,7 +630,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
preprocessed_image = self.image_processor.preprocess(
|
||||
preprocessed_image = self.clip_image_processor.preprocess(
|
||||
image,
|
||||
return_tensors="pt",
|
||||
)
|
||||
@@ -628,17 +671,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
|
||||
return image_latents
|
||||
|
||||
# Note that the CLIP latents are not decoded for image generation.
|
||||
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
# Rename: decode_latents -> decode_image_latents
|
||||
def decode_image_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def prepare_text_latents(
|
||||
self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None
|
||||
):
|
||||
@@ -720,6 +752,17 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def decode_text_latents(self, text_latents, device):
|
||||
output_token_list, seq_lengths = self.text_decoder.generate_captions(
|
||||
text_latents, self.text_tokenizer.eos_token_id, device=device
|
||||
)
|
||||
output_list = output_token_list.cpu().numpy()
|
||||
generated_text = [
|
||||
self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
|
||||
for output, length in zip(output_list, seq_lengths)
|
||||
]
|
||||
return generated_text
|
||||
|
||||
def _split(self, x, height, width):
|
||||
r"""
|
||||
Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W)
|
||||
@@ -1181,7 +1224,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
# Note that this differs from the formulation in the unidiffusers paper!
|
||||
# do_classifier_free_guidance = guidance_scale > 1.0
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# check if scheduler is in sigmas space
|
||||
# scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
|
||||
@@ -1194,15 +1237,18 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
if mode in ["text2img"]:
|
||||
# 3.1. Encode input prompt, if available
|
||||
assert prompt is not None or prompt_embeds is not None
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=multiplier,
|
||||
do_classifier_free_guidance=False, # don't support standard classifier-free guidance for now
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# if do_classifier_free_guidance:
|
||||
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
else:
|
||||
# 3.2. Prepare text latent variables, if input not available
|
||||
prompt_embeds = self.prepare_text_latents(
|
||||
@@ -1224,7 +1270,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
# 4.1. Encode images, if available
|
||||
assert image is not None, "`img2text` requires a conditioning image"
|
||||
# Encode image using VAE
|
||||
image_vae = preprocess(image)
|
||||
image_vae = self.image_processor.preprocess(image)
|
||||
height, width = image_vae.shape[-2:]
|
||||
image_vae_latents = self.encode_image_vae_latents(
|
||||
image=image_vae,
|
||||
@@ -1324,48 +1370,42 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Post-processing
|
||||
gen_image = None
|
||||
gen_text = None
|
||||
image = None
|
||||
text = None
|
||||
if mode == "joint":
|
||||
image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width)
|
||||
|
||||
# Map latent VAE image back to pixel space
|
||||
gen_image = self.decode_image_latents(image_vae_latents)
|
||||
if not output_type == "latent":
|
||||
# Map latent VAE image back to pixel space
|
||||
image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
else:
|
||||
image = image_vae_latents
|
||||
|
||||
# Generate text using the text decoder
|
||||
output_token_list, seq_lengths = self.text_decoder.generate_captions(
|
||||
text_latents, self.text_tokenizer.eos_token_id, device=device
|
||||
)
|
||||
output_list = output_token_list.cpu().numpy()
|
||||
gen_text = [
|
||||
self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
|
||||
for output, length in zip(output_list, seq_lengths)
|
||||
]
|
||||
text = self.decode_text_latents(text_latents, device)
|
||||
elif mode in ["text2img", "img"]:
|
||||
image_vae_latents, image_clip_latents = self._split(latents, height, width)
|
||||
gen_image = self.decode_image_latents(image_vae_latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
# Map latent VAE image back to pixel space
|
||||
image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
else:
|
||||
image = image_vae_latents
|
||||
elif mode in ["img2text", "text"]:
|
||||
text_latents = latents
|
||||
output_token_list, seq_lengths = self.text_decoder.generate_captions(
|
||||
text_latents, self.text_tokenizer.eos_token_id, device=device
|
||||
)
|
||||
output_list = output_token_list.cpu().numpy()
|
||||
gen_text = [
|
||||
self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
|
||||
for output, length in zip(output_list, seq_lengths)
|
||||
]
|
||||
text = self.decode_text_latents(text_latents, device)
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
# 10. Convert to PIL
|
||||
if output_type == "pil" and gen_image is not None:
|
||||
gen_image = self.numpy_to_pil(gen_image)
|
||||
# 10. Postprocess the image, if necessary
|
||||
if image is not None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (gen_image, gen_text)
|
||||
return (image, text)
|
||||
|
||||
return ImageTextPipelineOutput(images=gen_image, text=gen_text)
|
||||
return ImageTextPipelineOutput(images=image, text=text)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import gc
|
||||
import random
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -20,17 +21,70 @@ from diffusers import (
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, load_image, nightly, require_torch_gpu, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
nightly,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_unidiffuser_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
inputs = in_queue.get(timeout=timeout)
|
||||
torch_device = inputs.pop("torch_device")
|
||||
seed = inputs.pop("seed")
|
||||
inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
|
||||
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
|
||||
# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
|
||||
assert np.abs(image_slice - expected_slice).max() < 1e-1
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class UniDiffuserPipelineFastTests(
|
||||
PipelineTesterMixin, PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = UniDiffuserPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
# vae_latents, not latents, is the argument that corresponds to VAE latent inputs
|
||||
image_latents_params = frozenset(["vae_latents"])
|
||||
|
||||
def get_dummy_components(self):
|
||||
unet = UniDiffuserModel.from_pretrained(
|
||||
@@ -64,7 +118,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
subfolder="image_encoder",
|
||||
)
|
||||
# From the Stable Diffusion Image Variation pipeline tests
|
||||
image_processor = CLIPImageProcessor(crop_size=32, size=32)
|
||||
clip_image_processor = CLIPImageProcessor(crop_size=32, size=32)
|
||||
# image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_tokenizer = GPT2Tokenizer.from_pretrained(
|
||||
@@ -80,7 +134,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"image_encoder": image_encoder,
|
||||
"image_processor": image_processor,
|
||||
"clip_image_processor": clip_image_processor,
|
||||
"clip_tokenizer": clip_tokenizer,
|
||||
"text_decoder": text_decoder,
|
||||
"text_tokenizer": text_tokenizer,
|
||||
@@ -619,6 +673,19 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
|
||||
expected_text_prefix = "An astronaut"
|
||||
assert text[0][: len(expected_text_prefix)] == expected_text_prefix
|
||||
|
||||
@unittest.skip(reason="Skip torch.compile test to speed up the slow test suite.")
|
||||
@require_torch_2
|
||||
def test_unidiffuser_compile(self, seed=0):
|
||||
inputs = self.get_inputs(torch_device, seed=seed, generate_latents=True)
|
||||
# Delete prompt and image for joint inference.
|
||||
del inputs["prompt"]
|
||||
del inputs["image"]
|
||||
# Can't pickle a Generator object
|
||||
del inputs["generator"]
|
||||
inputs["torch_device"] = torch_device
|
||||
inputs["seed"] = seed
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_unidiffuser_compile, inputs=inputs)
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user