Compare commits

...

12 Commits

Author SHA1 Message Date
apolinário
ff1012f8cf Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-04-21 23:34:30 +02:00
apolinário
25bc77d8f8 Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-04-21 23:34:24 +02:00
apolinário
9c35a89921 Swap order 2025-04-21 23:31:04 +02:00
github-actions[bot]
32d9aef997 Apply style fixes 2025-04-21 15:42:31 +00:00
apolinário
9edc5beddc Use config value directly 2025-04-21 17:40:13 +02:00
github-actions[bot]
f87956e9cf Apply style fixes 2025-04-19 19:33:27 +00:00
apolinário
690adb5bd9 Add stochastic sampling to FlowMatchEulerDiscreteScheduler
This PR adds stochastic sampling to FlowMatchEulerDiscreteScheduler based on b1aeddd7cc  ltx_video/schedulers/rf.py
2025-04-19 19:48:16 +02:00
YiYi Xu
5873377a66 [Wan2.1-FLF2V] update conversion script (#11365)
update scheuler config in conversion sript
2025-04-18 14:08:44 -10:00
YiYi Xu
5a2e0f715c update output for Hidream transformer (#11366)
up
2025-04-18 14:07:21 -10:00
Kazuki Yoda
ef47726e2d Fix: StableDiffusionXLControlNetAdapterInpaintPipeline incorrectly inherited StableDiffusionLoraLoaderMixin (#11357)
Fix: Inherit `StableDiffusionXLLoraLoaderMixin`

`StableDiffusionXLControlNetAdapterInpaintPipeline`
used to incorrectly inherit
`StableDiffusionLoraLoaderMixin`
instead of `StableDiffusionXLLoraLoaderMixin`
2025-04-18 12:46:06 -10:00
YiYi Xu
0021bfa1e1 support Wan-FLF2V (#11353)
* update transformer

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-04-18 10:27:50 -10:00
Marc Sun
bbd0c161b5 [BNB] Fix test_moving_to_cpu_throws_warning (#11356)
fix

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-04-18 09:44:51 +05:30
9 changed files with 255 additions and 19 deletions

View File

@@ -133,6 +133,60 @@ output = pipe(
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### First and Last Frame Interpolation
```python
import numpy as np
import torch
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
def center_crop_resize(image, height, width):
# Calculate resize ratio to match first frame dimensions
resize_ratio = max(width / image.width, height / image.height)
# Resize the image
width = round(image.width * resize_ratio)
height = round(image.height * resize_ratio)
size = [width, height]
image = TF.center_crop(image, size)
return image, height, width
first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
if last_frame.size != first_frame.size:
last_frame, _, _ = center_crop_resize(last_frame, height, width)
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
output = pipe(
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5
).frames[0]
export_to_video(output, "output.mp4", fps=16)
```
### Video to Video Generation
```python

View File

@@ -33,7 +33,6 @@ from diffusers import DiffusionPipeline
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import (
FromSingleFileMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
@@ -300,7 +299,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
class StableDiffusionXLControlNetAdapterInpaintPipeline(
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionLoraLoaderMixin
DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter

View File

@@ -39,6 +39,24 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# for the FLF2V model
"img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
@@ -135,6 +153,28 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096,
},
}
elif model_type == "Wan-FLF2V-14B-720P":
config = {
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
"diffusers_config": {
"image_dim": 1280,
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"rope_max_seq_len": 1024,
"pos_embed_seq_len": 257 * 2,
},
}
return config
@@ -393,11 +433,12 @@ if __name__ == "__main__":
vae = convert_vae()
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
scheduler = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
)
if "I2V" in args.model_type:
if "I2V" in args.model_type or "FLF2V" in args.model_type:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)

View File

@@ -918,5 +918,5 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output, hidden_states_masks)
return Transformer2DModelOutput(sample=output, mask=hidden_states_masks)
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -49,8 +49,10 @@ class WanAttnProcessor2_0:
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
encoder_hidden_states_img = encoder_hidden_states[:, :257]
encoder_hidden_states = encoder_hidden_states[:, 257:]
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
@@ -108,14 +110,23 @@ class WanAttnProcessor2_0:
class WanImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.norm2 = FP32LayerNorm(out_features)
if pos_embed_seq_len is not None:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
else:
self.pos_embed = None
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
if self.pos_embed is not None:
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states)
@@ -130,6 +141,7 @@ class WanTimeTextImageEmbedding(nn.Module):
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
@@ -141,7 +153,7 @@ class WanTimeTextImageEmbedding(nn.Module):
self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
def forward(
self,
@@ -350,6 +362,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
pos_embed_seq_len: Optional[int] = None,
) -> None:
super().__init__()
@@ -368,6 +381,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
pos_embed_seq_len=pos_embed_seq_len,
)
# 3. Transformer blocks

View File

@@ -404,6 +404,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
if is_loaded_in_8bit_bnb:
return False
return hasattr(module, "_hf_hook") and (
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
or hasattr(module._hf_hook, "hooks")

View File

@@ -380,6 +380,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
last_image: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
@@ -398,9 +399,16 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents = latents.to(device=device, dtype=dtype)
image = image.unsqueeze(2)
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
)
if last_image is None:
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
)
else:
last_image = last_image.unsqueeze(2)
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
dim=2,
)
video_condition = video_condition.to(device=device, dtype=dtype)
latents_mean = (
@@ -424,7 +432,11 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_condition = (latent_condition - latents_mean) * latents_std
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, num_frames))] = 0
if last_image is None:
mask_lat_size[:, :, list(range(1, num_frames))] = 0
else:
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
@@ -476,6 +488,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
last_image: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -620,7 +633,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
if image_embeds is None:
image_embeds = self.encode_image(image, device)
if last_image is None:
image_embeds = self.encode_image(image, device)
else:
image_embeds = self.encode_image([image, last_image], device)
image_embeds = image_embeds.repeat(batch_size, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)
@@ -631,6 +647,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
# 5. Prepare latent variables
num_channels_latents = self.vae.config.z_dim
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
if last_image is not None:
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
device, dtype=torch.float32
)
latents, condition = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
@@ -642,6 +662,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
device,
generator,
latents,
last_image,
)
# 6. Denoising loop

View File

@@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
time_shift_type (`str`, defaults to "exponential"):
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
stochastic_sampling (`bool`, defaults to False):
Whether to use stochastic sampling.
"""
_compatibles = []
@@ -101,6 +103,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
time_shift_type: str = "exponential",
stochastic_sampling: bool = False,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -437,13 +440,25 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
lower_sigmas = lower_mask * sigmas
lower_sigmas, _ = lower_sigmas.max(dim=0)
dt = (per_token_sigmas - lower_sigmas)[..., None]
current_sigma = per_token_sigmas[..., None]
next_sigma = lower_sigmas[..., None]
dt = current_sigma - next_sigma
else:
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
sigma_idx = self.step_index
sigma = self.sigmas[sigma_idx]
sigma_next = self.sigmas[sigma_idx + 1]
current_sigma = sigma
next_sigma = sigma_next
dt = sigma_next - sigma
prev_sample = sample + dt * model_output
if self.config.stochastic_sampling:
x0 = sample - current_sigma * model_output
noise = torch.randn_like(sample)
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
else:
prev_sample = sample + dt * model_output
# upon completion increase step index by one
self._step_index += 1

View File

@@ -160,3 +160,90 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
def test_inference_batch_single_identical(self):
pass
class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
# TODO: impl FlowDPMSolverMultistepScheduler
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
transformer = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=36,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
image_dim=4,
pos_embed_seq_len=2 * (4 * 4 + 1),
)
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=4,
projection_dim=4,
num_hidden_layers=2,
num_attention_heads=2,
image_size=4,
intermediate_size=16,
patch_size=1,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
torch.manual_seed(0)
image_processor = CLIPImageProcessor(crop_size=4, size=4)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"image_encoder": image_encoder,
"image_processor": image_processor,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image_height = 16
image_width = 16
image = Image.new("RGB", (image_width, image_height))
last_image = Image.new("RGB", (image_width, image_height))
inputs = {
"image": image,
"last_image": last_image,
"prompt": "dance monkey",
"negative_prompt": "negative",
"height": image_height,
"width": image_width,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs