Compare commits

...

14 Commits

Author SHA1 Message Date
yiyixuxu
5512d70f7f add 5b ti2v 2025-07-29 02:27:15 +02:00
yiyixuxu
4fd0333b17 style 2025-07-28 21:57:52 +02:00
yiyixuxu
c0af998473 Merge branch 'wan2.2' of github.com:huggingface/diffusers into wan2.2 2025-07-28 21:48:47 +02:00
yiyixuxu
1ff7c99596 fix fast tests 2025-07-28 21:48:19 +02:00
YiYi Xu
6bb1677bd4 Update src/diffusers/models/transformers/transformer_wan.py 2025-07-28 08:13:59 -10:00
YiYi Xu
97675c7036 Apply suggestions from code review
Co-authored-by: bagheera <59658056+bghira@users.noreply.github.com>
2025-07-28 08:13:16 -10:00
yiyixuxu
e78a4fa30a remove a copied from in skyreels 2025-07-28 20:09:10 +02:00
yiyixuxu
0ab2e4fd91 refactor out reearrange 2025-07-28 20:07:31 +02:00
yiyixuxu
5709f7e04d conversion script 2025-07-28 12:28:42 +02:00
yiyixuxu
27ce75b984 add 5b t2v 2025-07-28 12:28:23 +02:00
yiyixuxu
bf2c6e0a08 add 2025-07-28 12:22:09 +02:00
yiyixuxu
95a55f9dc0 add conversion script for vae 2.2 2025-07-27 08:50:32 +02:00
yiyixuxu
f5da83c4a3 add t2v + vae2.2 2025-07-27 08:42:56 +02:00
yiyixuxu
7ba78d5ce5 support wan 2.2 i2v 2025-07-25 13:29:11 +02:00
9 changed files with 1107 additions and 96 deletions

View File

@@ -278,16 +278,82 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-I2V-14B-720p":
config = {
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-T2V-A14B":
config = {
"model_id": "Wan-AI/Wan2.2-T2V-A14B",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-TI2V-5B":
config = {
"model_id": "Wan-AI/Wan2.2-TI2V-5B",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 14336,
"freq_dim": 256,
"in_channels": 48,
"num_attention_heads": 24,
"num_layers": 30,
"out_channels": 48,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
def convert_transformer(model_type: str):
def convert_transformer(model_type: str, stage: str = None):
config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)
diffusers_config = config["diffusers_config"]
model_id = config["model_id"]
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
if stage is not None:
model_dir = model_dir / stage
original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights():
@@ -515,6 +581,310 @@ def convert_vae():
return vae
vae22_diffusers_config = {
"base_dim": 160,
"z_dim": 48,
"is_residual": True,
"in_channels": 12,
"out_channels": 12,
"decoder_base_dim": 256,
"scale_factor_temporal": 4,
"scale_factor_spatial": 16,
"patch_size": 2,
"latents_mean": [
-0.2289,
-0.0052,
-0.1323,
-0.2339,
-0.2799,
0.0174,
0.1838,
0.1557,
-0.1382,
0.0542,
0.2813,
0.0891,
0.1570,
-0.0098,
0.0375,
-0.1825,
-0.2246,
-0.1207,
-0.0698,
0.5109,
0.2665,
-0.2108,
-0.2158,
0.2502,
-0.2055,
-0.0322,
0.1109,
0.1567,
-0.0729,
0.0899,
-0.2799,
-0.1230,
-0.0313,
-0.1649,
0.0117,
0.0723,
-0.2839,
-0.2083,
-0.0520,
0.3748,
0.0152,
0.1957,
0.1433,
-0.2944,
0.3573,
-0.0548,
-0.1681,
-0.0667,
],
"latents_std": [
0.4765,
1.0364,
0.4514,
1.1677,
0.5313,
0.4990,
0.4818,
0.5013,
0.8158,
1.0344,
0.5894,
1.0901,
0.6885,
0.6165,
0.8454,
0.4978,
0.5759,
0.3523,
0.7135,
0.6804,
0.5833,
1.4146,
0.8986,
0.5659,
0.7069,
0.5338,
0.4889,
0.4917,
0.4069,
0.4999,
0.6866,
0.4093,
0.5709,
0.6065,
0.6415,
0.4944,
0.5726,
1.2042,
0.5458,
1.6887,
0.3971,
1.0600,
0.3943,
0.5537,
0.5444,
0.4089,
0.7468,
0.7744,
],
"clip_output": False,
}
def convert_vae_22():
vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.2-TI2V-5B", "Wan2.2_VAE.pth")
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
new_state_dict = {}
# Create mappings for specific components
middle_key_mapping = {
# Encoder middle block
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
# Decoder middle block
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
}
# Create a mapping for attention blocks
attention_mapping = {
# Encoder middle attention
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
# Decoder middle attention
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
}
# Create a mapping for the head components
head_mapping = {
# Encoder head
"encoder.head.0.gamma": "encoder.norm_out.gamma",
"encoder.head.2.bias": "encoder.conv_out.bias",
"encoder.head.2.weight": "encoder.conv_out.weight",
# Decoder head
"decoder.head.0.gamma": "decoder.norm_out.gamma",
"decoder.head.2.bias": "decoder.conv_out.bias",
"decoder.head.2.weight": "decoder.conv_out.weight",
}
# Create a mapping for the quant components
quant_mapping = {
"conv1.weight": "quant_conv.weight",
"conv1.bias": "quant_conv.bias",
"conv2.weight": "post_quant_conv.weight",
"conv2.bias": "post_quant_conv.bias",
}
# Process each key in the state dict
for key, value in old_state_dict.items():
# Handle middle block keys using the mapping
if key in middle_key_mapping:
new_key = middle_key_mapping[key]
new_state_dict[new_key] = value
# Handle attention blocks using the mapping
elif key in attention_mapping:
new_key = attention_mapping[key]
new_state_dict[new_key] = value
# Handle head keys using the mapping
elif key in head_mapping:
new_key = head_mapping[key]
new_state_dict[new_key] = value
# Handle quant keys using the mapping
elif key in quant_mapping:
new_key = quant_mapping[key]
new_state_dict[new_key] = value
# Handle encoder conv1
elif key == "encoder.conv1.weight":
new_state_dict["encoder.conv_in.weight"] = value
elif key == "encoder.conv1.bias":
new_state_dict["encoder.conv_in.bias"] = value
# Handle decoder conv1
elif key == "decoder.conv1.weight":
new_state_dict["decoder.conv_in.weight"] = value
elif key == "decoder.conv1.bias":
new_state_dict["decoder.conv_in.bias"] = value
# Handle encoder downsamples
elif key.startswith("encoder.downsamples."):
# Change encoder.downsamples to encoder.down_blocks
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
# Handle residual blocks - change downsamples to resnets and rename components
if "residual" in new_key or "shortcut" in new_key:
# Change the second downsamples to resnets
new_key = new_key.replace(".downsamples.", ".resnets.")
# Rename residual components
if ".residual.0.gamma" in new_key:
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
elif ".residual.2.weight" in new_key:
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
elif ".residual.2.bias" in new_key:
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
elif ".residual.3.gamma" in new_key:
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
elif ".residual.6.weight" in new_key:
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
elif ".residual.6.bias" in new_key:
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
elif ".shortcut.weight" in new_key:
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
elif ".shortcut.bias" in new_key:
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
# Handle resample blocks - change downsamples to downsampler and remove index
elif "resample" in new_key or "time_conv" in new_key:
# Change the second downsamples to downsampler and remove the index
parts = new_key.split(".")
# Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
# We want to change it to: encoder.down_blocks.X.downsampler.resample...
if len(parts) >= 4 and parts[3] == "downsamples":
# Remove the index (parts[4]) and change downsamples to downsampler
new_parts = parts[:3] + ["downsampler"] + parts[5:]
new_key = ".".join(new_parts)
new_state_dict[new_key] = value
# Handle decoder upsamples
elif key.startswith("decoder.upsamples."):
# Change decoder.upsamples to decoder.up_blocks
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
# Handle residual blocks - change upsamples to resnets and rename components
if "residual" in new_key or "shortcut" in new_key:
# Change the second upsamples to resnets
new_key = new_key.replace(".upsamples.", ".resnets.")
# Rename residual components
if ".residual.0.gamma" in new_key:
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
elif ".residual.2.weight" in new_key:
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
elif ".residual.2.bias" in new_key:
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
elif ".residual.3.gamma" in new_key:
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
elif ".residual.6.weight" in new_key:
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
elif ".residual.6.bias" in new_key:
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
elif ".shortcut.weight" in new_key:
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
elif ".shortcut.bias" in new_key:
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
# Handle resample blocks - change upsamples to upsampler and remove index
elif "resample" in new_key or "time_conv" in new_key:
# Change the second upsamples to upsampler and remove the index
parts = new_key.split(".")
# Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
# We want to change it to: encoder.down_blocks.X.downsampler.resample...
if len(parts) >= 4 and parts[3] == "upsamples":
# Remove the index (parts[4]) and change upsamples to upsampler
new_parts = parts[:3] + ["upsampler"] + parts[5:]
new_key = ".".join(new_parts)
new_state_dict[new_key] = value
else:
# Keep other keys unchanged
new_state_dict[key] = value
with init_empty_weights():
vae = AutoencoderKLWan(**vae22_diffusers_config)
vae.load_state_dict(new_state_dict, strict=True, assign=True)
return vae
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default=None)
@@ -533,11 +903,26 @@ DTYPE_MAPPING = {
if __name__ == "__main__":
args = get_args()
transformer = convert_transformer(args.model_type)
vae = convert_vae()
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
transformer = convert_transformer(args.model_type, stage="high_noise_model")
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
else:
transformer = convert_transformer(args.model_type)
transformer_2 = None
if "Wan2.2" in args.model_type and "TI2V" in args.model_type:
vae = convert_vae_22()
else:
vae = convert_vae()
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
if "FLF2V" in args.model_type:
flow_shift = 16.0
elif "TI2V" in args.model_type:
flow_shift = 5.0
else:
flow_shift = 3.0
scheduler = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
)
@@ -547,7 +932,36 @@ if __name__ == "__main__":
dtype = DTYPE_MAPPING[args.dtype]
transformer.to(dtype)
if "I2V" in args.model_type or "FLF2V" in args.model_type:
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
pipe = WanImageToVideoPipeline(
transformer=transformer,
transformer_2=transformer_2,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
boundary_ratio=0.9,
)
elif "Wan2.2" and "T2V" in args.model_type:
pipe = WanPipeline(
transformer=transformer,
transformer_2=transformer_2,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
boundary_ratio=0.875,
)
elif "Wan2.2" and "TI2V" in args.model_type:
pipe = WanPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
expand_timesteps=True,
)
elif "I2V" in args.model_type or "FLF2V" in args.model_type:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)

View File

@@ -34,6 +34,103 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CACHE_T = 2
class AvgDown3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (0, 0, 0, 0, pad_t, 0)
x = F.pad(x, pad)
B, C, T, H, W = x.shape
x = x.view(
B,
C,
T // self.factor_t,
self.factor_t,
H // self.factor_s,
self.factor_s,
W // self.factor_s,
self.factor_s,
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(
B,
C * self.factor,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.view(
B,
self.out_channels,
self.group_size,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.mean(dim=2)
return x
class DupUp3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor_t,
self.factor_s,
self.factor_s,
x.size(2),
x.size(3),
x.size(4),
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor_t,
x.size(4) * self.factor_s,
x.size(6) * self.factor_s,
)
if first_chunk:
x = x[:, :, self.factor_t - 1 :, :, :]
return x
class WanCausalConv3d(nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
@@ -134,19 +231,25 @@ class WanResample(nn.Module):
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
def __init__(self, dim: int, mode: str) -> None:
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
super().__init__()
self.dim = dim
self.mode = mode
# default to dim //2
if upsample_out_dim is None:
upsample_out_dim = dim // 2
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
)
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
@@ -363,6 +466,42 @@ class WanMidBlock(nn.Module):
return x
class WanResidualDownBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False):
super().__init__()
# Shortcut path with downsample
self.avg_shortcut = AvgDown3D(
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path with residual blocks and downsample
resnets = []
for _ in range(num_res_blocks):
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add the final downsample block
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
self.downsampler = WanResample(out_dim, mode=mode)
else:
self.downsampler = None
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for resnet in self.resnets:
x = resnet(x, feat_cache, feat_idx)
if self.downsampler is not None:
x = self.downsampler(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy)
class WanEncoder3d(nn.Module):
r"""
A 3D encoder module.
@@ -380,6 +519,7 @@ class WanEncoder3d(nn.Module):
def __init__(
self,
in_channels: int = 3,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
@@ -388,6 +528,7 @@ class WanEncoder3d(nn.Module):
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
is_residual: bool = False, # wan 2.2 vae use a residual downblock
):
super().__init__()
self.dim = dim
@@ -403,23 +544,35 @@ class WanEncoder3d(nn.Module):
scale = 1.0
# init block
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(WanAttentionBlock(out_dim))
in_dim = out_dim
if is_residual:
self.down_blocks.append(
WanResidualDownBlock(
in_dim,
out_dim,
dropout,
num_res_blocks,
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
down_flag=i != len(dim_mult) - 1,
)
)
else:
for _ in range(num_res_blocks):
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(WanAttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(WanResample(out_dim, mode=mode))
scale /= 2.0
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(WanResample(out_dim, mode=mode))
scale /= 2.0
# middle blocks
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
@@ -470,6 +623,94 @@ class WanEncoder3d(nn.Module):
return x
class WanResidualUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
temperal_upsample (bool): Whether to upsample on temporal dimension
up_flag (bool): Whether to upsample or not
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
temperal_upsample: bool = False,
up_flag: bool = False,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2,
)
else:
self.avg_shortcut = None
# create residual blocks
resnets = []
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
if up_flag:
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
else:
self.upsampler = None
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
x_copy = x.clone()
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsampler is not None:
if feat_cache is not None:
x = self.upsampler(x, feat_cache, feat_idx)
else:
x = self.upsampler(x)
if self.avg_shortcut is not None:
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
return x
class WanUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
@@ -513,7 +754,7 @@ class WanUpBlock(nn.Module):
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
"""
Forward pass through the upsampling block.
@@ -564,6 +805,8 @@ class WanDecoder3d(nn.Module):
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
out_channels: int = 3,
is_residual: bool = False,
):
super().__init__()
self.dim = dim
@@ -577,7 +820,6 @@ class WanDecoder3d(nn.Module):
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
@@ -589,36 +831,47 @@ class WanDecoder3d(nn.Module):
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i > 0:
if i > 0 and not is_residual:
# wan vae 2.1
in_dim = in_dim // 2
# Determine if we need upsampling
# determine if we need upsampling
up_flag = i != len(dim_mult) - 1
# determine upsampling mode, if not upsampling, set to None
upsample_mode = None
if i != len(dim_mult) - 1:
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
if up_flag and temperal_upsample[i]:
upsample_mode = "upsample3d"
elif up_flag:
upsample_mode = "upsample2d"
# Create and add the upsampling block
up_block = WanUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
if is_residual:
up_block = WanResidualUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
temperal_upsample=temperal_upsample[i] if up_flag else False,
up_flag=up_flag,
non_linearity=non_linearity,
)
else:
up_block = WanUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)
# Update scale for next iteration
if upsample_mode is not None:
scale *= 2.0
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
@@ -637,7 +890,7 @@ class WanDecoder3d(nn.Module):
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx)
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
## head
x = self.norm_out(x)
@@ -656,6 +909,77 @@ class WanDecoder3d(nn.Module):
return x
def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
# x shape: [batch_size, channels, height, width]
batch_size, channels, height, width = x.shape
# Ensure height and width are divisible by patch_size
if height % patch_size != 0 or width % patch_size != 0:
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
# Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size]
x = x.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
# Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size]
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
x = x.view(batch_size, channels * patch_size * patch_size, height // patch_size, width // patch_size)
elif x.dim() == 5:
# x shape: [batch_size, channels, frames, height, width]
batch_size, channels, frames, height, width = x.shape
# Ensure height and width are divisible by patch_size
if height % patch_size != 0 or width % patch_size != 0:
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous()
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
# x shape: [b, (c * patch_size * patch_size), h, w]
batch_size, c_patches, height, width = x.shape
channels = c_patches // (patch_size * patch_size)
# Reshape to [b, c, patch_size, patch_size, h, w]
x = x.view(batch_size, channels, patch_size, patch_size, height, width)
# Rearrange to [b, c, h * patch_size, w * patch_size]
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(batch_size, channels, height * patch_size, width * patch_size)
elif x.dim() == 5:
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
batch_size, c_patches, frames, height, width = x.shape
channels = c_patches // (patch_size * patch_size)
# Reshape to [b, c, patch_size, patch_size, f, h, w]
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
x = x.permute(0, 1, 4, 5, 2, 6, 3).contiguous()
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
return x
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
@@ -671,6 +995,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def __init__(
self,
base_dim: int = 96,
decoder_base_dim: Optional[int] = None,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
@@ -713,6 +1038,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
2.8251,
1.9160,
],
is_residual: bool = False,
in_channels: int = 3,
out_channels: int = 3,
patch_size: Optional[int] = None,
scale_factor_temporal: Optional[int] = 4,
scale_factor_spatial: Optional[int] = 8,
clip_output: bool = True,
) -> None:
super().__init__()
@@ -720,14 +1052,33 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
if decoder_base_dim is None:
decoder_base_dim = base_dim
self.encoder = WanEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
in_channels=in_channels,
dim=base_dim,
z_dim=z_dim * 2,
dim_mult=dim_mult,
num_res_blocks=num_res_blocks,
attn_scales=attn_scales,
temperal_downsample=temperal_downsample,
dropout=dropout,
is_residual=is_residual,
)
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
self.decoder = WanDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
dim=decoder_base_dim,
z_dim=z_dim,
dim_mult=dim_mult,
num_res_blocks=num_res_blocks,
attn_scales=attn_scales,
temperal_upsample=self.temperal_upsample,
dropout=dropout,
out_channels=out_channels,
is_residual=is_residual,
)
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
@@ -827,6 +1178,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return self.tiled_encode(x)
self.clear_cache()
if self.config.patch_size is not None:
x = patchify(x, patch_size=self.config.patch_size)
iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
@@ -884,12 +1237,17 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for i in range(num_frame):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = self.decoder(
x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True
)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
out = torch.clamp(out, min=-1.0, max=1.0)
if self.config.clip_output:
out = torch.clamp(out, min=-1.0, max=1.0)
if self.config.patch_size is not None:
out = unpatchify(out, patch_size=self.config.patch_size)
self.clear_cache()
if not return_dict:
return (out,)

View File

@@ -170,8 +170,11 @@ class WanTimeTextImageEmbedding(nn.Module):
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
timestep_seq_len: Optional[int] = None,
):
timestep = self.timesteps_proj(timestep)
if timestep_seq_len is not None:
timestep = timestep.unflatten(0, (1, timestep_seq_len))
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
@@ -309,9 +312,23 @@ class WanTransformerBlock(nn.Module):
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
if temb.ndim == 4:
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table.unsqueeze(0) + temb.float()
).chunk(6, dim=2)
# batch_size, seq_len, 1, inner_dim
shift_msa = shift_msa.squeeze(2)
scale_msa = scale_msa.squeeze(2)
gate_msa = gate_msa.squeeze(2)
c_shift_msa = c_shift_msa.squeeze(2)
c_scale_msa = c_scale_msa.squeeze(2)
c_gate_msa = c_gate_msa.squeeze(2)
else:
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
@@ -469,10 +486,22 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
if timestep.ndim == 2:
ts_seq_len = timestep.shape[1]
timestep = timestep.flatten() # batch_size * seq_len
else:
ts_seq_len = None
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
)
timestep_proj = timestep_proj.unflatten(1, (6, -1))
if ts_seq_len is not None:
# batch_size, seq_len, 6, inner_dim
timestep_proj = timestep_proj.unflatten(2, (6, -1))
else:
# batch_size, 6, inner_dim
timestep_proj = timestep_proj.unflatten(1, (6, -1))
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
@@ -488,7 +517,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
if temb.ndim == 3:
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2)
scale = scale.squeeze(2)
else:
# batch_size, inner_dim
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the

View File

@@ -275,7 +275,6 @@ class SkyReelsV2Pipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs
def check_inputs(
self,
prompt,

View File

@@ -316,7 +316,6 @@ class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixi
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs
def check_inputs(
self,
prompt,

View File

@@ -112,10 +112,20 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer_2 ([`WanTransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
stages. If not provided, only `transformer` is used.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2"]
def __init__(
self,
@@ -124,6 +134,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer_2: Optional[WanTransformer3DModel] = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False, # Wan2.2 ti2v
):
super().__init__()
@@ -133,10 +146,12 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
transformer_2=transformer_2,
)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.register_to_config(boundary_ratio=boundary_ratio)
self.register_to_config(expand_timesteps=expand_timesteps)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
@@ -270,6 +285,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
guidance_scale_2=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -302,6 +318,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
def prepare_latents(
self,
batch_size: int,
@@ -369,6 +388,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -407,6 +427,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
guidance_scale_2 (`float`, *optional*, defaults to `None`):
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -461,6 +485,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
guidance_scale_2,
)
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -470,7 +495,11 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
guidance_scale_2 = guidance_scale
self._guidance_scale = guidance_scale
self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
@@ -520,21 +549,44 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents,
)
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
if self.config.boundary_ratio is not None:
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
else:
boundary_timestep = None
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
if boundary_timestep is None or t >= boundary_timestep:
# wan2.1 or high-noise stage in wan2.2
current_model = self.transformer
current_guidance_scale = guidance_scale
else:
# low-noise stage in wan2.2
current_model = self.transformer_2
current_guidance_scale = guidance_scale_2
latent_model_input = latents.to(transformer_dtype)
if self.config.expand_timesteps:
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
# batch_size, seq_len
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
else:
timestep = t.expand(latents.shape[0])
with current_model.cache_context("cond"):
noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
@@ -543,15 +595,15 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
)[0]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
with current_model.cache_context("uncond"):
noise_uncond = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

View File

@@ -149,20 +149,33 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer_2 ([`WanTransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
`transformer` is used.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2", "image_encoder", "image_processor"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
image_encoder: CLIPVisionModel,
image_processor: CLIPImageProcessor,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
image_processor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModel = None,
transformer_2: WanTransformer3DModel = None,
boundary_ratio: Optional[float] = None,
expand_timesteps: bool = False,
):
super().__init__()
@@ -174,10 +187,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
transformer=transformer,
scheduler=scheduler,
image_processor=image_processor,
transformer_2=transformer_2,
)
self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.image_processor = image_processor
@@ -325,6 +340,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
negative_prompt_embeds=None,
image_embeds=None,
callback_on_step_end_tensor_inputs=None,
guidance_scale_2=None,
):
if image is not None and image_embeds is not None:
raise ValueError(
@@ -368,6 +384,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
if self.config.boundary_ratio is not None and image_embeds is not None:
raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.")
def prepare_latents(
self,
image: PipelineImageInput,
@@ -398,8 +420,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
else:
latents = latents.to(device=device, dtype=dtype)
image = image.unsqueeze(2)
if last_image is None:
image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
if self.config.expand_timesteps:
video_condition = image
elif last_image is None:
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
)
@@ -432,6 +458,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_condition = latent_condition.to(dtype)
latent_condition = (latent_condition - latents_mean) * latents_std
if self.config.expand_timesteps:
first_frame_mask = torch.ones(
1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
)
first_frame_mask[:, :, 0] = 0
return latents, latent_condition, first_frame_mask
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
if last_image is None:
@@ -483,6 +516,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -527,6 +561,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
guidance_scale_2 (`float`, *optional*, defaults to `None`):
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -589,6 +627,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
negative_prompt_embeds,
image_embeds,
callback_on_step_end_tensor_inputs,
guidance_scale_2,
)
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -598,7 +637,11 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
guidance_scale_2 = guidance_scale
self._guidance_scale = guidance_scale
self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
@@ -631,13 +674,14 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
if image_embeds is None:
if last_image is None:
image_embeds = self.encode_image(image, device)
else:
image_embeds = self.encode_image([image, last_image], device)
image_embeds = image_embeds.repeat(batch_size, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)
if self.config.boundary_ratio is None and not self.config.expand_timesteps:
if image_embeds is None:
if last_image is None:
image_embeds = self.encode_image(image, device)
else:
image_embeds = self.encode_image([image, last_image], device)
image_embeds = image_embeds.repeat(batch_size, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -650,34 +694,74 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
device, dtype=torch.float32
)
latents, condition = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
last_image,
)
if self.config.expand_timesteps:
latents, condition, first_frame_mask = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
last_image,
)
else:
latents, condition = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
last_image,
)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
if self.config.boundary_ratio is not None:
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
else:
boundary_timestep = None
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
timestep = t.expand(latents.shape[0])
noise_pred = self.transformer(
if boundary_timestep is None or t >= boundary_timestep:
# wan2.1 or high-noise stage in wan2.2
current_model = self.transformer
current_guidance_scale = guidance_scale
else:
# low-noise stage in wan2.2
current_model = self.transformer_2
current_guidance_scale = guidance_scale_2
if self.config.expand_timesteps:
latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
latent_model_input = latent_model_input.to(transformer_dtype)
# seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
# batch_size, seq_len
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
else:
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
timestep = t.expand(latents.shape[0])
noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
@@ -687,7 +771,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
noise_uncond = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
@@ -695,7 +779,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -719,6 +803,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
self._current_timestep = None
if self.config.expand_timesteps:
latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (

View File

@@ -85,12 +85,29 @@ class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
rope_max_seq_len=32,
)
torch.manual_seed(0)
transformer_2 = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=16,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer_2": transformer_2,
}
return components

View File

@@ -86,6 +86,23 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_dim=4,
)
torch.manual_seed(0)
transformer_2 = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=36,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
image_dim=4,
)
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=4,
@@ -109,6 +126,7 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"tokenizer": tokenizer,
"image_encoder": image_encoder,
"image_processor": image_processor,
"transformer_2": transformer_2,
}
return components
@@ -164,6 +182,12 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_inference_batch_single_identical(self):
pass
@unittest.skip(
"TODO: refactor this test: one component can be optional for certain checkpoints but not for others"
)
def test_save_load_optional_components(self):
pass
class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WanImageToVideoPipeline
@@ -218,6 +242,24 @@ class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pos_embed_seq_len=2 * (4 * 4 + 1),
)
torch.manual_seed(0)
transformer_2 = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=36,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
image_dim=4,
pos_embed_seq_len=2 * (4 * 4 + 1),
)
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=4,
@@ -241,6 +283,7 @@ class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"tokenizer": tokenizer,
"image_encoder": image_encoder,
"image_processor": image_processor,
"transformer_2": transformer_2,
}
return components
@@ -297,3 +340,9 @@ class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
def test_inference_batch_single_identical(self):
pass
@unittest.skip(
"TODO: refactor this test: one component can be optional for certain checkpoints but not for others"
)
def test_save_load_optional_components(self):
pass