mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-25 13:54:45 +08:00
Compare commits
22 Commits
release-te
...
rope-init-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
100166ed53 | ||
|
|
16704379a0 | ||
|
|
4bd87a1fe9 | ||
|
|
484443e0b4 | ||
|
|
61d96c3ae7 | ||
|
|
4f495b06dc | ||
|
|
40c13fe5b4 | ||
|
|
2a3fbc2cc2 | ||
|
|
089cf798eb | ||
|
|
cbc2ec8f44 | ||
|
|
b5f591fea8 | ||
|
|
05b38c3c0d | ||
|
|
8f7fde5701 | ||
|
|
a59672655b | ||
|
|
9aca79f2b8 | ||
|
|
bbcf2a8589 | ||
|
|
4cfb2164fb | ||
|
|
c977966502 | ||
|
|
1ca0a75567 | ||
|
|
c1e6a32ae4 | ||
|
|
77b2162817 | ||
|
|
4e66513a74 |
@@ -1,4 +1,7 @@
|
||||
name: Release Fast GPU Tests on main
|
||||
# Duplicate workflow to push_tests.yml that is meant to run on release/patch branches as a final check
|
||||
# Creating a duplicate workflow here is simpler than adding complex path/branch parsing logic to push_tests.yml
|
||||
# Needs to be updated if push_tests.yml updated
|
||||
name: (Release) Fast GPU Tests on main
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -226,6 +226,8 @@
|
||||
- sections:
|
||||
- local: api/models/controlnet
|
||||
title: ControlNetModel
|
||||
- local: api/models/controlnet_flux
|
||||
title: FluxControlNetModel
|
||||
- local: api/models/controlnet_hunyuandit
|
||||
title: HunyuanDiT2DControlNetModel
|
||||
- local: api/models/controlnet_sd3
|
||||
@@ -320,6 +322,8 @@
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
title: ControlNet
|
||||
- local: api/pipelines/controlnet_flux
|
||||
title: ControlNet with Flux.1
|
||||
- local: api/pipelines/controlnet_hunyuandit
|
||||
title: ControlNet with Hunyuan-DiT
|
||||
- local: api/pipelines/controlnet_sd3
|
||||
|
||||
45
docs/source/en/api/models/controlnet_flux.md
Normal file
45
docs/source/en/api/models/controlnet_flux.md
Normal file
@@ -0,0 +1,45 @@
|
||||
<!--Copyright 2024 The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# FluxControlNetModel
|
||||
|
||||
FluxControlNetModel is an implementation of ControlNet for Flux.1.
|
||||
|
||||
The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
|
||||
|
||||
## Loading from the original format
|
||||
|
||||
By default the [`FluxControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`].
|
||||
|
||||
```py
|
||||
from diffusers import FluxControlNetPipeline
|
||||
from diffusers.models import FluxControlNetModel, FluxMultiControlNetModel
|
||||
|
||||
controlnet = FluxControlNetModel.from_pretrained("InstantX/FLUX.1-dev-Controlnet-Canny")
|
||||
pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet)
|
||||
|
||||
controlnet = FluxControlNetModel.from_pretrained("InstantX/FLUX.1-dev-Controlnet-Canny")
|
||||
controlnet = FluxMultiControlNetModel([controlnet])
|
||||
pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet)
|
||||
```
|
||||
|
||||
## FluxControlNetModel
|
||||
|
||||
[[autodoc]] FluxControlNetModel
|
||||
|
||||
## FluxControlNetOutput
|
||||
|
||||
[[autodoc]] models.controlnet_flux.FluxControlNetOutput
|
||||
48
docs/source/en/api/pipelines/controlnet_flux.md
Normal file
48
docs/source/en/api/pipelines/controlnet_flux.md
Normal file
@@ -0,0 +1,48 @@
|
||||
<!--Copyright 2024 The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet with Flux.1
|
||||
|
||||
FluxControlNetPipeline is an implementation of ControlNet for Flux.1.
|
||||
|
||||
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
|
||||
|
||||
With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
|
||||
|
||||
This controlnet code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for Flux-ControlNet in the table below:
|
||||
|
||||
|
||||
| ControlNet type | Developer | Link |
|
||||
| -------- | ---------- | ---- |
|
||||
| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny) |
|
||||
| Depth | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Depth) |
|
||||
| Union | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union) |
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## FluxControlNetPipeline
|
||||
[[autodoc]] FluxControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## FluxPipelineOutput
|
||||
[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput
|
||||
@@ -30,63 +30,64 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
|
||||
| Pipeline | Tasks |
|
||||
|---|---|
|
||||
| [AltDiffusion](alt_diffusion) | image2image |
|
||||
| [aMUSEd](amused) | text2image |
|
||||
| [AnimateDiff](animatediff) | text2video |
|
||||
| [Attend-and-Excite](attend_and_excite) | text2image |
|
||||
| [Audio Diffusion](audio_diffusion) | image2audio |
|
||||
| [AudioLDM](audioldm) | text2audio |
|
||||
| [AudioLDM2](audioldm2) | text2audio |
|
||||
| [AuraFlow](auraflow) | text2image |
|
||||
| [BLIP Diffusion](blip_diffusion) | text2image |
|
||||
| [CogVideoX](cogvideox) | text2video |
|
||||
| [Consistency Models](consistency_models) | unconditional image generation |
|
||||
| [ControlNet](controlnet) | text2image, image2image, inpainting |
|
||||
| [ControlNet with Flux.1](controlnet_flux) | text2image |
|
||||
| [ControlNet with Hunyuan-DiT](controlnet_hunyuandit) | text2image |
|
||||
| [ControlNet with Stable Diffusion 3](controlnet_sd3) | text2image |
|
||||
| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
|
||||
| [ControlNet-XS](controlnetxs) | text2image |
|
||||
| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |
|
||||
| [Cycle Diffusion](cycle_diffusion) | image2image |
|
||||
| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
|
||||
| [DDIM](ddim) | unconditional image generation |
|
||||
| [DDPM](ddpm) | unconditional image generation |
|
||||
| [DeepFloyd IF](deepfloyd_if) | text2image, image2image, inpainting, super-resolution |
|
||||
| [DiffEdit](diffedit) | inpainting |
|
||||
| [DiT](dit) | text2image |
|
||||
| [GLIGEN](stable_diffusion/gligen) | text2image |
|
||||
| [Flux](flux) | text2image |
|
||||
| [Hunyuan-DiT](hunyuandit) | text2image |
|
||||
| [I2VGen-XL](i2vgenxl) | text2video |
|
||||
| [InstructPix2Pix](pix2pix) | image editing |
|
||||
| [Kandinsky 2.1](kandinsky) | text2image, image2image, inpainting, interpolation |
|
||||
| [Kandinsky 2.2](kandinsky_v22) | text2image, image2image, inpainting |
|
||||
| [Kandinsky 3](kandinsky3) | text2image, image2image |
|
||||
| [Kolors](kolors) | text2image |
|
||||
| [Latent Consistency Models](latent_consistency_models) | text2image |
|
||||
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
|
||||
| [LDM3D](stable_diffusion/ldm3d_diffusion) | text2image, text-to-3D, text-to-pano, upscaling |
|
||||
| [Latte](latte) | text2image |
|
||||
| [LEDITS++](ledits_pp) | image editing |
|
||||
| [Lumina-T2X](lumina) | text2image |
|
||||
| [Marigold](marigold) | depth |
|
||||
| [MultiDiffusion](panorama) | text2image |
|
||||
| [MusicLDM](musicldm) | text2audio |
|
||||
| [PAG](pag) | text2image |
|
||||
| [Paint by Example](paint_by_example) | inpainting |
|
||||
| [ParaDiGMS](paradigms) | text2image |
|
||||
| [Pix2Pix Zero](pix2pix_zero) | image editing |
|
||||
| [PIA](pia) | image2video |
|
||||
| [PixArt-α](pixart) | text2image |
|
||||
| [PNDM](pndm) | unconditional image generation |
|
||||
| [RePaint](repaint) | inpainting |
|
||||
| [Score SDE VE](score_sde_ve) | unconditional image generation |
|
||||
| [PixArt-Σ](pixart_sigma) | text2image |
|
||||
| [Self-Attention Guidance](self_attention_guidance) | text2image |
|
||||
| [Semantic Guidance](semantic_stable_diffusion) | text2image |
|
||||
| [Shap-E](shap_e) | text-to-3D, image-to-3D |
|
||||
| [Spectrogram Diffusion](spectrogram_diffusion) | |
|
||||
| [Stable Audio](stable_audio) | text2audio |
|
||||
| [Stable Cascade](stable_cascade) | text2image |
|
||||
| [Stable Diffusion](stable_diffusion/overview) | text2image, image2image, depth2image, inpainting, image variation, latent upscaler, super-resolution |
|
||||
| [Stable Diffusion Model Editing](model_editing) | model editing |
|
||||
| [Stable Diffusion XL](stable_diffusion/stable_diffusion_xl) | text2image, image2image, inpainting |
|
||||
| [Stable Diffusion XL Turbo](stable_diffusion/sdxl_turbo) | text2image, image2image, inpainting |
|
||||
| [Stable unCLIP](stable_unclip) | text2image, image variation |
|
||||
| [Stochastic Karras VE](stochastic_karras_ve) | unconditional image generation |
|
||||
| [T2I-Adapter](stable_diffusion/adapter) | text2image |
|
||||
| [Text2Video](text_to_video) | text2video, video2video |
|
||||
| [Text2Video-Zero](text_to_video_zero) | text2video |
|
||||
| [unCLIP](unclip) | text2image, image variation |
|
||||
| [Unconditional Latent Diffusion](latent_diffusion_uncond) | unconditional image generation |
|
||||
| [UniDiffuser](unidiffuser) | text2image, image2text, image variation, text variation, unconditional image generation, unconditional audio generation |
|
||||
| [Value-guided planning](value_guided_sampling) | value guided sampling |
|
||||
| [Versatile Diffusion](versatile_diffusion) | text2image, image variation |
|
||||
| [VQ Diffusion](vq_diffusion) | text2image |
|
||||
| [Wuerstchen](wuerstchen) | text2image |
|
||||
|
||||
## DiffusionPipeline
|
||||
|
||||
@@ -314,11 +314,12 @@ def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_di
|
||||
for x, y in zip(modifier_token_id, args.modifier_token):
|
||||
learned_embeds_dict = {}
|
||||
learned_embeds_dict[y] = learned_embeds[x]
|
||||
filename = f"{output_dir}/{y}.bin"
|
||||
|
||||
if safe_serialization:
|
||||
filename = f"{output_dir}/{y}.safetensors"
|
||||
safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"})
|
||||
else:
|
||||
filename = f"{output_dir}/{y}.bin"
|
||||
torch.save(learned_embeds_dict, filename)
|
||||
|
||||
|
||||
@@ -1040,17 +1041,22 @@ def main(args):
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
@@ -1065,8 +1071,14 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -842,7 +842,7 @@ class PromptDataset(Dataset):
|
||||
return example
|
||||
|
||||
|
||||
def tokenize_prompt(tokenizer, prompt, max_sequence_length=512):
|
||||
def tokenize_prompt(tokenizer, prompt, max_sequence_length):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -863,20 +863,26 @@ def _encode_prompt_with_t5(
|
||||
prompt=None,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
text_input_ids=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
if tokenizer is not None:
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
else:
|
||||
if text_input_ids is None:
|
||||
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = text_encoder.dtype
|
||||
@@ -896,22 +902,28 @@ def _encode_prompt_with_clip(
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
device=None,
|
||||
text_input_ids=None,
|
||||
num_images_per_prompt: int = 1,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_length=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if tokenizer is not None:
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_length=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
else:
|
||||
if text_input_ids is None:
|
||||
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
# Use pooled output of CLIPTextModel
|
||||
@@ -932,17 +944,19 @@ def encode_prompt(
|
||||
max_sequence_length,
|
||||
device=None,
|
||||
num_images_per_prompt: int = 1,
|
||||
text_input_ids_list=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
device = device if device is not None else text_encoders[1].device
|
||||
pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
text_encoder=text_encoders[0],
|
||||
tokenizer=tokenizers[0],
|
||||
prompt=prompt,
|
||||
device=device if device is not None else text_encoders[0].device,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
|
||||
)
|
||||
|
||||
prompt_embeds = _encode_prompt_with_t5(
|
||||
@@ -951,7 +965,8 @@ def encode_prompt(
|
||||
max_sequence_length=max_sequence_length,
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device if device is not None else text_encoders[1].device,
|
||||
device=device,
|
||||
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
|
||||
)
|
||||
|
||||
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
@@ -1499,7 +1514,25 @@ def main(args):
|
||||
)
|
||||
else:
|
||||
tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
|
||||
tokens_two = tokenize_prompt(tokenizer_two, prompts, max_sequence_length=512)
|
||||
tokens_two = tokenize_prompt(
|
||||
tokenizer_two, prompts, max_sequence_length=args.max_sequence_length
|
||||
)
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=[None, None],
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
prompt=prompts,
|
||||
)
|
||||
else:
|
||||
if args.train_text_encoder:
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=[None, None],
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
prompt=args.instance_prompt,
|
||||
)
|
||||
|
||||
# Convert images to latent space
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
@@ -1553,41 +1586,22 @@ def main(args):
|
||||
guidance = None
|
||||
|
||||
# Predict the noise residual
|
||||
if not args.train_text_encoder:
|
||||
model_pred = transformer(
|
||||
hidden_states=packed_noisy_model_input,
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
timestep=timesteps / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=None,
|
||||
prompt=None,
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
)
|
||||
model_pred = transformer(
|
||||
hidden_states=packed_noisy_model_input,
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
timestep=timesteps / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
model_pred = transformer(
|
||||
hidden_states=packed_noisy_model_input,
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
timestep=timesteps / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042
|
||||
model_pred = FluxPipeline._unpack_latents(
|
||||
model_pred,
|
||||
height=int(model_input.shape[2]),
|
||||
width=int(model_input.shape[3]),
|
||||
height=int(model_input.shape[2] * vae_scale_factor / 2),
|
||||
width=int(model_input.shape[3] * vae_scale_factor / 2),
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
)
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ else:
|
||||
"ControlNetXSAdapter",
|
||||
"DiTTransformer2DModel",
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
@@ -554,6 +555,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ControlNetXSAdapter,
|
||||
DiTTransformer2DModel,
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
|
||||
@@ -208,6 +208,8 @@ class IPAdapterMixin:
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
subfolder=image_encoder_subfolder,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import is_peft_version, logging
|
||||
|
||||
|
||||
@@ -326,3 +328,294 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
|
||||
prefix = "text_encoder_2."
|
||||
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
||||
return {new_name: alpha}
|
||||
|
||||
|
||||
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
|
||||
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
||||
# All credits go to `kohya-ss`.
|
||||
def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
|
||||
# scale weight by alpha and dim
|
||||
rank = down_weight.shape[0]
|
||||
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
|
||||
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
|
||||
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
|
||||
|
||||
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
|
||||
sd_lora_rank = down_weight.shape[0]
|
||||
|
||||
# scale weight by alpha and dim
|
||||
alpha = sds_sd.pop(sds_key + ".alpha")
|
||||
scale = alpha / sd_lora_rank
|
||||
|
||||
# calculate scale_down and scale_up
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
down_weight = down_weight * scale_down
|
||||
up_weight = up_weight * scale_up
|
||||
|
||||
# calculate dims if not provided
|
||||
num_splits = len(ait_keys)
|
||||
if dims is None:
|
||||
dims = [up_weight.shape[0] // num_splits] * num_splits
|
||||
else:
|
||||
assert sum(dims) == up_weight.shape[0]
|
||||
|
||||
# check upweight is sparse or not
|
||||
is_sparse = False
|
||||
if sd_lora_rank % num_splits == 0:
|
||||
ait_rank = sd_lora_rank // num_splits
|
||||
is_sparse = True
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
for k in range(len(dims)):
|
||||
if j == k:
|
||||
continue
|
||||
is_sparse = is_sparse and torch.all(
|
||||
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
|
||||
)
|
||||
i += dims[j]
|
||||
if is_sparse:
|
||||
logger.info(f"weight is sparse: {sds_key}")
|
||||
|
||||
# make ai-toolkit weight
|
||||
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||
if not is_sparse:
|
||||
# down_weight is copied to each split
|
||||
ait_sd.update({k: down_weight for k in ait_down_keys})
|
||||
|
||||
# up_weight is split to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
||||
else:
|
||||
# down_weight is chunked to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
|
||||
|
||||
# up_weight is sparse: only non-zero values are copied to each split
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
|
||||
i += dims[j]
|
||||
|
||||
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
|
||||
ait_sd = {}
|
||||
for i in range(19):
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_attn_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_out.0",
|
||||
)
|
||||
_convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mlp_0",
|
||||
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mlp_2",
|
||||
f"transformer.transformer_blocks.{i}.ff.net.2",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mod_lin",
|
||||
f"transformer.transformer_blocks.{i}.norm1.linear",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_attn_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_add_out",
|
||||
)
|
||||
_convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mlp_0",
|
||||
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mlp_2",
|
||||
f"transformer.transformer_blocks.{i}.ff_context.net.2",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mod_lin",
|
||||
f"transformer.transformer_blocks.{i}.norm1_context.linear",
|
||||
)
|
||||
|
||||
for i in range(38):
|
||||
_convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_linear1",
|
||||
[
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_v",
|
||||
f"transformer.single_transformer_blocks.{i}.proj_mlp",
|
||||
],
|
||||
dims=[3072, 3072, 3072, 12288],
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_linear2",
|
||||
f"transformer.single_transformer_blocks.{i}.proj_out",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_modulation_lin",
|
||||
f"transformer.single_transformer_blocks.{i}.norm.linear",
|
||||
)
|
||||
|
||||
if len(sds_sd) > 0:
|
||||
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
|
||||
|
||||
return ait_sd
|
||||
|
||||
return _convert_sd_scripts_to_ai_toolkit(state_dict)
|
||||
|
||||
|
||||
# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6
|
||||
# Some utilities were reused from
|
||||
# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
||||
def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
new_state_dict = {}
|
||||
orig_keys = list(old_state_dict.keys())
|
||||
|
||||
def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
||||
down_weight = sds_sd.pop(sds_key)
|
||||
up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
|
||||
|
||||
# calculate dims if not provided
|
||||
num_splits = len(ait_keys)
|
||||
if dims is None:
|
||||
dims = [up_weight.shape[0] // num_splits] * num_splits
|
||||
else:
|
||||
assert sum(dims) == up_weight.shape[0]
|
||||
|
||||
# make ai-toolkit weight
|
||||
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||
|
||||
# down_weight is copied to each split
|
||||
ait_sd.update({k: down_weight for k in ait_down_keys})
|
||||
|
||||
# up_weight is split to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
||||
|
||||
for old_key in orig_keys:
|
||||
# Handle double_blocks
|
||||
if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")):
|
||||
block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1)
|
||||
new_key = f"transformer.transformer_blocks.{block_num}"
|
||||
|
||||
if "processor.proj_lora1" in old_key:
|
||||
new_key += ".attn.to_out.0"
|
||||
elif "processor.proj_lora2" in old_key:
|
||||
new_key += ".attn.to_add_out"
|
||||
elif "processor.qkv_lora1" in old_key and "up" not in old_key:
|
||||
handle_qkv(
|
||||
old_state_dict,
|
||||
new_state_dict,
|
||||
old_key,
|
||||
[
|
||||
f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
|
||||
f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
|
||||
f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
|
||||
],
|
||||
)
|
||||
# continue
|
||||
elif "processor.qkv_lora2" in old_key and "up" not in old_key:
|
||||
handle_qkv(
|
||||
old_state_dict,
|
||||
new_state_dict,
|
||||
old_key,
|
||||
[
|
||||
f"transformer.transformer_blocks.{block_num}.attn.to_q",
|
||||
f"transformer.transformer_blocks.{block_num}.attn.to_k",
|
||||
f"transformer.transformer_blocks.{block_num}.attn.to_v",
|
||||
],
|
||||
)
|
||||
# continue
|
||||
|
||||
if "down" in old_key:
|
||||
new_key += ".lora_A.weight"
|
||||
elif "up" in old_key:
|
||||
new_key += ".lora_B.weight"
|
||||
|
||||
# Handle single_blocks
|
||||
elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"):
|
||||
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
|
||||
new_key = f"transformer.single_transformer_blocks.{block_num}"
|
||||
|
||||
if "proj_lora1" in old_key or "proj_lora2" in old_key:
|
||||
new_key += ".proj_out"
|
||||
elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
|
||||
new_key += ".norm.linear"
|
||||
|
||||
if "down" in old_key:
|
||||
new_key += ".lora_A.weight"
|
||||
elif "up" in old_key:
|
||||
new_key += ".lora_B.weight"
|
||||
|
||||
else:
|
||||
# Handle other potential key patterns here
|
||||
new_key = old_key
|
||||
|
||||
# Since we already handle qkv above.
|
||||
if "qkv" not in old_key:
|
||||
new_state_dict[new_key] = old_state_dict.pop(old_key)
|
||||
|
||||
if len(old_state_dict) > 0:
|
||||
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@@ -31,7 +31,12 @@ from ..utils import (
|
||||
scale_lora_layers,
|
||||
)
|
||||
from .lora_base import LoraBaseMixin
|
||||
from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
|
||||
from .lora_conversion_utils import (
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
_maybe_map_sgm_blocks_to_diffusers,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -1583,6 +1588,20 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
|
||||
|
||||
is_kohya = any(".lora_down.weight" in k for k in state_dict)
|
||||
if is_kohya:
|
||||
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
|
||||
# Kohya already takes care of scaling the LoRA parameters with alpha.
|
||||
return (state_dict, None) if return_alphas else state_dict
|
||||
|
||||
is_xlabs = any("processor" in k for k in state_dict)
|
||||
if is_xlabs:
|
||||
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
|
||||
# xlabs doesn't use `alpha`.
|
||||
return (state_dict, None) if return_alphas else state_dict
|
||||
|
||||
# For state dicts like
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
@@ -35,7 +35,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnet_flux"] = ["FluxControlNetModel"]
|
||||
_import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
|
||||
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
|
||||
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
||||
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
||||
@@ -88,7 +88,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQModel,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .controlnet_flux import FluxControlNetModel
|
||||
from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
|
||||
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from .controlnet_sparsectrl import SparseControlNetModel
|
||||
|
||||
@@ -972,15 +972,32 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
return frame_indices
|
||||
|
||||
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
||||
if weighting_scheme == "pyramid":
|
||||
if weighting_scheme == "flat":
|
||||
weights = [1.0] * num_frames
|
||||
|
||||
elif weighting_scheme == "pyramid":
|
||||
if num_frames % 2 == 0:
|
||||
# num_frames = 4 => [1, 2, 2, 1]
|
||||
weights = list(range(1, num_frames // 2 + 1))
|
||||
mid = num_frames // 2
|
||||
weights = list(range(1, mid + 1))
|
||||
weights = weights + weights[::-1]
|
||||
else:
|
||||
# num_frames = 5 => [1, 2, 3, 2, 1]
|
||||
weights = list(range(1, num_frames // 2 + 1))
|
||||
weights = weights + [num_frames // 2 + 1] + weights[::-1]
|
||||
mid = (num_frames + 1) // 2
|
||||
weights = list(range(1, mid))
|
||||
weights = weights + [mid] + weights[::-1]
|
||||
|
||||
elif weighting_scheme == "delayed_reverse_sawtooth":
|
||||
if num_frames % 2 == 0:
|
||||
# num_frames = 4 => [0.01, 2, 2, 1]
|
||||
mid = num_frames // 2
|
||||
weights = [0.01] * (mid - 1) + [mid]
|
||||
weights = weights + list(range(mid, 0, -1))
|
||||
else:
|
||||
# num_frames = 5 => [0.01, 0.01, 3, 2, 1]
|
||||
mid = (num_frames + 1) // 2
|
||||
weights = [0.01] * mid
|
||||
weights = weights + list(range(mid, 0, -1))
|
||||
else:
|
||||
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
@@ -101,6 +102,10 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
for _ in range(len(self.single_transformer_blocks)):
|
||||
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.union = num_mode is not None
|
||||
if self.union:
|
||||
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
|
||||
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@@ -173,8 +178,8 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
def from_transformer(
|
||||
cls,
|
||||
transformer,
|
||||
num_layers=4,
|
||||
num_single_layers=10,
|
||||
num_layers: int = 4,
|
||||
num_single_layers: int = 10,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
load_weights_from_transformer=True,
|
||||
@@ -205,6 +210,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
controlnet_mode: torch.Tensor = None,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
@@ -221,6 +227,12 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
controlnet_mode (`torch.Tensor`):
|
||||
The mode tensor of shape `(batch_size, 1)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
@@ -272,6 +284,15 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if self.union:
|
||||
# union mode
|
||||
if controlnet_mode is None:
|
||||
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
||||
# union mode emb
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
@@ -367,7 +388,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
|
||||
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
|
||||
|
||||
#
|
||||
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
|
||||
controlnet_single_block_samples = (
|
||||
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
|
||||
@@ -384,3 +404,114 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
)
|
||||
|
||||
|
||||
class FluxMultiControlNetModel(ModelMixin):
|
||||
r"""
|
||||
`FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
|
||||
|
||||
This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
|
||||
compatible with `FluxControlNetModel`.
|
||||
|
||||
Args:
|
||||
controlnets (`List[FluxControlNetModel]`):
|
||||
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
||||
`FluxControlNetModel` as a list.
|
||||
"""
|
||||
|
||||
def __init__(self, controlnets):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
controlnet_mode: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FluxControlNetOutput, Tuple]:
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1 and self.nets[0].union:
|
||||
controlnet = self.nets[0]
|
||||
|
||||
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
|
||||
block_samples, single_block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
controlnet_cond=image,
|
||||
controlnet_mode=mode[:, None],
|
||||
conditioning_scale=scale,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_projections,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
|
||||
# Regular Multi-ControlNets
|
||||
# load all ControlNets into memories
|
||||
else:
|
||||
for i, (image, mode, scale, controlnet) in enumerate(
|
||||
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
|
||||
):
|
||||
block_samples, single_block_samples = controlnet(
|
||||
hidden_states=hidden_states,
|
||||
controlnet_cond=image,
|
||||
controlnet_mode=mode[:, None],
|
||||
conditioning_scale=scale,
|
||||
timestep=timestep,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_projections,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
txt_ids=txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
|
||||
return control_block_samples, control_single_block_samples
|
||||
|
||||
@@ -691,7 +691,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
emb = emb.repeat_interleave(sample_num_frames, dim=0)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(sample_num_frames, dim=0)
|
||||
|
||||
# 2. pre-process
|
||||
batch_size, channels, num_frames, height, width = sample.shape
|
||||
|
||||
@@ -391,15 +391,16 @@ def get_3d_rotary_pos_embed(
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
if use_real is not True:
|
||||
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
||||
start, stop = crops_coords
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
@@ -408,54 +409,37 @@ def get_3d_rotary_pos_embed(
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
||||
|
||||
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
||||
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
||||
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
||||
|
||||
# Broadcast and concatenate tensors along specified dimension
|
||||
def broadcast(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = {len(t.shape) for t in tensors}
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatenation"
|
||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
||||
return torch.cat(tensors, dim=dim)
|
||||
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
||||
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
||||
freqs_t = freqs_t[:, None, None, :].expand(
|
||||
-1, grid_size_h, grid_size_w, -1
|
||||
) # temporal_size, grid_size_h, grid_size_w, dim_t
|
||||
freqs_h = freqs_h[None, :, None, :].expand(
|
||||
temporal_size, -1, grid_size_w, -1
|
||||
) # temporal_size, grid_size_h, grid_size_2, dim_h
|
||||
freqs_w = freqs_w[None, None, :, :].expand(
|
||||
temporal_size, grid_size_h, -1, -1
|
||||
) # temporal_size, grid_size_h, grid_size_2, dim_w
|
||||
|
||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
freqs = torch.cat(
|
||||
[freqs_t, freqs_h, freqs_w], dim=-1
|
||||
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
|
||||
freqs = freqs.view(
|
||||
temporal_size * grid_size_h * grid_size_w, -1
|
||||
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
|
||||
return freqs
|
||||
|
||||
t, h, w, d = freqs.shape
|
||||
freqs = freqs.view(t * h * w, d)
|
||||
|
||||
# Generate sine and cosine components
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos()
|
||||
|
||||
if use_real:
|
||||
return cos, sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
||||
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
||||
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
||||
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
||||
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
||||
return cos, sin
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
||||
@@ -530,7 +514,7 @@ def get_1d_rotary_pos_embed(
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
|
||||
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
@@ -561,21 +545,30 @@ def get_1d_rotary_pos_embed(
|
||||
assert dim % 2 == 0
|
||||
|
||||
if isinstance(pos, int):
|
||||
pos = np.arange(pos)
|
||||
pos = torch.arange(pos)
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||||
|
||||
theta = theta * ntk_factor
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
|
||||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
||||
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||||
if use_real and repeat_interleave_real:
|
||||
# flux, hunyuan-dit, cogvideox
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# stable audio
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
|
||||
# lumina
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
@@ -606,11 +599,11 @@ def apply_rotary_emb(
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Use for example in Lumina
|
||||
# Used for flux, cogvideox, hunyuan-dit
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Use for example in Stable Audio
|
||||
# Used for Stable Audio
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
@@ -620,6 +613,7 @@ def apply_rotary_emb(
|
||||
|
||||
return out
|
||||
else:
|
||||
# used for lumina
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
@@ -638,7 +632,7 @@ class FluxPosEmbed(nn.Module):
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.squeeze().float().cpu().numpy()
|
||||
pos = ids.squeeze().float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
freqs_dtype = torch.float32 if is_mps else torch.float64
|
||||
for i in range(n_axes):
|
||||
|
||||
@@ -38,6 +38,7 @@ from ..modeling_outputs import Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
from torch.profiler import record_function
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@@ -439,109 +440,114 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
else:
|
||||
guidance = None
|
||||
temb = (
|
||||
self.time_text_embed(timestep, pooled_projections)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
with record_function(" x_embedder"):
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
with record_function(" time_text_embed"):
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
guidance = None
|
||||
temb = (
|
||||
self.time_text_embed(timestep, pooled_projections)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
with record_function(" pos_embeds (rotary)"):
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
with record_function(" blocks"):
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
with record_function(" single blocks"):
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
# controlnet residual
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ class AnimateDiffTransformer3D(nn.Module):
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
@@ -2178,7 +2178,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
|
||||
emb = emb if aug_emb is None else emb + aug_emb
|
||||
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
||||
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
||||
if "image_embeds" not in added_cond_kwargs:
|
||||
|
||||
@@ -432,7 +432,6 @@ class AnimateDiffPipeline(
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
@@ -470,8 +469,8 @@ class AnimateDiffPipeline(
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
|
||||
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -557,11 +556,15 @@ class AnimateDiffPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_frames: Optional[int] = 16,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
@@ -701,9 +704,10 @@ class AnimateDiffPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
if prompt is not None and isinstance(prompt, (str, dict)):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
@@ -716,22 +720,39 @@ class AnimateDiffPipeline(
|
||||
text_encoder_lora_scale = (
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
if self.free_noise_enabled:
|
||||
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
|
||||
prompt=prompt,
|
||||
num_frames=num_frames,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
else:
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
@@ -783,6 +804,9 @@ class AnimateDiffPipeline(
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -505,8 +505,8 @@ class AnimateDiffControlNetPipeline(
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
|
||||
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -699,6 +699,10 @@ class AnimateDiffControlNetPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -858,9 +862,10 @@ class AnimateDiffControlNetPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
if prompt is not None and isinstance(prompt, (str, dict)):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
@@ -883,22 +888,39 @@ class AnimateDiffControlNetPipeline(
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
if self.free_noise_enabled:
|
||||
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
|
||||
prompt=prompt,
|
||||
num_frames=num_frames,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
else:
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
@@ -990,6 +1012,9 @@ class AnimateDiffControlNetPipeline(
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
@@ -1002,7 +1027,6 @@ class AnimateDiffControlNetPipeline(
|
||||
else:
|
||||
control_model_input = latent_model_input
|
||||
controlnet_prompt_embeds = prompt_embeds
|
||||
controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
|
||||
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
||||
|
||||
@@ -1143,6 +1143,8 @@ class AnimateDiffSDXLPipeline(
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1)
|
||||
|
||||
@@ -878,6 +878,8 @@ class AnimateDiffSparseControlNetPipeline(
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
# 4. Prepare IP-Adapter embeddings
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
|
||||
@@ -246,7 +246,6 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
@@ -299,7 +298,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
if prompt is not None and isinstance(prompt, (str, dict)):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
@@ -582,8 +581,8 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
|
||||
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -628,23 +627,20 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video,
|
||||
height,
|
||||
width,
|
||||
num_channels_latents,
|
||||
batch_size,
|
||||
timestep,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
video: Optional[torch.Tensor] = None,
|
||||
height: int = 64,
|
||||
width: int = 64,
|
||||
num_channels_latents: int = 4,
|
||||
batch_size: int = 1,
|
||||
timestep: Optional[int] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
decode_chunk_size: int = 16,
|
||||
):
|
||||
if latents is None:
|
||||
num_frames = video.shape[1]
|
||||
else:
|
||||
num_frames = latents.shape[2]
|
||||
|
||||
add_noise: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_frames = video.shape[1] if latents is None else latents.shape[2]
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
@@ -708,8 +704,13 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
if shape != latents.shape:
|
||||
# [B, C, F, H, W]
|
||||
raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
|
||||
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
if add_noise:
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
@@ -735,6 +736,10 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -743,6 +748,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
enforce_inference_steps: bool = False,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
@@ -874,9 +880,10 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
if prompt is not None and isinstance(prompt, (str, dict)):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
@@ -884,29 +891,85 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
dtype = self.dtype
|
||||
|
||||
# 3. Encode input prompt
|
||||
# 3. Prepare timesteps
|
||||
if not enforce_inference_steps:
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
else:
|
||||
denoising_inference_steps = int(num_inference_steps / strength)
|
||||
timesteps, denoising_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, denoising_inference_steps, device, timesteps, sigmas
|
||||
)
|
||||
timesteps = timesteps[-num_inference_steps:]
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
if latents is None:
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
# Move the number of frames before the number of channels.
|
||||
video = video.permute(0, 2, 1, 3, 4)
|
||||
video = video.to(device=device, dtype=dtype)
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
height=height,
|
||||
width=width,
|
||||
num_channels_latents=num_channels_latents,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
timestep=latent_timestep,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
decode_chunk_size=decode_chunk_size,
|
||||
add_noise=enforce_inference_steps,
|
||||
)
|
||||
|
||||
# 5. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
num_frames = latents.shape[2]
|
||||
if self.free_noise_enabled:
|
||||
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
|
||||
prompt=prompt,
|
||||
num_frames=num_frames,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
else:
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
# 6. Prepare IP-Adapter embeddings
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
@@ -916,38 +979,10 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
if latents is None:
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
# Move the number of frames before the number of channels.
|
||||
video = video.permute(0, 2, 1, 3, 4)
|
||||
video = video.to(device=device, dtype=prompt_embeds.dtype)
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
height=height,
|
||||
width=width,
|
||||
num_channels_latents=num_channels_latents,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
timestep=latent_timestep,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
decode_chunk_size=decode_chunk_size,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Add image embeds for IP-Adapter
|
||||
# 8. Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = (
|
||||
{"image_embeds": image_embeds}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
||||
@@ -967,9 +1002,12 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
# 8. Denoising loop
|
||||
# 9. Denoising loop
|
||||
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
@@ -1005,14 +1043,14 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
# 9. Post-processing
|
||||
# 10. Post-processing
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
# 11. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -463,7 +463,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
use_real=True,
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
|
||||
@@ -36,6 +36,8 @@ from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import FluxPipelineOutput
|
||||
|
||||
from torch.profiler import record_function
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
@@ -280,7 +282,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds
|
||||
@@ -536,7 +538,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
guidance_scale: float = 3.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
@@ -716,21 +718,24 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
with record_function(f"transformer iter_{i}"):
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
with record_function(f"scheduler.step (iter_{i})"):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -757,10 +762,11 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
with record_function(f"decode latent"):
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -27,7 +27,7 @@ from transformers import (
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnet_flux import FluxControlNetModel
|
||||
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
@@ -61,7 +61,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import FluxControlNetPipeline
|
||||
>>> from diffusers import FluxControlNetModel
|
||||
|
||||
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha"
|
||||
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
|
||||
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
||||
>>> pipe = FluxControlNetPipeline.from_pretrained(
|
||||
... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
|
||||
@@ -195,7 +195,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
text_encoder_2: T5EncoderModel,
|
||||
tokenizer_2: T5TokenizerFast,
|
||||
transformer: FluxTransformer2DModel,
|
||||
controlnet: FluxControlNetModel,
|
||||
controlnet: Union[
|
||||
FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
|
||||
],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -300,7 +302,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds
|
||||
@@ -571,6 +573,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
control_image: PipelineImageInput = None,
|
||||
control_mode: Optional[Union[int, List[int]]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
@@ -611,6 +614,20 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
||||
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
||||
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
||||
images must be passed as a list such that each element of the list can be correctly batched for input
|
||||
to a single ControlNet.
|
||||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
||||
the corresponding scale as a list.
|
||||
control_mode (`int` or `List[int]`,, *optional*, defaults to None):
|
||||
The control mode when applying ControlNet-Union.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -730,6 +747,55 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
# set control mode
|
||||
if control_mode is not None:
|
||||
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
|
||||
control_mode = control_mode.reshape([-1, 1])
|
||||
|
||||
elif isinstance(self.controlnet, FluxMultiControlNetModel):
|
||||
control_images = []
|
||||
|
||||
for control_image_ in control_image:
|
||||
control_image_ = self.prepare_image(
|
||||
image=control_image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
height, width = control_image_.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
||||
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image_.shape[2:]
|
||||
control_image_ = self._pack_latents(
|
||||
control_image_,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
control_images.append(control_image_)
|
||||
|
||||
control_image = control_images
|
||||
|
||||
# set control mode
|
||||
control_mode_ = []
|
||||
if isinstance(control_mode, list):
|
||||
for cmode in control_mode:
|
||||
if cmode is None:
|
||||
control_mode_.append(-1)
|
||||
else:
|
||||
control_mode_.append(cmode)
|
||||
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
|
||||
control_mode = control_mode.reshape([-1, 1])
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
@@ -785,6 +851,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
||||
hidden_states=latents,
|
||||
controlnet_cond=control_image,
|
||||
controlnet_mode=control_mode,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -22,6 +22,7 @@ from ..models.unets.unet_motion_model import (
|
||||
DownBlockMotion,
|
||||
UpBlockMotion,
|
||||
)
|
||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ..utils import logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
|
||||
@@ -98,6 +99,142 @@ class AnimateDiffFreeNoiseMixin:
|
||||
free_noise_transfomer_block.state_dict(), strict=True
|
||||
)
|
||||
|
||||
def _check_inputs_free_noise(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
num_frames,
|
||||
) -> None:
|
||||
if not isinstance(prompt, (str, dict)):
|
||||
raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}")
|
||||
|
||||
if negative_prompt is not None:
|
||||
if not isinstance(negative_prompt, (str, dict)):
|
||||
raise ValueError(
|
||||
f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}"
|
||||
)
|
||||
|
||||
if prompt_embeds is not None or negative_prompt_embeds is not None:
|
||||
raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.")
|
||||
|
||||
frame_indices = [isinstance(x, int) for x in prompt.keys()]
|
||||
frame_prompts = [isinstance(x, str) for x in prompt.values()]
|
||||
min_frame = min(list(prompt.keys()))
|
||||
max_frame = max(list(prompt.keys()))
|
||||
|
||||
if not all(frame_indices):
|
||||
raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.")
|
||||
if not all(frame_prompts):
|
||||
raise ValueError("Expected str values in `prompt` dict for FreeNoise.")
|
||||
if min_frame != 0:
|
||||
raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.")
|
||||
if max_frame >= num_frames:
|
||||
raise ValueError(
|
||||
f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing."
|
||||
)
|
||||
|
||||
def _encode_prompt_free_noise(
|
||||
self,
|
||||
prompt: Union[str, Dict[int, str]],
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
num_videos_per_prompt: int,
|
||||
do_classifier_free_guidance: bool,
|
||||
negative_prompt: Optional[Union[str, Dict[int, str]]] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
|
||||
# Ensure that we have a dictionary of prompts
|
||||
if isinstance(prompt, str):
|
||||
prompt = {0: prompt}
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = {0: negative_prompt}
|
||||
|
||||
self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames)
|
||||
|
||||
# Sort the prompts based on frame indices
|
||||
prompt = dict(sorted(prompt.items()))
|
||||
negative_prompt = dict(sorted(negative_prompt.items()))
|
||||
|
||||
# Ensure that we have a prompt for the last frame index
|
||||
prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]]
|
||||
negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]]
|
||||
|
||||
frame_indices = list(prompt.keys())
|
||||
frame_prompts = list(prompt.values())
|
||||
frame_negative_indices = list(negative_prompt.keys())
|
||||
frame_negative_prompts = list(negative_prompt.values())
|
||||
|
||||
# Generate and interpolate positive prompts
|
||||
prompt_embeds, _ = self.encode_prompt(
|
||||
prompt=frame_prompts,
|
||||
device=device,
|
||||
num_images_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=False,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
lora_scale=lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
|
||||
shape = (num_frames, *prompt_embeds.shape[1:])
|
||||
prompt_interpolation_embeds = prompt_embeds.new_zeros(shape)
|
||||
|
||||
for i in range(len(frame_indices) - 1):
|
||||
start_frame = frame_indices[i]
|
||||
end_frame = frame_indices[i + 1]
|
||||
start_tensor = prompt_embeds[i].unsqueeze(0)
|
||||
end_tensor = prompt_embeds[i + 1].unsqueeze(0)
|
||||
|
||||
prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(
|
||||
start_frame, end_frame, start_tensor, end_tensor
|
||||
)
|
||||
|
||||
# Generate and interpolate negative prompts
|
||||
negative_prompt_embeds = None
|
||||
negative_prompt_interpolation_embeds = None
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
_, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=[""] * len(frame_negative_prompts),
|
||||
device=device,
|
||||
num_images_per_prompt=num_videos_per_prompt,
|
||||
do_classifier_free_guidance=True,
|
||||
negative_prompt=frame_negative_prompts,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
lora_scale=lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
|
||||
negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape)
|
||||
|
||||
for i in range(len(frame_negative_indices) - 1):
|
||||
start_frame = frame_negative_indices[i]
|
||||
end_frame = frame_negative_indices[i + 1]
|
||||
start_tensor = negative_prompt_embeds[i].unsqueeze(0)
|
||||
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
|
||||
|
||||
negative_prompt_interpolation_embeds[
|
||||
start_frame : end_frame + 1
|
||||
] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
|
||||
|
||||
prompt_embeds = prompt_interpolation_embeds
|
||||
negative_prompt_embeds = negative_prompt_interpolation_embeds
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def _prepare_latents_free_noise(
|
||||
self,
|
||||
batch_size: int,
|
||||
@@ -172,12 +309,29 @@ class AnimateDiffFreeNoiseMixin:
|
||||
latents = latents[:, :, :num_frames]
|
||||
return latents
|
||||
|
||||
def _lerp(
|
||||
self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
num_indices = end_index - start_index + 1
|
||||
interpolated_tensors = []
|
||||
|
||||
for i in range(num_indices):
|
||||
alpha = i / (num_indices - 1)
|
||||
interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor
|
||||
interpolated_tensors.append(interpolated_tensor)
|
||||
|
||||
interpolated_tensors = torch.cat(interpolated_tensors)
|
||||
return interpolated_tensors
|
||||
|
||||
def enable_free_noise(
|
||||
self,
|
||||
context_length: Optional[int] = 16,
|
||||
context_stride: int = 4,
|
||||
weighting_scheme: str = "pyramid",
|
||||
noise_type: str = "shuffle_context",
|
||||
prompt_interpolation_callback: Optional[
|
||||
Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable long video generation using FreeNoise.
|
||||
@@ -195,13 +349,27 @@ class AnimateDiffFreeNoiseMixin:
|
||||
weighting_scheme (`str`, defaults to `pyramid`):
|
||||
Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting
|
||||
schemes are supported currently:
|
||||
- "flat"
|
||||
Performs weighting averaging with a flat weight pattern: [1, 1, 1, 1, 1].
|
||||
- "pyramid"
|
||||
Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
|
||||
Performs weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
|
||||
- "delayed_reverse_sawtooth"
|
||||
Performs weighted averaging with low weights for earlier frames and high-to-low weights for
|
||||
later frames: [0.01, 0.01, 3, 2, 1].
|
||||
noise_type (`str`, defaults to "shuffle_context"):
|
||||
TODO
|
||||
Must be one of ["shuffle_context", "repeat_context", "random"].
|
||||
- "shuffle_context"
|
||||
Shuffles a fixed batch of `context_length` latents to create a final latent of size
|
||||
`num_frames`. This is usually the best setting for most generation scenarious. However, there
|
||||
might be visible repetition noticeable in the kinds of motion/animation generated.
|
||||
- "repeated_context"
|
||||
Repeats a fixed batch of `context_length` latents to create a final latent of size
|
||||
`num_frames`.
|
||||
- "random"
|
||||
The final latents are random without any repetition.
|
||||
"""
|
||||
|
||||
allowed_weighting_scheme = ["pyramid"]
|
||||
allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"]
|
||||
allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
|
||||
|
||||
if context_length > self.motion_adapter.config.motion_max_seq_length:
|
||||
@@ -219,14 +387,25 @@ class AnimateDiffFreeNoiseMixin:
|
||||
self._free_noise_context_stride = context_stride
|
||||
self._free_noise_weighting_scheme = weighting_scheme
|
||||
self._free_noise_noise_type = noise_type
|
||||
self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp
|
||||
|
||||
if hasattr(self.unet.mid_block, "motion_modules"):
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
else:
|
||||
blocks = [*self.unet.down_blocks, *self.unet.up_blocks]
|
||||
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
for block in blocks:
|
||||
self._enable_free_noise_in_block(block)
|
||||
|
||||
def disable_free_noise(self) -> None:
|
||||
r"""Disable the FreeNoise sampling mechanism."""
|
||||
self._free_noise_context_length = None
|
||||
|
||||
if hasattr(self.unet.mid_block, "motion_modules"):
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
else:
|
||||
blocks = [*self.unet.down_blocks, *self.unet.up_blocks]
|
||||
|
||||
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
||||
for block in blocks:
|
||||
self._disable_free_noise_in_block(block)
|
||||
|
||||
@@ -547,7 +547,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
|
||||
negative_image_embeds = prior_outputs[1]
|
||||
|
||||
prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
|
||||
image = [image] if isinstance(prompt, PIL.Image.Image) else image
|
||||
image = [image] if isinstance(image, PIL.Image.Image) else image
|
||||
|
||||
if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
|
||||
prompt = (image_embeds.shape[0] // len(prompt)) * prompt
|
||||
@@ -813,7 +813,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
|
||||
negative_image_embeds = prior_outputs[1]
|
||||
|
||||
prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
|
||||
image = [image] if isinstance(prompt, PIL.Image.Image) else image
|
||||
image = [image] if isinstance(image, PIL.Image.Image) else image
|
||||
mask_image = [mask_image] if isinstance(mask_image, PIL.Image.Image) else mask_image
|
||||
|
||||
if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
|
||||
|
||||
@@ -734,6 +734,8 @@ class AnimateDiffPAGPipeline(
|
||||
elif self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
@@ -805,7 +807,9 @@ class AnimateDiffPAGPipeline(
|
||||
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
|
||||
latent_model_input = torch.cat(
|
||||
[latents] * (prompt_embeds.shape[0] // num_frames // latents.shape[0])
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
|
||||
@@ -824,6 +824,8 @@ class PIAPipeline(
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
||||
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
from torch.profiler import record_function
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -284,20 +284,22 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
with record_function(" _init_step_index"):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
with record_function(" get sigma and sigma_next"):
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
with record_function(" get prev_sample"):
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
@@ -418,11 +418,11 @@ class EMAModel:
|
||||
one_minus_decay = 1 - decay
|
||||
|
||||
context_manager = contextlib.nullcontext
|
||||
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
|
||||
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
if self.foreach:
|
||||
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
|
||||
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
||||
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
|
||||
|
||||
with context_manager():
|
||||
@@ -444,7 +444,7 @@ class EMAModel:
|
||||
|
||||
else:
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
|
||||
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
||||
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
|
||||
|
||||
with context_manager():
|
||||
|
||||
@@ -197,6 +197,21 @@ class FluxControlNetModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FluxMultiControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FluxTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 16)).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size * num_frames, 4, 16)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
|
||||
@@ -460,6 +460,29 @@ class AnimateDiffPipelineFastTests(
|
||||
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_multi_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
context_length = 8
|
||||
context_stride = 4
|
||||
pipe.enable_free_noise(context_length, context_stride)
|
||||
|
||||
# Make sure that pipeline works when prompt indices are within num_frames bounds
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"}
|
||||
inputs["num_frames"] = 16
|
||||
pipe(**inputs).frames[0]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Ensure that prompt indices are within bounds
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["num_frames"] = 16
|
||||
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
|
||||
pipe(**inputs).frames[0]
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
|
||||
@@ -476,6 +476,27 @@ class AnimateDiffControlNetPipelineFastTests(
|
||||
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_multi_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
context_length = 8
|
||||
context_stride = 4
|
||||
pipe.enable_free_noise(context_length, context_stride)
|
||||
|
||||
# Make sure that pipeline works when prompt indices are within num_frames bounds
|
||||
inputs = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"}
|
||||
pipe(**inputs).frames[0]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Ensure that prompt indices are within bounds
|
||||
inputs = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
|
||||
pipe(**inputs).frames[0]
|
||||
|
||||
def test_vae_slicing(self, video_count=2):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -491,3 +491,28 @@ class AnimateDiffVideoToVideoPipelineFastTests(
|
||||
1e-4,
|
||||
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_multi_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
context_length = 8
|
||||
context_stride = 4
|
||||
pipe.enable_free_noise(context_length, context_stride)
|
||||
|
||||
# Make sure that pipeline works when prompt indices are within num_frames bounds
|
||||
inputs = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"}
|
||||
inputs["num_inference_steps"] = 2
|
||||
inputs["strength"] = 0.5
|
||||
pipe(**inputs).frames[0]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Ensure that prompt indices are within bounds
|
||||
inputs = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs["num_inference_steps"] = 2
|
||||
inputs["strength"] = 0.5
|
||||
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
|
||||
pipe(**inputs).frames[0]
|
||||
|
||||
Reference in New Issue
Block a user