mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-09 05:54:24 +08:00
Compare commits
14 Commits
cached-lor
...
wan-5bi2v
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5512d70f7f | ||
|
|
4fd0333b17 | ||
|
|
c0af998473 | ||
|
|
1ff7c99596 | ||
|
|
6bb1677bd4 | ||
|
|
97675c7036 | ||
|
|
e78a4fa30a | ||
|
|
0ab2e4fd91 | ||
|
|
5709f7e04d | ||
|
|
27ce75b984 | ||
|
|
bf2c6e0a08 | ||
|
|
95a55f9dc0 | ||
|
|
f5da83c4a3 | ||
|
|
7ba78d5ce5 |
@@ -278,16 +278,82 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-I2V-14B-720p":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 36,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-T2V-A14B":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.2-T2V-A14B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-TI2V-5B":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.2-TI2V-5B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 14336,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 48,
|
||||
"num_attention_heads": 24,
|
||||
"num_layers": 30,
|
||||
"out_channels": 48,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
|
||||
|
||||
|
||||
def convert_transformer(model_type: str):
|
||||
def convert_transformer(model_type: str, stage: str = None):
|
||||
config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)
|
||||
|
||||
diffusers_config = config["diffusers_config"]
|
||||
model_id = config["model_id"]
|
||||
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
|
||||
|
||||
if stage is not None:
|
||||
model_dir = model_dir / stage
|
||||
|
||||
original_state_dict = load_sharded_safetensors(model_dir)
|
||||
|
||||
with init_empty_weights():
|
||||
@@ -515,6 +581,310 @@ def convert_vae():
|
||||
return vae
|
||||
|
||||
|
||||
vae22_diffusers_config = {
|
||||
"base_dim": 160,
|
||||
"z_dim": 48,
|
||||
"is_residual": True,
|
||||
"in_channels": 12,
|
||||
"out_channels": 12,
|
||||
"decoder_base_dim": 256,
|
||||
"scale_factor_temporal": 4,
|
||||
"scale_factor_spatial": 16,
|
||||
"patch_size": 2,
|
||||
"latents_mean": [
|
||||
-0.2289,
|
||||
-0.0052,
|
||||
-0.1323,
|
||||
-0.2339,
|
||||
-0.2799,
|
||||
0.0174,
|
||||
0.1838,
|
||||
0.1557,
|
||||
-0.1382,
|
||||
0.0542,
|
||||
0.2813,
|
||||
0.0891,
|
||||
0.1570,
|
||||
-0.0098,
|
||||
0.0375,
|
||||
-0.1825,
|
||||
-0.2246,
|
||||
-0.1207,
|
||||
-0.0698,
|
||||
0.5109,
|
||||
0.2665,
|
||||
-0.2108,
|
||||
-0.2158,
|
||||
0.2502,
|
||||
-0.2055,
|
||||
-0.0322,
|
||||
0.1109,
|
||||
0.1567,
|
||||
-0.0729,
|
||||
0.0899,
|
||||
-0.2799,
|
||||
-0.1230,
|
||||
-0.0313,
|
||||
-0.1649,
|
||||
0.0117,
|
||||
0.0723,
|
||||
-0.2839,
|
||||
-0.2083,
|
||||
-0.0520,
|
||||
0.3748,
|
||||
0.0152,
|
||||
0.1957,
|
||||
0.1433,
|
||||
-0.2944,
|
||||
0.3573,
|
||||
-0.0548,
|
||||
-0.1681,
|
||||
-0.0667,
|
||||
],
|
||||
"latents_std": [
|
||||
0.4765,
|
||||
1.0364,
|
||||
0.4514,
|
||||
1.1677,
|
||||
0.5313,
|
||||
0.4990,
|
||||
0.4818,
|
||||
0.5013,
|
||||
0.8158,
|
||||
1.0344,
|
||||
0.5894,
|
||||
1.0901,
|
||||
0.6885,
|
||||
0.6165,
|
||||
0.8454,
|
||||
0.4978,
|
||||
0.5759,
|
||||
0.3523,
|
||||
0.7135,
|
||||
0.6804,
|
||||
0.5833,
|
||||
1.4146,
|
||||
0.8986,
|
||||
0.5659,
|
||||
0.7069,
|
||||
0.5338,
|
||||
0.4889,
|
||||
0.4917,
|
||||
0.4069,
|
||||
0.4999,
|
||||
0.6866,
|
||||
0.4093,
|
||||
0.5709,
|
||||
0.6065,
|
||||
0.6415,
|
||||
0.4944,
|
||||
0.5726,
|
||||
1.2042,
|
||||
0.5458,
|
||||
1.6887,
|
||||
0.3971,
|
||||
1.0600,
|
||||
0.3943,
|
||||
0.5537,
|
||||
0.5444,
|
||||
0.4089,
|
||||
0.7468,
|
||||
0.7744,
|
||||
],
|
||||
"clip_output": False,
|
||||
}
|
||||
|
||||
|
||||
def convert_vae_22():
|
||||
vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.2-TI2V-5B", "Wan2.2_VAE.pth")
|
||||
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
|
||||
new_state_dict = {}
|
||||
|
||||
# Create mappings for specific components
|
||||
middle_key_mapping = {
|
||||
# Encoder middle block
|
||||
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
|
||||
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
|
||||
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
|
||||
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
|
||||
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
|
||||
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
|
||||
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
|
||||
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
|
||||
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
|
||||
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
|
||||
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
|
||||
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
|
||||
# Decoder middle block
|
||||
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
|
||||
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
|
||||
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
|
||||
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
|
||||
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
|
||||
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
|
||||
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
|
||||
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
|
||||
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
|
||||
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
|
||||
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
|
||||
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
|
||||
}
|
||||
|
||||
# Create a mapping for attention blocks
|
||||
attention_mapping = {
|
||||
# Encoder middle attention
|
||||
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
|
||||
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
|
||||
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
|
||||
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
|
||||
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
|
||||
# Decoder middle attention
|
||||
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
|
||||
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
|
||||
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
|
||||
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
|
||||
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
|
||||
}
|
||||
|
||||
# Create a mapping for the head components
|
||||
head_mapping = {
|
||||
# Encoder head
|
||||
"encoder.head.0.gamma": "encoder.norm_out.gamma",
|
||||
"encoder.head.2.bias": "encoder.conv_out.bias",
|
||||
"encoder.head.2.weight": "encoder.conv_out.weight",
|
||||
# Decoder head
|
||||
"decoder.head.0.gamma": "decoder.norm_out.gamma",
|
||||
"decoder.head.2.bias": "decoder.conv_out.bias",
|
||||
"decoder.head.2.weight": "decoder.conv_out.weight",
|
||||
}
|
||||
|
||||
# Create a mapping for the quant components
|
||||
quant_mapping = {
|
||||
"conv1.weight": "quant_conv.weight",
|
||||
"conv1.bias": "quant_conv.bias",
|
||||
"conv2.weight": "post_quant_conv.weight",
|
||||
"conv2.bias": "post_quant_conv.bias",
|
||||
}
|
||||
|
||||
# Process each key in the state dict
|
||||
for key, value in old_state_dict.items():
|
||||
# Handle middle block keys using the mapping
|
||||
if key in middle_key_mapping:
|
||||
new_key = middle_key_mapping[key]
|
||||
new_state_dict[new_key] = value
|
||||
# Handle attention blocks using the mapping
|
||||
elif key in attention_mapping:
|
||||
new_key = attention_mapping[key]
|
||||
new_state_dict[new_key] = value
|
||||
# Handle head keys using the mapping
|
||||
elif key in head_mapping:
|
||||
new_key = head_mapping[key]
|
||||
new_state_dict[new_key] = value
|
||||
# Handle quant keys using the mapping
|
||||
elif key in quant_mapping:
|
||||
new_key = quant_mapping[key]
|
||||
new_state_dict[new_key] = value
|
||||
# Handle encoder conv1
|
||||
elif key == "encoder.conv1.weight":
|
||||
new_state_dict["encoder.conv_in.weight"] = value
|
||||
elif key == "encoder.conv1.bias":
|
||||
new_state_dict["encoder.conv_in.bias"] = value
|
||||
# Handle decoder conv1
|
||||
elif key == "decoder.conv1.weight":
|
||||
new_state_dict["decoder.conv_in.weight"] = value
|
||||
elif key == "decoder.conv1.bias":
|
||||
new_state_dict["decoder.conv_in.bias"] = value
|
||||
# Handle encoder downsamples
|
||||
elif key.startswith("encoder.downsamples."):
|
||||
# Change encoder.downsamples to encoder.down_blocks
|
||||
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
|
||||
|
||||
# Handle residual blocks - change downsamples to resnets and rename components
|
||||
if "residual" in new_key or "shortcut" in new_key:
|
||||
# Change the second downsamples to resnets
|
||||
new_key = new_key.replace(".downsamples.", ".resnets.")
|
||||
|
||||
# Rename residual components
|
||||
if ".residual.0.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
||||
elif ".residual.2.weight" in new_key:
|
||||
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
|
||||
elif ".residual.2.bias" in new_key:
|
||||
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
|
||||
elif ".residual.3.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
|
||||
elif ".residual.6.weight" in new_key:
|
||||
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
|
||||
elif ".residual.6.bias" in new_key:
|
||||
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
|
||||
elif ".shortcut.weight" in new_key:
|
||||
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
||||
elif ".shortcut.bias" in new_key:
|
||||
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
||||
|
||||
# Handle resample blocks - change downsamples to downsampler and remove index
|
||||
elif "resample" in new_key or "time_conv" in new_key:
|
||||
# Change the second downsamples to downsampler and remove the index
|
||||
parts = new_key.split(".")
|
||||
# Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
|
||||
# We want to change it to: encoder.down_blocks.X.downsampler.resample...
|
||||
if len(parts) >= 4 and parts[3] == "downsamples":
|
||||
# Remove the index (parts[4]) and change downsamples to downsampler
|
||||
new_parts = parts[:3] + ["downsampler"] + parts[5:]
|
||||
new_key = ".".join(new_parts)
|
||||
|
||||
new_state_dict[new_key] = value
|
||||
|
||||
# Handle decoder upsamples
|
||||
elif key.startswith("decoder.upsamples."):
|
||||
# Change decoder.upsamples to decoder.up_blocks
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
|
||||
# Handle residual blocks - change upsamples to resnets and rename components
|
||||
if "residual" in new_key or "shortcut" in new_key:
|
||||
# Change the second upsamples to resnets
|
||||
new_key = new_key.replace(".upsamples.", ".resnets.")
|
||||
|
||||
# Rename residual components
|
||||
if ".residual.0.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
||||
elif ".residual.2.weight" in new_key:
|
||||
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
|
||||
elif ".residual.2.bias" in new_key:
|
||||
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
|
||||
elif ".residual.3.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
|
||||
elif ".residual.6.weight" in new_key:
|
||||
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
|
||||
elif ".residual.6.bias" in new_key:
|
||||
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
|
||||
elif ".shortcut.weight" in new_key:
|
||||
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
||||
elif ".shortcut.bias" in new_key:
|
||||
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
||||
|
||||
# Handle resample blocks - change upsamples to upsampler and remove index
|
||||
elif "resample" in new_key or "time_conv" in new_key:
|
||||
# Change the second upsamples to upsampler and remove the index
|
||||
parts = new_key.split(".")
|
||||
# Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
|
||||
# We want to change it to: encoder.down_blocks.X.downsampler.resample...
|
||||
if len(parts) >= 4 and parts[3] == "upsamples":
|
||||
# Remove the index (parts[4]) and change upsamples to upsampler
|
||||
new_parts = parts[:3] + ["upsampler"] + parts[5:]
|
||||
new_key = ".".join(new_parts)
|
||||
|
||||
new_state_dict[new_key] = value
|
||||
else:
|
||||
# Keep other keys unchanged
|
||||
new_state_dict[key] = value
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLWan(**vae22_diffusers_config)
|
||||
vae.load_state_dict(new_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
@@ -533,11 +903,26 @@ DTYPE_MAPPING = {
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = convert_transformer(args.model_type)
|
||||
vae = convert_vae()
|
||||
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
|
||||
transformer = convert_transformer(args.model_type, stage="high_noise_model")
|
||||
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
|
||||
else:
|
||||
transformer = convert_transformer(args.model_type)
|
||||
transformer_2 = None
|
||||
|
||||
if "Wan2.2" in args.model_type and "TI2V" in args.model_type:
|
||||
vae = convert_vae_22()
|
||||
else:
|
||||
vae = convert_vae()
|
||||
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
|
||||
if "FLF2V" in args.model_type:
|
||||
flow_shift = 16.0
|
||||
elif "TI2V" in args.model_type:
|
||||
flow_shift = 5.0
|
||||
else:
|
||||
flow_shift = 3.0
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
|
||||
)
|
||||
@@ -547,7 +932,36 @@ if __name__ == "__main__":
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
transformer.to(dtype)
|
||||
|
||||
if "I2V" in args.model_type or "FLF2V" in args.model_type:
|
||||
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
|
||||
pipe = WanImageToVideoPipeline(
|
||||
transformer=transformer,
|
||||
transformer_2=transformer_2,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
boundary_ratio=0.9,
|
||||
)
|
||||
elif "Wan2.2" and "T2V" in args.model_type:
|
||||
pipe = WanPipeline(
|
||||
transformer=transformer,
|
||||
transformer_2=transformer_2,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
boundary_ratio=0.875,
|
||||
)
|
||||
elif "Wan2.2" and "TI2V" in args.model_type:
|
||||
pipe = WanPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
expand_timesteps=True,
|
||||
)
|
||||
elif "I2V" in args.model_type or "FLF2V" in args.model_type:
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
@@ -34,6 +34,103 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
class AvgDown3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
factor_t,
|
||||
factor_s=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||
|
||||
assert in_channels * self.factor % out_channels == 0
|
||||
self.group_size = in_channels * self.factor // out_channels
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
||||
pad = (0, 0, 0, 0, pad_t, 0)
|
||||
x = F.pad(x, pad)
|
||||
B, C, T, H, W = x.shape
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
T // self.factor_t,
|
||||
self.factor_t,
|
||||
H // self.factor_s,
|
||||
self.factor_s,
|
||||
W // self.factor_s,
|
||||
self.factor_s,
|
||||
)
|
||||
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
||||
x = x.view(
|
||||
B,
|
||||
C * self.factor,
|
||||
T // self.factor_t,
|
||||
H // self.factor_s,
|
||||
W // self.factor_s,
|
||||
)
|
||||
x = x.view(
|
||||
B,
|
||||
self.out_channels,
|
||||
self.group_size,
|
||||
T // self.factor_t,
|
||||
H // self.factor_s,
|
||||
W // self.factor_s,
|
||||
)
|
||||
x = x.mean(dim=2)
|
||||
return x
|
||||
|
||||
|
||||
class DupUp3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
factor_t,
|
||||
factor_s=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = self.factor_t * self.factor_s * self.factor_s
|
||||
|
||||
assert out_channels * self.factor % in_channels == 0
|
||||
self.repeats = out_channels * self.factor // in_channels
|
||||
|
||||
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
||||
x = x.repeat_interleave(self.repeats, dim=1)
|
||||
x = x.view(
|
||||
x.size(0),
|
||||
self.out_channels,
|
||||
self.factor_t,
|
||||
self.factor_s,
|
||||
self.factor_s,
|
||||
x.size(2),
|
||||
x.size(3),
|
||||
x.size(4),
|
||||
)
|
||||
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
||||
x = x.view(
|
||||
x.size(0),
|
||||
self.out_channels,
|
||||
x.size(2) * self.factor_t,
|
||||
x.size(4) * self.factor_s,
|
||||
x.size(6) * self.factor_s,
|
||||
)
|
||||
if first_chunk:
|
||||
x = x[:, :, self.factor_t - 1 :, :, :]
|
||||
return x
|
||||
|
||||
|
||||
class WanCausalConv3d(nn.Conv3d):
|
||||
r"""
|
||||
A custom 3D causal convolution layer with feature caching support.
|
||||
@@ -134,19 +231,25 @@ class WanResample(nn.Module):
|
||||
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, mode: str) -> None:
|
||||
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# default to dim //2
|
||||
if upsample_out_dim is None:
|
||||
upsample_out_dim = dim // 2
|
||||
|
||||
# layers
|
||||
if mode == "upsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
|
||||
)
|
||||
elif mode == "upsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
||||
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
|
||||
)
|
||||
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
@@ -363,6 +466,42 @@ class WanMidBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class WanResidualDownBlock(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False):
|
||||
super().__init__()
|
||||
|
||||
# Shortcut path with downsample
|
||||
self.avg_shortcut = AvgDown3D(
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_downsample else 1,
|
||||
factor_s=2 if down_flag else 1,
|
||||
)
|
||||
|
||||
# Main path with residual blocks and downsample
|
||||
resnets = []
|
||||
for _ in range(num_res_blocks):
|
||||
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||
in_dim = out_dim
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
# Add the final downsample block
|
||||
if down_flag:
|
||||
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||
self.downsampler = WanResample(out_dim, mode=mode)
|
||||
else:
|
||||
self.downsampler = None
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
x_copy = x.clone()
|
||||
for resnet in self.resnets:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
if self.downsampler is not None:
|
||||
x = self.downsampler(x, feat_cache, feat_idx)
|
||||
|
||||
return x + self.avg_shortcut(x_copy)
|
||||
|
||||
|
||||
class WanEncoder3d(nn.Module):
|
||||
r"""
|
||||
A 3D encoder module.
|
||||
@@ -380,6 +519,7 @@ class WanEncoder3d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
@@ -388,6 +528,7 @@ class WanEncoder3d(nn.Module):
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
is_residual: bool = False, # wan 2.2 vae use a residual downblock
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -403,23 +544,35 @@ class WanEncoder3d(nn.Module):
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
|
||||
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
self.down_blocks.append(WanAttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
if is_residual:
|
||||
self.down_blocks.append(
|
||||
WanResidualDownBlock(
|
||||
in_dim,
|
||||
out_dim,
|
||||
dropout,
|
||||
num_res_blocks,
|
||||
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
|
||||
down_flag=i != len(dim_mult) - 1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
for _ in range(num_res_blocks):
|
||||
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
self.down_blocks.append(WanAttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
||||
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
||||
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
||||
@@ -470,6 +623,94 @@ class WanEncoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class WanResidualUpBlock(nn.Module):
|
||||
"""
|
||||
A block that handles upsampling for the WanVAE decoder.
|
||||
|
||||
Args:
|
||||
in_dim (int): Input dimension
|
||||
out_dim (int): Output dimension
|
||||
num_res_blocks (int): Number of residual blocks
|
||||
dropout (float): Dropout rate
|
||||
temperal_upsample (bool): Whether to upsample on temporal dimension
|
||||
up_flag (bool): Whether to upsample or not
|
||||
non_linearity (str): Type of non-linearity to use
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
num_res_blocks: int,
|
||||
dropout: float = 0.0,
|
||||
temperal_upsample: bool = False,
|
||||
up_flag: bool = False,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
if up_flag:
|
||||
self.avg_shortcut = DupUp3D(
|
||||
in_dim,
|
||||
out_dim,
|
||||
factor_t=2 if temperal_upsample else 1,
|
||||
factor_s=2,
|
||||
)
|
||||
else:
|
||||
self.avg_shortcut = None
|
||||
|
||||
# create residual blocks
|
||||
resnets = []
|
||||
current_dim = in_dim
|
||||
for _ in range(num_res_blocks + 1):
|
||||
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
||||
current_dim = out_dim
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
# Add upsampling layer if needed
|
||||
if up_flag:
|
||||
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
|
||||
self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
|
||||
else:
|
||||
self.upsampler = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
"""
|
||||
Forward pass through the upsampling block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor
|
||||
feat_cache (list, optional): Feature cache for causal convolutions
|
||||
feat_idx (list, optional): Feature index for cache management
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor
|
||||
"""
|
||||
x_copy = x.clone()
|
||||
|
||||
for resnet in self.resnets:
|
||||
if feat_cache is not None:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = resnet(x)
|
||||
|
||||
if self.upsampler is not None:
|
||||
if feat_cache is not None:
|
||||
x = self.upsampler(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = self.upsampler(x)
|
||||
|
||||
if self.avg_shortcut is not None:
|
||||
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanUpBlock(nn.Module):
|
||||
"""
|
||||
A block that handles upsampling for the WanVAE decoder.
|
||||
@@ -513,7 +754,7 @@ class WanUpBlock(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
|
||||
"""
|
||||
Forward pass through the upsampling block.
|
||||
|
||||
@@ -564,6 +805,8 @@ class WanDecoder3d(nn.Module):
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
out_channels: int = 3,
|
||||
is_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -577,7 +820,6 @@ class WanDecoder3d(nn.Module):
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
@@ -589,36 +831,47 @@ class WanDecoder3d(nn.Module):
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i > 0:
|
||||
if i > 0 and not is_residual:
|
||||
# wan vae 2.1
|
||||
in_dim = in_dim // 2
|
||||
|
||||
# Determine if we need upsampling
|
||||
# determine if we need upsampling
|
||||
up_flag = i != len(dim_mult) - 1
|
||||
# determine upsampling mode, if not upsampling, set to None
|
||||
upsample_mode = None
|
||||
if i != len(dim_mult) - 1:
|
||||
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
||||
|
||||
if up_flag and temperal_upsample[i]:
|
||||
upsample_mode = "upsample3d"
|
||||
elif up_flag:
|
||||
upsample_mode = "upsample2d"
|
||||
# Create and add the upsampling block
|
||||
up_block = WanUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
upsample_mode=upsample_mode,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
if is_residual:
|
||||
up_block = WanResidualUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
temperal_upsample=temperal_upsample[i] if up_flag else False,
|
||||
up_flag=up_flag,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
else:
|
||||
up_block = WanUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
upsample_mode=upsample_mode,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# Update scale for next iteration
|
||||
if upsample_mode is not None:
|
||||
scale *= 2.0
|
||||
|
||||
# output blocks
|
||||
self.norm_out = WanRMS_norm(out_dim, images=False)
|
||||
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
|
||||
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
@@ -637,7 +890,7 @@ class WanDecoder3d(nn.Module):
|
||||
|
||||
## upsamples
|
||||
for up_block in self.up_blocks:
|
||||
x = up_block(x, feat_cache, feat_idx)
|
||||
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
@@ -656,6 +909,77 @@ class WanDecoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def patchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
# x shape: [batch_size, channels, height, width]
|
||||
batch_size, channels, height, width = x.shape
|
||||
|
||||
# Ensure height and width are divisible by patch_size
|
||||
if height % patch_size != 0 or width % patch_size != 0:
|
||||
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
|
||||
|
||||
# Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size]
|
||||
x = x.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
|
||||
|
||||
# Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size]
|
||||
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
|
||||
x = x.view(batch_size, channels * patch_size * patch_size, height // patch_size, width // patch_size)
|
||||
|
||||
elif x.dim() == 5:
|
||||
# x shape: [batch_size, channels, frames, height, width]
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
|
||||
# Ensure height and width are divisible by patch_size
|
||||
if height % patch_size != 0 or width % patch_size != 0:
|
||||
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
|
||||
|
||||
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
|
||||
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
|
||||
|
||||
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
|
||||
x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous()
|
||||
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid input shape: {x.shape}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x, patch_size):
|
||||
if patch_size == 1:
|
||||
return x
|
||||
|
||||
if x.dim() == 4:
|
||||
# x shape: [b, (c * patch_size * patch_size), h, w]
|
||||
batch_size, c_patches, height, width = x.shape
|
||||
channels = c_patches // (patch_size * patch_size)
|
||||
|
||||
# Reshape to [b, c, patch_size, patch_size, h, w]
|
||||
x = x.view(batch_size, channels, patch_size, patch_size, height, width)
|
||||
|
||||
# Rearrange to [b, c, h * patch_size, w * patch_size]
|
||||
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
|
||||
x = x.view(batch_size, channels, height * patch_size, width * patch_size)
|
||||
|
||||
elif x.dim() == 5:
|
||||
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
|
||||
batch_size, c_patches, frames, height, width = x.shape
|
||||
channels = c_patches // (patch_size * patch_size)
|
||||
|
||||
# Reshape to [b, c, patch_size, patch_size, f, h, w]
|
||||
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
|
||||
|
||||
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
|
||||
x = x.permute(0, 1, 4, 5, 2, 6, 3).contiguous()
|
||||
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
||||
@@ -671,6 +995,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
def __init__(
|
||||
self,
|
||||
base_dim: int = 96,
|
||||
decoder_base_dim: Optional[int] = None,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
@@ -713,6 +1038,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
2.8251,
|
||||
1.9160,
|
||||
],
|
||||
is_residual: bool = False,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
patch_size: Optional[int] = None,
|
||||
scale_factor_temporal: Optional[int] = 4,
|
||||
scale_factor_spatial: Optional[int] = 8,
|
||||
clip_output: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -720,14 +1052,33 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
if decoder_base_dim is None:
|
||||
decoder_base_dim = base_dim
|
||||
|
||||
self.encoder = WanEncoder3d(
|
||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
||||
in_channels=in_channels,
|
||||
dim=base_dim,
|
||||
z_dim=z_dim * 2,
|
||||
dim_mult=dim_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_scales=attn_scales,
|
||||
temperal_downsample=temperal_downsample,
|
||||
dropout=dropout,
|
||||
is_residual=is_residual,
|
||||
)
|
||||
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
||||
|
||||
self.decoder = WanDecoder3d(
|
||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
||||
dim=decoder_base_dim,
|
||||
z_dim=z_dim,
|
||||
dim_mult=dim_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_scales=attn_scales,
|
||||
temperal_upsample=self.temperal_upsample,
|
||||
dropout=dropout,
|
||||
out_channels=out_channels,
|
||||
is_residual=is_residual,
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
||||
@@ -827,6 +1178,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
self.clear_cache()
|
||||
if self.config.patch_size is not None:
|
||||
x = patchify(x, patch_size=self.config.patch_size)
|
||||
iter_ = 1 + (num_frame - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
@@ -884,12 +1237,17 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for i in range(num_frame):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
out = self.decoder(
|
||||
x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True
|
||||
)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
out = torch.clamp(out, min=-1.0, max=1.0)
|
||||
if self.config.clip_output:
|
||||
out = torch.clamp(out, min=-1.0, max=1.0)
|
||||
if self.config.patch_size is not None:
|
||||
out = unpatchify(out, patch_size=self.config.patch_size)
|
||||
self.clear_cache()
|
||||
if not return_dict:
|
||||
return (out,)
|
||||
|
||||
@@ -170,8 +170,11 @@ class WanTimeTextImageEmbedding(nn.Module):
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
timestep_seq_len: Optional[int] = None,
|
||||
):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
if timestep_seq_len is not None:
|
||||
timestep = timestep.unflatten(0, (1, timestep_seq_len))
|
||||
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
@@ -309,9 +312,23 @@ class WanTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
if temb.ndim == 4:
|
||||
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table.unsqueeze(0) + temb.float()
|
||||
).chunk(6, dim=2)
|
||||
# batch_size, seq_len, 1, inner_dim
|
||||
shift_msa = shift_msa.squeeze(2)
|
||||
scale_msa = scale_msa.squeeze(2)
|
||||
gate_msa = gate_msa.squeeze(2)
|
||||
c_shift_msa = c_shift_msa.squeeze(2)
|
||||
c_scale_msa = c_scale_msa.squeeze(2)
|
||||
c_gate_msa = c_gate_msa.squeeze(2)
|
||||
else:
|
||||
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
@@ -469,10 +486,22 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
|
||||
if timestep.ndim == 2:
|
||||
ts_seq_len = timestep.shape[1]
|
||||
timestep = timestep.flatten() # batch_size * seq_len
|
||||
else:
|
||||
ts_seq_len = None
|
||||
|
||||
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
||||
timestep, encoder_hidden_states, encoder_hidden_states_image
|
||||
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
|
||||
)
|
||||
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
||||
if ts_seq_len is not None:
|
||||
# batch_size, seq_len, 6, inner_dim
|
||||
timestep_proj = timestep_proj.unflatten(2, (6, -1))
|
||||
else:
|
||||
# batch_size, 6, inner_dim
|
||||
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
||||
|
||||
if encoder_hidden_states_image is not None:
|
||||
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
||||
@@ -488,7 +517,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
||||
|
||||
# 5. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
if temb.ndim == 3:
|
||||
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
|
||||
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
|
||||
shift = shift.squeeze(2)
|
||||
scale = scale.squeeze(2)
|
||||
else:
|
||||
# batch_size, inner_dim
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
# Move the shift and scale tensors to the same device as hidden_states.
|
||||
# When using multi-GPU inference via accelerate these will be on the
|
||||
|
||||
@@ -275,7 +275,6 @@ class SkyReelsV2Pipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
|
||||
@@ -316,7 +316,6 @@ class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixi
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
|
||||
@@ -112,10 +112,20 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer_2 ([`WanTransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
|
||||
two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
|
||||
stages. If not provided, only `transformer` is used.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
||||
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
||||
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -124,6 +134,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer_2: Optional[WanTransformer3DModel] = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
expand_timesteps: bool = False, # Wan2.2 ti2v
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -133,10 +146,12 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
transformer_2=transformer_2,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.register_to_config(boundary_ratio=boundary_ratio)
|
||||
self.register_to_config(expand_timesteps=expand_timesteps)
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
@@ -270,6 +285,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
@@ -302,6 +318,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
||||
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
@@ -369,6 +388,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -407,6 +427,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
||||
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
||||
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
||||
and the pipeline's `boundary_ratio` are not None.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -461,6 +485,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
guidance_scale_2,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
@@ -470,7 +495,11 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_scale_2 = guidance_scale_2
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
@@ -520,21 +549,44 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latents,
|
||||
)
|
||||
|
||||
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
boundary_timestep = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
if boundary_timestep is None or t >= boundary_timestep:
|
||||
# wan2.1 or high-noise stage in wan2.2
|
||||
current_model = self.transformer
|
||||
current_guidance_scale = guidance_scale
|
||||
else:
|
||||
# low-noise stage in wan2.2
|
||||
current_model = self.transformer_2
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
if self.config.expand_timesteps:
|
||||
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
|
||||
# batch_size, seq_len
|
||||
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
||||
else:
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
with current_model.cache_context("cond"):
|
||||
noise_pred = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
@@ -543,15 +595,15 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
noise_uncond = self.transformer(
|
||||
with current_model.cache_context("uncond"):
|
||||
noise_uncond = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -149,20 +149,33 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer_2 ([`WanTransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
|
||||
`transformer` is used.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
||||
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
||||
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2", "image_encoder", "image_processor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
image_encoder: CLIPVisionModel,
|
||||
image_processor: CLIPImageProcessor,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
image_processor: CLIPImageProcessor = None,
|
||||
image_encoder: CLIPVisionModel = None,
|
||||
transformer_2: WanTransformer3DModel = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
expand_timesteps: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -174,10 +187,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
image_processor=image_processor,
|
||||
transformer_2=transformer_2,
|
||||
)
|
||||
self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
self.image_processor = image_processor
|
||||
|
||||
@@ -325,6 +340,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
negative_prompt_embeds=None,
|
||||
image_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
if image is not None and image_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -368,6 +384,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
||||
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
||||
|
||||
if self.config.boundary_ratio is not None and image_embeds is not None:
|
||||
raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
@@ -398,8 +420,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
image = image.unsqueeze(2)
|
||||
if last_image is None:
|
||||
image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
|
||||
|
||||
if self.config.expand_timesteps:
|
||||
video_condition = image
|
||||
|
||||
elif last_image is None:
|
||||
video_condition = torch.cat(
|
||||
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
|
||||
)
|
||||
@@ -432,6 +458,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latent_condition = latent_condition.to(dtype)
|
||||
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||
|
||||
if self.config.expand_timesteps:
|
||||
first_frame_mask = torch.ones(
|
||||
1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
|
||||
)
|
||||
first_frame_mask[:, :, 0] = 0
|
||||
return latents, latent_condition, first_frame_mask
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
|
||||
if last_image is None:
|
||||
@@ -483,6 +516,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -527,6 +561,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
||||
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
||||
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
||||
and the pipeline's `boundary_ratio` are not None.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -589,6 +627,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
negative_prompt_embeds,
|
||||
image_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
guidance_scale_2,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
@@ -598,7 +637,11 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_scale_2 = guidance_scale_2
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
@@ -631,13 +674,14 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
if image_embeds is None:
|
||||
if last_image is None:
|
||||
image_embeds = self.encode_image(image, device)
|
||||
else:
|
||||
image_embeds = self.encode_image([image, last_image], device)
|
||||
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
||||
image_embeds = image_embeds.to(transformer_dtype)
|
||||
if self.config.boundary_ratio is None and not self.config.expand_timesteps:
|
||||
if image_embeds is None:
|
||||
if last_image is None:
|
||||
image_embeds = self.encode_image(image, device)
|
||||
else:
|
||||
image_embeds = self.encode_image([image, last_image], device)
|
||||
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
||||
image_embeds = image_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -650,34 +694,74 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
|
||||
device, dtype=torch.float32
|
||||
)
|
||||
latents, condition = self.prepare_latents(
|
||||
image,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
last_image,
|
||||
)
|
||||
|
||||
if self.config.expand_timesteps:
|
||||
latents, condition, first_frame_mask = self.prepare_latents(
|
||||
image,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
last_image,
|
||||
)
|
||||
else:
|
||||
latents, condition = self.prepare_latents(
|
||||
image,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
last_image,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
boundary_timestep = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
if boundary_timestep is None or t >= boundary_timestep:
|
||||
# wan2.1 or high-noise stage in wan2.2
|
||||
current_model = self.transformer
|
||||
current_guidance_scale = guidance_scale
|
||||
else:
|
||||
# low-noise stage in wan2.2
|
||||
current_model = self.transformer_2
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
if self.config.expand_timesteps:
|
||||
latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
|
||||
latent_model_input = latent_model_input.to(transformer_dtype)
|
||||
|
||||
# seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
|
||||
temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
|
||||
# batch_size, seq_len
|
||||
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
||||
else:
|
||||
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
@@ -687,7 +771,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
noise_uncond = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
@@ -695,7 +779,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
@@ -719,6 +803,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if self.config.expand_timesteps:
|
||||
latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
|
||||
@@ -85,12 +85,29 @@ class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
rope_max_seq_len=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer_2 = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer_2": transformer_2,
|
||||
}
|
||||
return components
|
||||
|
||||
|
||||
@@ -86,6 +86,23 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_dim=4,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer_2 = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=36,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
image_dim=4,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_encoder_config = CLIPVisionConfig(
|
||||
hidden_size=4,
|
||||
@@ -109,6 +126,7 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"tokenizer": tokenizer,
|
||||
"image_encoder": image_encoder,
|
||||
"image_processor": image_processor,
|
||||
"transformer_2": transformer_2,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -164,6 +182,12 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"TODO: refactor this test: one component can be optional for certain checkpoints but not for others"
|
||||
)
|
||||
def test_save_load_optional_components(self):
|
||||
pass
|
||||
|
||||
|
||||
class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = WanImageToVideoPipeline
|
||||
@@ -218,6 +242,24 @@ class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pos_embed_seq_len=2 * (4 * 4 + 1),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer_2 = WanTransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=36,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=2,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
image_dim=4,
|
||||
pos_embed_seq_len=2 * (4 * 4 + 1),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_encoder_config = CLIPVisionConfig(
|
||||
hidden_size=4,
|
||||
@@ -241,6 +283,7 @@ class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"tokenizer": tokenizer,
|
||||
"image_encoder": image_encoder,
|
||||
"image_processor": image_processor,
|
||||
"transformer_2": transformer_2,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -297,3 +340,9 @@ class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"TODO: refactor this test: one component can be optional for certain checkpoints but not for others"
|
||||
)
|
||||
def test_save_load_optional_components(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user