Compare commits

...

7 Commits

Author SHA1 Message Date
Aryan
39f688a24e Merge branch 'main' into modular-diffusers-wan-i2v-flf2v 2025-08-05 11:54:42 +02:00
Aryan
2ff808d26c address review comments 2025-08-01 06:49:06 +02:00
Aryan
22f3273a82 update 2025-07-29 03:36:41 +02:00
Aryan
03b03a91f1 update 2025-07-29 00:29:46 +02:00
Aryan
7ab5ecc774 Merge branch 'main' into modular-diffusers-wan-i2v-flf2v 2025-07-29 03:19:01 +05:30
Aryan
7ae44420c3 update 2025-07-28 07:20:46 +02:00
Aryan
183bcd5c79 update 2025-07-27 23:58:27 +02:00
5 changed files with 555 additions and 18 deletions

View File

@@ -25,12 +25,14 @@ else:
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"IMAGE2VIDEO_BLOCKS",
"TEXT2VIDEO_BLOCKS",
"WanAutoBeforeDenoiseStep",
"WanAutoBlocks",
"WanAutoBlocks",
"WanAutoDecodeStep",
"WanAutoDenoiseStep",
"WanAutoVaeEncoderStep",
]
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
@@ -45,11 +47,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
IMAGE2VIDEO_BLOCKS,
TEXT2VIDEO_BLOCKS,
WanAutoBeforeDenoiseStep,
WanAutoBlocks,
WanAutoDecodeStep,
WanAutoDenoiseStep,
WanAutoVaeEncoderStep,
)
from .modular_pipeline import WanModularPipeline
else:

View File

@@ -282,7 +282,10 @@ class WanPrepareLatentsStep(PipelineBlock):
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
)
),
OutputParam("height", type_hint=int),
OutputParam("width", type_hint=int),
OutputParam("num_frames", type_hint=int),
]
@staticmethod

View File

@@ -34,6 +34,56 @@ from .modular_pipeline import WanModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanI2VLoopBeforeDenoiser(PipelineBlock):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", UniPCMultistepScheduler),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanI2VDenoiseLoopWrapper`)"
)
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process.",
),
InputParam(
"latent_condition",
required=True,
type_hint=torch.Tensor,
description="The latent condition to use for the denoising process.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latent_model_inputs",
type_hint=torch.Tensor,
description="The concatenated noisy and conditioning latents to use for the denoising process.",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: int):
block_state.latent_model_inputs = torch.cat([block_state.latents, block_state.latent_condition], dim=1)
return components, block_state
class WanLoopDenoiser(PipelineBlock):
model_name = "wan"
@@ -102,7 +152,7 @@ class WanLoopDenoiser(PipelineBlock):
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# Each guider_state_batch will have .prompt_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
@@ -120,7 +170,112 @@ class WanLoopDenoiser(PipelineBlock):
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latents.to(transformer_dtype),
timestep=t.flatten(),
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
)[0]
components.guider.cleanup_models(components.transformer)
# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
return components, block_state
class WanI2VLoopDenoiser(PipelineBlock):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
ComponentSpec("transformer", WanTransformer3DModel),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latent_model_inputs",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process.",
),
InputParam(
"image_embeds",
required=True,
type_hint=torch.Tensor,
description="The encoder hidden states for the image inputs.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process.",
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds. "
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@torch.no_grad()
def __call__(
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
}
transformer_dtype = components.transformer.dtype
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# Prepare minibatches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latent_model_inputs.to(transformer_dtype),
timestep=t.flatten(),
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
encoder_hidden_states_image=block_state.image_embeds.to(transformer_dtype),
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
)[0]
@@ -247,7 +402,7 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
WanLoopDenoiser,
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
@@ -257,5 +412,26 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports both text2vid tasks."
"This block supports the text2vid task."
)
class WanI2VDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanI2VLoopBeforeDenoiser,
WanI2VLoopDenoiser,
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents with conditional first- and last-frame support. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `WanI2VLoopBeforeDenoiser`\n"
" - `WanI2VLoopDenoiser`\n"
" - `WanI2VLoopAfterDenoiser`\n"
"This block supports the image-to-video and first-last-frame-to-video tasks."
)

View File

@@ -17,11 +17,14 @@ from typing import List, Optional, Union
import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLWan
from ...utils import is_ftfy_available, logging
from ...video_processor import VideoProcessor
from ..modular_pipeline import PipelineBlock, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
@@ -51,6 +54,20 @@ def prompt_clean(text):
return text
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class WanTextEncoderStep(PipelineBlock):
model_name = "wan"
@@ -240,3 +257,238 @@ class WanTextEncoderStep(PipelineBlock):
# Add outputs
self.set_block_state(state, block_state)
return components, state
class WanImageEncoderStep(PipelineBlock):
model_name = "wan"
@property
def description(self) -> str:
return "Image Encoder step to compute image embeddings to guide the video generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("image_encoder", CLIPVisionModel),
ComponentSpec("image_processor", CLIPImageProcessor),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
required=True,
description="The input image to condition the generation on for first-frame conditioned video generation.",
),
InputParam(
"last_image",
required=False,
description="The last image to condition the generation on for last-frame conditioned video generation.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_embeds",
type_hint=torch.Tensor,
description="image embeddings used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
if not isinstance(block_state.image, PipelineImageInput):
raise ValueError(f"`image` has to be of type `PipelineImageInput` but is {type(block_state.image)}.")
if block_state.last_image is not None and not isinstance(block_state.last_image, PipelineImageInput):
raise ValueError(
f"`last_image` has to be of type `PipelineImageInput` but is {type(block_state.last_image)}."
)
@staticmethod
def encode_image(
components,
image: Union[PipelineImageInput, List[PipelineImageInput]],
device: torch.device,
):
image = components.image_processor(images=image, return_tensors="pt").to(device)
image_embeds = components.image_encoder(**image, output_hidden_states=True)
return image_embeds.hidden_states[-2]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
# Encode input images
image = block_state.image
if block_state.last_image is not None:
image = [block_state.image, block_state.last_image]
block_state.image_embeds = self.encode_image(components, image, block_state.device)
# Add outputs
self.set_block_state(state, block_state)
return components, state
class WanVaeEncoderStep(PipelineBlock):
model_name = "wan"
@property
def description(self) -> str:
return (
"VAE encode step that encodes the input image/last_image to latents for conditioning the video generation"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLWan),
ComponentSpec(
"video_processor",
VideoProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
InputParam("last_image", required=False),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("num_frames", type_hint=int),
InputParam("batch_size", type_hint=int),
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latent_condition",
type_hint=torch.Tensor,
description="The latents representing the reference first-frame/last-frame for conditioned video generation.",
),
OutputParam("num_channels_latents", type_hint=int),
]
@staticmethod
def _encode_vae_image(
components: WanModularPipeline,
batch_size: int,
height: int,
width: int,
num_frames: int,
image: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
last_image: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
):
latent_height = height // components.vae_scale_factor_spatial
latent_width = width // components.vae_scale_factor_spatial
latents_mean = (
torch.tensor(components.vae.config.latents_mean)
.view(1, components.vae.config.z_dim, 1, 1, 1)
.to(device, dtype)
)
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
1, components.vae.config.z_dim, 1, 1, 1
).to(device, dtype)
image = image.unsqueeze(2)
if last_image is None:
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
)
else:
last_image = last_image.unsqueeze(2)
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
dim=2,
)
video_condition = video_condition.to(device=device, dtype=dtype)
if isinstance(generator, list):
latent_condition = [
retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") for _ in generator
]
latent_condition = torch.cat(latent_condition)
else:
latent_condition = retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax")
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
latent_condition = latent_condition.to(dtype)
latent_condition = (latent_condition - latents_mean) * latents_std
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
if last_image is None:
mask_lat_size[:, :, list(range(1, num_frames))] = 0
else:
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(latent_condition.device)
latent_condition = torch.concat([mask_lat_size, latent_condition], dim=1)
return latent_condition
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
block_state.num_channels_latents = components.vae.config.z_dim
block_state.image = components.video_processor.preprocess(
block_state.image, height=block_state.height, width=block_state.width
).to(block_state.device, dtype=torch.float32)
if block_state.last_image is not None:
block_state.last_image = components.video_processor.preprocess(
block_state.last_image, height=block_state.height, width=block_state.width
).to(block_state.device, dtype=torch.float32)
block_state.latent_condition = self._encode_vae_image(
components,
block_state.batch_size,
block_state.height,
block_state.width,
block_state.num_frames,
block_state.image,
block_state.device,
block_state.dtype,
block_state.last_image,
block_state.generator,
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -21,13 +21,43 @@ from .before_denoise import (
WanSetTimestepsStep,
)
from .decoders import WanDecodeStep
from .denoise import WanDenoiseStep
from .encoders import WanTextEncoderStep
from .denoise import WanDenoiseStep, WanI2VDenoiseStep
from .encoders import WanImageEncoderStep, WanTextEncoderStep, WanVaeEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanAutoImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanImageEncoderStep]
block_names = ["image_encoder"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"Image encoder step that encodes the image inputs into a conditioning embedding.\n"
+ "This is an auto pipeline block that works for both first-frame and first-last-frame conditioning tasks.\n"
+ " - `WanImageEncoderStep` (image_encoder) is used when `image`, and possibly `last_image` is provided."
+ " - if `image` is not provided, this step will be skipped."
)
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [WanVaeEncoderStep]
block_names = ["img2vid"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for both first-frame and first-last-frame conditioning tasks.\n"
+ " - `WanVaeEncoderStep` (img2vid) is used when `image`, and possibly `last_image` is provided."
+ " - if `image` is not provided, this step will be skipped."
)
# before_denoise: text2vid
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
@@ -48,44 +78,72 @@ class WanBeforeDenoiseStep(SequentialPipelineBlocks):
)
# before_denoise: all task (text2vid,)
# before_denoise: img2vid
class WanI2VBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanImageEncoderStep,
WanVaeEncoderStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "image_encoder", "vae_encoder"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for image-to-video and first-last-frame-to-video tasks.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanImageEncoderStep` is used to encode the image inputs into a conditioning embedding\n"
+ " - `WanVaeEncoderStep` is used to encode the image/last-image inputs into their latent representations\n"
)
# before_denoise: all task (text2vid, img2vid)
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanBeforeDenoiseStep,
WanI2VBeforeDenoiseStep,
]
block_names = ["text2vid"]
block_trigger_inputs = [None]
block_names = ["text2vid", "img2vid"]
block_trigger_inputs = [None, "image"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2vid.\n"
+ "This is an auto pipeline block that works for text2vid, img2vid, first-last-frame2vid.\n"
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
+ " - `WanI2VBeforeDenoiseStep` (img2vid) is used when `image` is provided.\n"
)
# denoise: text2vid
# denoise: text2vid, img2vid
class WanAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanDenoiseStep,
WanI2VDenoiseStep,
]
block_names = ["denoise"]
block_trigger_inputs = [None]
block_names = ["denoise", "denoise_i2v"]
block_trigger_inputs = [None, "image"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2vid tasks.."
" - `WanDenoiseStep` (denoise) for text2vid tasks."
"This is a auto pipeline block that works for text2vid and img2vid tasks..."
" - `WanDenoiseStep` (denoise) for text2vid task."
" - `WanI2VDenoiseStep` (denoise_i2v) for img2vid task, which is used when `image` is provided.\n"
)
# decode: all task (text2img, img2img, inpainting)
class WanAutoDecodeStep(AutoPipelineBlocks):
block_classes = [WanDecodeStep]
block_names = ["non-inpaint"]
block_names = ["decode"]
block_trigger_inputs = [None]
@property
@@ -116,6 +174,33 @@ class WanAutoBlocks(SequentialPipelineBlocks):
)
# img2vid and first-last-frame2vid
class WanI2VAutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoBeforeDenoiseStep,
WanImageEncoderStep,
WanAutoVaeEncoderStep,
WanAutoDenoiseStep,
WanAutoDecodeStep,
]
block_names = [
"text_encoder",
"before_denoise",
"image_encoder",
"vae_encoder",
"denoise",
"decoder",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-video using Wan.\n"
+ "- for image-to-video and first-last-frame-to-video generation, you need to provide is `image`, and possibly `last_image`"
)
TEXT2VIDEO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
@@ -128,9 +213,25 @@ TEXT2VIDEO_BLOCKS = InsertableDict(
)
IMAGE2VIDEO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("image_encoder", WanImageEncoderStep),
("vae_encoder", WanVaeEncoderStep),
("denoise", WanI2VDenoiseStep),
("decode", WanDecodeStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("image_encoder", WanAutoImageEncoderStep),
("vae_encoder", WanAutoVaeEncoderStep),
("before_denoise", WanAutoBeforeDenoiseStep),
("denoise", WanAutoDenoiseStep),
("decode", WanAutoDecodeStep),
@@ -140,5 +241,6 @@ AUTO_BLOCKS = InsertableDict(
ALL_BLOCKS = {
"text2video": TEXT2VIDEO_BLOCKS,
"image2video": IMAGE2VIDEO_BLOCKS,
"auto": AUTO_BLOCKS,
}