Compare commits

...

7 Commits

Author SHA1 Message Date
Aryan
39f688a24e Merge branch 'main' into modular-diffusers-wan-i2v-flf2v 2025-08-05 11:54:42 +02:00
Aryan
2ff808d26c address review comments 2025-08-01 06:49:06 +02:00
Aryan
22f3273a82 update 2025-07-29 03:36:41 +02:00
Aryan
03b03a91f1 update 2025-07-29 00:29:46 +02:00
Aryan
7ab5ecc774 Merge branch 'main' into modular-diffusers-wan-i2v-flf2v 2025-07-29 03:19:01 +05:30
Aryan
7ae44420c3 update 2025-07-28 07:20:46 +02:00
Aryan
183bcd5c79 update 2025-07-27 23:58:27 +02:00
5 changed files with 555 additions and 18 deletions

View File

@@ -25,12 +25,14 @@ else:
_import_structure["modular_blocks"] = [ _import_structure["modular_blocks"] = [
"ALL_BLOCKS", "ALL_BLOCKS",
"AUTO_BLOCKS", "AUTO_BLOCKS",
"IMAGE2VIDEO_BLOCKS",
"TEXT2VIDEO_BLOCKS", "TEXT2VIDEO_BLOCKS",
"WanAutoBeforeDenoiseStep", "WanAutoBeforeDenoiseStep",
"WanAutoBlocks", "WanAutoBlocks",
"WanAutoBlocks", "WanAutoBlocks",
"WanAutoDecodeStep", "WanAutoDecodeStep",
"WanAutoDenoiseStep", "WanAutoDenoiseStep",
"WanAutoVaeEncoderStep",
] ]
_import_structure["modular_pipeline"] = ["WanModularPipeline"] _import_structure["modular_pipeline"] = ["WanModularPipeline"]
@@ -45,11 +47,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modular_blocks import ( from .modular_blocks import (
ALL_BLOCKS, ALL_BLOCKS,
AUTO_BLOCKS, AUTO_BLOCKS,
IMAGE2VIDEO_BLOCKS,
TEXT2VIDEO_BLOCKS, TEXT2VIDEO_BLOCKS,
WanAutoBeforeDenoiseStep, WanAutoBeforeDenoiseStep,
WanAutoBlocks, WanAutoBlocks,
WanAutoDecodeStep, WanAutoDecodeStep,
WanAutoDenoiseStep, WanAutoDenoiseStep,
WanAutoVaeEncoderStep,
) )
from .modular_pipeline import WanModularPipeline from .modular_pipeline import WanModularPipeline
else: else:

View File

@@ -282,7 +282,10 @@ class WanPrepareLatentsStep(PipelineBlock):
return [ return [
OutputParam( OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
) ),
OutputParam("height", type_hint=int),
OutputParam("width", type_hint=int),
OutputParam("num_frames", type_hint=int),
] ]
@staticmethod @staticmethod

View File

@@ -34,6 +34,56 @@ from .modular_pipeline import WanModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanI2VLoopBeforeDenoiser(PipelineBlock):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", UniPCMultistepScheduler),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanI2VDenoiseLoopWrapper`)"
)
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process.",
),
InputParam(
"latent_condition",
required=True,
type_hint=torch.Tensor,
description="The latent condition to use for the denoising process.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latent_model_inputs",
type_hint=torch.Tensor,
description="The concatenated noisy and conditioning latents to use for the denoising process.",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: int):
block_state.latent_model_inputs = torch.cat([block_state.latents, block_state.latent_condition], dim=1)
return components, block_state
class WanLoopDenoiser(PipelineBlock): class WanLoopDenoiser(PipelineBlock):
model_name = "wan" model_name = "wan"
@@ -102,7 +152,7 @@ class WanLoopDenoiser(PipelineBlock):
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields` # Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. # Each guider_state_batch will have .prompt_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond # e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
@@ -120,7 +170,112 @@ class WanLoopDenoiser(PipelineBlock):
guider_state_batch.noise_pred = components.transformer( guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latents.to(transformer_dtype), hidden_states=block_state.latents.to(transformer_dtype),
timestep=t.flatten(), timestep=t.flatten(),
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds.to(transformer_dtype),
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
)[0]
components.guider.cleanup_models(components.transformer)
# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
return components, block_state
class WanI2VLoopDenoiser(PipelineBlock):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
ComponentSpec("transformer", WanTransformer3DModel),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latent_model_inputs",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process.",
),
InputParam(
"image_embeds",
required=True,
type_hint=torch.Tensor,
description="The encoder hidden states for the image inputs.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process.",
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds. "
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@torch.no_grad()
def __call__(
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
}
transformer_dtype = components.transformer.dtype
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latent_model_inputs.to(transformer_dtype),
timestep=t.flatten(),
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
encoder_hidden_states_image=block_state.image_embeds.to(transformer_dtype),
attention_kwargs=block_state.attention_kwargs, attention_kwargs=block_state.attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
@@ -247,7 +402,7 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
WanLoopDenoiser, WanLoopDenoiser,
WanLoopAfterDenoiser, WanLoopAfterDenoiser,
] ]
block_names = ["before_denoiser", "denoiser", "after_denoiser"] block_names = ["denoiser", "after_denoiser"]
@property @property
def description(self) -> str: def description(self) -> str:
@@ -257,5 +412,26 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `WanLoopDenoiser`\n" " - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n" " - `WanLoopAfterDenoiser`\n"
"This block supports both text2vid tasks." "This block supports the text2vid task."
)
class WanI2VDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanI2VLoopBeforeDenoiser,
WanI2VLoopDenoiser,
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents with conditional first- and last-frame support. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `WanI2VLoopBeforeDenoiser`\n"
" - `WanI2VLoopDenoiser`\n"
" - `WanI2VLoopAfterDenoiser`\n"
"This block supports the image-to-video and first-last-frame-to-video tasks."
) )

View File

@@ -17,11 +17,14 @@ from typing import List, Optional, Union
import regex as re import regex as re
import torch import torch
from transformers import AutoTokenizer, UMT5EncoderModel from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance from ...guiders import ClassifierFreeGuidance
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLWan
from ...utils import is_ftfy_available, logging from ...utils import is_ftfy_available, logging
from ...video_processor import VideoProcessor
from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline from .modular_pipeline import WanModularPipeline
@@ -51,6 +54,20 @@ def prompt_clean(text):
return text return text
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class WanTextEncoderStep(PipelineBlock): class WanTextEncoderStep(PipelineBlock):
model_name = "wan" model_name = "wan"
@@ -240,3 +257,238 @@ class WanTextEncoderStep(PipelineBlock):
# Add outputs # Add outputs
self.set_block_state(state, block_state) self.set_block_state(state, block_state)
return components, state return components, state
class WanImageEncoderStep(PipelineBlock):
model_name = "wan"
@property
def description(self) -> str:
return "Image Encoder step to compute image embeddings to guide the video generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("image_encoder", CLIPVisionModel),
ComponentSpec("image_processor", CLIPImageProcessor),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
required=True,
description="The input image to condition the generation on for first-frame conditioned video generation.",
),
InputParam(
"last_image",
required=False,
description="The last image to condition the generation on for last-frame conditioned video generation.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_embeds",
type_hint=torch.Tensor,
description="image embeddings used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
if not isinstance(block_state.image, PipelineImageInput):
raise ValueError(f"`image` has to be of type `PipelineImageInput` but is {type(block_state.image)}.")
if block_state.last_image is not None and not isinstance(block_state.last_image, PipelineImageInput):
raise ValueError(
f"`last_image` has to be of type `PipelineImageInput` but is {type(block_state.last_image)}."
)
@staticmethod
def encode_image(
components,
image: Union[PipelineImageInput, List[PipelineImageInput]],
device: torch.device,
):
image = components.image_processor(images=image, return_tensors="pt").to(device)
image_embeds = components.image_encoder(**image, output_hidden_states=True)
return image_embeds.hidden_states[-2]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
# Encode input images
image = block_state.image
if block_state.last_image is not None:
image = [block_state.image, block_state.last_image]
block_state.image_embeds = self.encode_image(components, image, block_state.device)
# Add outputs
self.set_block_state(state, block_state)
return components, state
class WanVaeEncoderStep(PipelineBlock):
model_name = "wan"
@property
def description(self) -> str:
return (
"VAE encode step that encodes the input image/last_image to latents for conditioning the video generation"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLWan),
ComponentSpec(
"video_processor",
VideoProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
InputParam("last_image", required=False),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("num_frames", type_hint=int),
InputParam("batch_size", type_hint=int),
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latent_condition",
type_hint=torch.Tensor,
description="The latents representing the reference first-frame/last-frame for conditioned video generation.",
),
OutputParam("num_channels_latents", type_hint=int),
]
@staticmethod
def _encode_vae_image(
components: WanModularPipeline,
batch_size: int,
height: int,
width: int,
num_frames: int,
image: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
last_image: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
):
latent_height = height // components.vae_scale_factor_spatial
latent_width = width // components.vae_scale_factor_spatial
latents_mean = (
torch.tensor(components.vae.config.latents_mean)
.view(1, components.vae.config.z_dim, 1, 1, 1)
.to(device, dtype)
)
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
1, components.vae.config.z_dim, 1, 1, 1
).to(device, dtype)
image = image.unsqueeze(2)
if last_image is None:
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
)
else:
last_image = last_image.unsqueeze(2)
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
dim=2,
)
video_condition = video_condition.to(device=device, dtype=dtype)
if isinstance(generator, list):
latent_condition = [
retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") for _ in generator
]
latent_condition = torch.cat(latent_condition)
else:
latent_condition = retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax")
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
latent_condition = latent_condition.to(dtype)
latent_condition = (latent_condition - latents_mean) * latents_std
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
if last_image is None:
mask_lat_size[:, :, list(range(1, num_frames))] = 0
else:
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(latent_condition.device)
latent_condition = torch.concat([mask_lat_size, latent_condition], dim=1)
return latent_condition
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
block_state.num_channels_latents = components.vae.config.z_dim
block_state.image = components.video_processor.preprocess(
block_state.image, height=block_state.height, width=block_state.width
).to(block_state.device, dtype=torch.float32)
if block_state.last_image is not None:
block_state.last_image = components.video_processor.preprocess(
block_state.last_image, height=block_state.height, width=block_state.width
).to(block_state.device, dtype=torch.float32)
block_state.latent_condition = self._encode_vae_image(
components,
block_state.batch_size,
block_state.height,
block_state.width,
block_state.num_frames,
block_state.image,
block_state.device,
block_state.dtype,
block_state.last_image,
block_state.generator,
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -21,13 +21,43 @@ from .before_denoise import (
WanSetTimestepsStep, WanSetTimestepsStep,
) )
from .decoders import WanDecodeStep from .decoders import WanDecodeStep
from .denoise import WanDenoiseStep from .denoise import WanDenoiseStep, WanI2VDenoiseStep
from .encoders import WanTextEncoderStep from .encoders import WanImageEncoderStep, WanTextEncoderStep, WanVaeEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanAutoImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanImageEncoderStep]
block_names = ["image_encoder"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"Image encoder step that encodes the image inputs into a conditioning embedding.\n"
+ "This is an auto pipeline block that works for both first-frame and first-last-frame conditioning tasks.\n"
+ " - `WanImageEncoderStep` (image_encoder) is used when `image`, and possibly `last_image` is provided."
+ " - if `image` is not provided, this step will be skipped."
)
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [WanVaeEncoderStep]
block_names = ["img2vid"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for both first-frame and first-last-frame conditioning tasks.\n"
+ " - `WanVaeEncoderStep` (img2vid) is used when `image`, and possibly `last_image` is provided."
+ " - if `image` is not provided, this step will be skipped."
)
# before_denoise: text2vid # before_denoise: text2vid
class WanBeforeDenoiseStep(SequentialPipelineBlocks): class WanBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [ block_classes = [
@@ -48,44 +78,72 @@ class WanBeforeDenoiseStep(SequentialPipelineBlocks):
) )
# before_denoise: all task (text2vid,) # before_denoise: img2vid
class WanI2VBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanImageEncoderStep,
WanVaeEncoderStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "image_encoder", "vae_encoder"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for image-to-video and first-last-frame-to-video tasks.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanImageEncoderStep` is used to encode the image inputs into a conditioning embedding\n"
+ " - `WanVaeEncoderStep` is used to encode the image/last-image inputs into their latent representations\n"
)
# before_denoise: all task (text2vid, img2vid)
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks): class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [ block_classes = [
WanBeforeDenoiseStep, WanBeforeDenoiseStep,
WanI2VBeforeDenoiseStep,
] ]
block_names = ["text2vid"] block_names = ["text2vid", "img2vid"]
block_trigger_inputs = [None] block_trigger_inputs = [None, "image"]
@property @property
def description(self): def description(self):
return ( return (
"Before denoise step that prepare the inputs for the denoise step.\n" "Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2vid.\n" + "This is an auto pipeline block that works for text2vid, img2vid, first-last-frame2vid.\n"
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n" + " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
+ " - `WanI2VBeforeDenoiseStep` (img2vid) is used when `image` is provided.\n"
) )
# denoise: text2vid # denoise: text2vid, img2vid
class WanAutoDenoiseStep(AutoPipelineBlocks): class WanAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [ block_classes = [
WanDenoiseStep, WanDenoiseStep,
WanI2VDenoiseStep,
] ]
block_names = ["denoise"] block_names = ["denoise", "denoise_i2v"]
block_trigger_inputs = [None] block_trigger_inputs = [None, "image"]
@property @property
def description(self) -> str: def description(self) -> str:
return ( return (
"Denoise step that iteratively denoise the latents. " "Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2vid tasks.." "This is a auto pipeline block that works for text2vid and img2vid tasks..."
" - `WanDenoiseStep` (denoise) for text2vid tasks." " - `WanDenoiseStep` (denoise) for text2vid task."
" - `WanI2VDenoiseStep` (denoise_i2v) for img2vid task, which is used when `image` is provided.\n"
) )
# decode: all task (text2img, img2img, inpainting) # decode: all task (text2img, img2img, inpainting)
class WanAutoDecodeStep(AutoPipelineBlocks): class WanAutoDecodeStep(AutoPipelineBlocks):
block_classes = [WanDecodeStep] block_classes = [WanDecodeStep]
block_names = ["non-inpaint"] block_names = ["decode"]
block_trigger_inputs = [None] block_trigger_inputs = [None]
@property @property
@@ -116,6 +174,33 @@ class WanAutoBlocks(SequentialPipelineBlocks):
) )
# img2vid and first-last-frame2vid
class WanI2VAutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoBeforeDenoiseStep,
WanImageEncoderStep,
WanAutoVaeEncoderStep,
WanAutoDenoiseStep,
WanAutoDecodeStep,
]
block_names = [
"text_encoder",
"before_denoise",
"image_encoder",
"vae_encoder",
"denoise",
"decoder",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-video using Wan.\n"
+ "- for image-to-video and first-last-frame-to-video generation, you need to provide is `image`, and possibly `last_image`"
)
TEXT2VIDEO_BLOCKS = InsertableDict( TEXT2VIDEO_BLOCKS = InsertableDict(
[ [
("text_encoder", WanTextEncoderStep), ("text_encoder", WanTextEncoderStep),
@@ -128,9 +213,25 @@ TEXT2VIDEO_BLOCKS = InsertableDict(
) )
IMAGE2VIDEO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("image_encoder", WanImageEncoderStep),
("vae_encoder", WanVaeEncoderStep),
("denoise", WanI2VDenoiseStep),
("decode", WanDecodeStep),
]
)
AUTO_BLOCKS = InsertableDict( AUTO_BLOCKS = InsertableDict(
[ [
("text_encoder", WanTextEncoderStep), ("text_encoder", WanTextEncoderStep),
("image_encoder", WanAutoImageEncoderStep),
("vae_encoder", WanAutoVaeEncoderStep),
("before_denoise", WanAutoBeforeDenoiseStep), ("before_denoise", WanAutoBeforeDenoiseStep),
("denoise", WanAutoDenoiseStep), ("denoise", WanAutoDenoiseStep),
("decode", WanAutoDecodeStep), ("decode", WanAutoDecodeStep),
@@ -140,5 +241,6 @@ AUTO_BLOCKS = InsertableDict(
ALL_BLOCKS = { ALL_BLOCKS = {
"text2video": TEXT2VIDEO_BLOCKS, "text2video": TEXT2VIDEO_BLOCKS,
"image2video": IMAGE2VIDEO_BLOCKS,
"auto": AUTO_BLOCKS, "auto": AUTO_BLOCKS,
} }