mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-21 03:44:49 +08:00
Compare commits
7 Commits
cp-fixes-a
...
modular-di
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39f688a24e | ||
|
|
2ff808d26c | ||
|
|
22f3273a82 | ||
|
|
03b03a91f1 | ||
|
|
7ab5ecc774 | ||
|
|
7ae44420c3 | ||
|
|
183bcd5c79 |
@@ -25,12 +25,14 @@ else:
|
|||||||
_import_structure["modular_blocks"] = [
|
_import_structure["modular_blocks"] = [
|
||||||
"ALL_BLOCKS",
|
"ALL_BLOCKS",
|
||||||
"AUTO_BLOCKS",
|
"AUTO_BLOCKS",
|
||||||
|
"IMAGE2VIDEO_BLOCKS",
|
||||||
"TEXT2VIDEO_BLOCKS",
|
"TEXT2VIDEO_BLOCKS",
|
||||||
"WanAutoBeforeDenoiseStep",
|
"WanAutoBeforeDenoiseStep",
|
||||||
"WanAutoBlocks",
|
"WanAutoBlocks",
|
||||||
"WanAutoBlocks",
|
"WanAutoBlocks",
|
||||||
"WanAutoDecodeStep",
|
"WanAutoDecodeStep",
|
||||||
"WanAutoDenoiseStep",
|
"WanAutoDenoiseStep",
|
||||||
|
"WanAutoVaeEncoderStep",
|
||||||
]
|
]
|
||||||
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
|
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
|
||||||
|
|
||||||
@@ -45,11 +47,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .modular_blocks import (
|
from .modular_blocks import (
|
||||||
ALL_BLOCKS,
|
ALL_BLOCKS,
|
||||||
AUTO_BLOCKS,
|
AUTO_BLOCKS,
|
||||||
|
IMAGE2VIDEO_BLOCKS,
|
||||||
TEXT2VIDEO_BLOCKS,
|
TEXT2VIDEO_BLOCKS,
|
||||||
WanAutoBeforeDenoiseStep,
|
WanAutoBeforeDenoiseStep,
|
||||||
WanAutoBlocks,
|
WanAutoBlocks,
|
||||||
WanAutoDecodeStep,
|
WanAutoDecodeStep,
|
||||||
WanAutoDenoiseStep,
|
WanAutoDenoiseStep,
|
||||||
|
WanAutoVaeEncoderStep,
|
||||||
)
|
)
|
||||||
from .modular_pipeline import WanModularPipeline
|
from .modular_pipeline import WanModularPipeline
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -282,7 +282,10 @@ class WanPrepareLatentsStep(PipelineBlock):
|
|||||||
return [
|
return [
|
||||||
OutputParam(
|
OutputParam(
|
||||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||||
)
|
),
|
||||||
|
OutputParam("height", type_hint=int),
|
||||||
|
OutputParam("width", type_hint=int),
|
||||||
|
OutputParam("num_frames", type_hint=int),
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -34,6 +34,56 @@ from .modular_pipeline import WanModularPipeline
|
|||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class WanI2VLoopBeforeDenoiser(PipelineBlock):
|
||||||
|
model_name = "wan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("scheduler", UniPCMultistepScheduler),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Step within the denoising loop that prepares the latent input for the denoiser. "
|
||||||
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||||
|
"object (e.g. `WanI2VDenoiseLoopWrapper`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_inputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"latents",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The initial latents to use for the denoising process.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"latent_condition",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The latent condition to use for the denoising process.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"latent_model_inputs",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The concatenated noisy and conditioning latents to use for the denoising process.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: int):
|
||||||
|
block_state.latent_model_inputs = torch.cat([block_state.latents, block_state.latent_condition], dim=1)
|
||||||
|
return components, block_state
|
||||||
|
|
||||||
|
|
||||||
class WanLoopDenoiser(PipelineBlock):
|
class WanLoopDenoiser(PipelineBlock):
|
||||||
model_name = "wan"
|
model_name = "wan"
|
||||||
|
|
||||||
@@ -102,7 +152,7 @@ class WanLoopDenoiser(PipelineBlock):
|
|||||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||||
|
|
||||||
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
||||||
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
|
# Each guider_state_batch will have .prompt_embeds.
|
||||||
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
||||||
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
||||||
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
||||||
@@ -120,7 +170,112 @@ class WanLoopDenoiser(PipelineBlock):
|
|||||||
guider_state_batch.noise_pred = components.transformer(
|
guider_state_batch.noise_pred = components.transformer(
|
||||||
hidden_states=block_state.latents.to(transformer_dtype),
|
hidden_states=block_state.latents.to(transformer_dtype),
|
||||||
timestep=t.flatten(),
|
timestep=t.flatten(),
|
||||||
encoder_hidden_states=prompt_embeds,
|
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
|
||||||
|
attention_kwargs=block_state.attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
components.guider.cleanup_models(components.transformer)
|
||||||
|
|
||||||
|
# Perform guidance
|
||||||
|
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
|
||||||
|
|
||||||
|
return components, block_state
|
||||||
|
|
||||||
|
|
||||||
|
class WanI2VLoopDenoiser(PipelineBlock):
|
||||||
|
model_name = "wan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec(
|
||||||
|
"guider",
|
||||||
|
ClassifierFreeGuidance,
|
||||||
|
config=FrozenDict({"guidance_scale": 5.0}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
ComponentSpec("transformer", WanTransformer3DModel),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Step within the denoising loop that denoise the latents with guidance. "
|
||||||
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||||
|
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return [
|
||||||
|
InputParam("attention_kwargs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_inputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"latent_model_inputs",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The initial latents to use for the denoising process.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"image_embeds",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The encoder hidden states for the image inputs.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"num_inference_steps",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="The number of inference steps to use for the denoising process.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
kwargs_type="guider_input_fields",
|
||||||
|
description=(
|
||||||
|
"All conditional model inputs that need to be prepared with guider. "
|
||||||
|
"It should contain prompt_embeds/negative_prompt_embeds. "
|
||||||
|
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||||
|
) -> PipelineState:
|
||||||
|
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
||||||
|
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
||||||
|
guider_input_fields = {
|
||||||
|
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
|
||||||
|
}
|
||||||
|
transformer_dtype = components.transformer.dtype
|
||||||
|
|
||||||
|
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||||
|
|
||||||
|
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
||||||
|
# Each guider_state_batch will have .prompt_embeds.
|
||||||
|
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
||||||
|
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
||||||
|
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
||||||
|
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
||||||
|
|
||||||
|
# run the denoiser for each guidance batch
|
||||||
|
for guider_state_batch in guider_state:
|
||||||
|
components.guider.prepare_models(components.transformer)
|
||||||
|
cond_kwargs = guider_state_batch.as_dict()
|
||||||
|
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
||||||
|
prompt_embeds = cond_kwargs.pop("prompt_embeds")
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
||||||
|
guider_state_batch.noise_pred = components.transformer(
|
||||||
|
hidden_states=block_state.latent_model_inputs.to(transformer_dtype),
|
||||||
|
timestep=t.flatten(),
|
||||||
|
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
|
||||||
|
encoder_hidden_states_image=block_state.image_embeds.to(transformer_dtype),
|
||||||
attention_kwargs=block_state.attention_kwargs,
|
attention_kwargs=block_state.attention_kwargs,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0]
|
)[0]
|
||||||
@@ -247,7 +402,7 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
|
|||||||
WanLoopDenoiser,
|
WanLoopDenoiser,
|
||||||
WanLoopAfterDenoiser,
|
WanLoopAfterDenoiser,
|
||||||
]
|
]
|
||||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
block_names = ["denoiser", "after_denoiser"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
@@ -257,5 +412,26 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
|
|||||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||||
" - `WanLoopDenoiser`\n"
|
" - `WanLoopDenoiser`\n"
|
||||||
" - `WanLoopAfterDenoiser`\n"
|
" - `WanLoopAfterDenoiser`\n"
|
||||||
"This block supports both text2vid tasks."
|
"This block supports the text2vid task."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WanI2VDenoiseStep(WanDenoiseLoopWrapper):
|
||||||
|
block_classes = [
|
||||||
|
WanI2VLoopBeforeDenoiser,
|
||||||
|
WanI2VLoopDenoiser,
|
||||||
|
WanLoopAfterDenoiser,
|
||||||
|
]
|
||||||
|
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Denoise step that iteratively denoises the latents with conditional first- and last-frame support. \n"
|
||||||
|
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||||
|
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||||
|
" - `WanI2VLoopBeforeDenoiser`\n"
|
||||||
|
" - `WanI2VLoopDenoiser`\n"
|
||||||
|
" - `WanI2VLoopAfterDenoiser`\n"
|
||||||
|
"This block supports the image-to-video and first-last-frame-to-video tasks."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,11 +17,14 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
|
||||||
|
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...guiders import ClassifierFreeGuidance
|
from ...guiders import ClassifierFreeGuidance
|
||||||
|
from ...image_processor import PipelineImageInput
|
||||||
|
from ...models import AutoencoderKLWan
|
||||||
from ...utils import is_ftfy_available, logging
|
from ...utils import is_ftfy_available, logging
|
||||||
|
from ...video_processor import VideoProcessor
|
||||||
from ..modular_pipeline import PipelineBlock, PipelineState
|
from ..modular_pipeline import PipelineBlock, PipelineState
|
||||||
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
||||||
from .modular_pipeline import WanModularPipeline
|
from .modular_pipeline import WanModularPipeline
|
||||||
@@ -51,6 +54,20 @@ def prompt_clean(text):
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||||
|
def retrieve_latents(
|
||||||
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||||
|
):
|
||||||
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||||
|
return encoder_output.latent_dist.sample(generator)
|
||||||
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||||
|
return encoder_output.latent_dist.mode()
|
||||||
|
elif hasattr(encoder_output, "latents"):
|
||||||
|
return encoder_output.latents
|
||||||
|
else:
|
||||||
|
raise AttributeError("Could not access latents of provided encoder_output")
|
||||||
|
|
||||||
|
|
||||||
class WanTextEncoderStep(PipelineBlock):
|
class WanTextEncoderStep(PipelineBlock):
|
||||||
model_name = "wan"
|
model_name = "wan"
|
||||||
|
|
||||||
@@ -240,3 +257,238 @@ class WanTextEncoderStep(PipelineBlock):
|
|||||||
# Add outputs
|
# Add outputs
|
||||||
self.set_block_state(state, block_state)
|
self.set_block_state(state, block_state)
|
||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class WanImageEncoderStep(PipelineBlock):
|
||||||
|
model_name = "wan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Image Encoder step to compute image embeddings to guide the video generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("image_encoder", CLIPVisionModel),
|
||||||
|
ComponentSpec("image_processor", CLIPImageProcessor),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_configs(self) -> List[ConfigSpec]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"image",
|
||||||
|
required=True,
|
||||||
|
description="The input image to condition the generation on for first-frame conditioned video generation.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"last_image",
|
||||||
|
required=False,
|
||||||
|
description="The last image to condition the generation on for last-frame conditioned video generation.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"image_embeds",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="image embeddings used to guide the image generation",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_inputs(block_state):
|
||||||
|
if not isinstance(block_state.image, PipelineImageInput):
|
||||||
|
raise ValueError(f"`image` has to be of type `PipelineImageInput` but is {type(block_state.image)}.")
|
||||||
|
if block_state.last_image is not None and not isinstance(block_state.last_image, PipelineImageInput):
|
||||||
|
raise ValueError(
|
||||||
|
f"`last_image` has to be of type `PipelineImageInput` but is {type(block_state.last_image)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode_image(
|
||||||
|
components,
|
||||||
|
image: Union[PipelineImageInput, List[PipelineImageInput]],
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
image = components.image_processor(images=image, return_tensors="pt").to(device)
|
||||||
|
image_embeds = components.image_encoder(**image, output_hidden_states=True)
|
||||||
|
return image_embeds.hidden_states[-2]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
# Get inputs and intermediates
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(block_state)
|
||||||
|
|
||||||
|
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
# Encode input images
|
||||||
|
image = block_state.image
|
||||||
|
if block_state.last_image is not None:
|
||||||
|
image = [block_state.image, block_state.last_image]
|
||||||
|
|
||||||
|
block_state.image_embeds = self.encode_image(components, image, block_state.device)
|
||||||
|
|
||||||
|
# Add outputs
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class WanVaeEncoderStep(PipelineBlock):
|
||||||
|
model_name = "wan"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"VAE encode step that encodes the input image/last_image to latents for conditioning the video generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("vae", AutoencoderKLWan),
|
||||||
|
ComponentSpec(
|
||||||
|
"video_processor",
|
||||||
|
VideoProcessor,
|
||||||
|
config=FrozenDict({"vae_scale_factor": 8}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("image", required=True),
|
||||||
|
InputParam("last_image", required=False),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("height", type_hint=int),
|
||||||
|
InputParam("width", type_hint=int),
|
||||||
|
InputParam("num_frames", type_hint=int),
|
||||||
|
InputParam("batch_size", type_hint=int),
|
||||||
|
InputParam("generator"),
|
||||||
|
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"latent_condition",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The latents representing the reference first-frame/last-frame for conditioned video generation.",
|
||||||
|
),
|
||||||
|
OutputParam("num_channels_latents", type_hint=int),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _encode_vae_image(
|
||||||
|
components: WanModularPipeline,
|
||||||
|
batch_size: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
num_frames: int,
|
||||||
|
image: torch.Tensor,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
last_image: Optional[torch.Tensor] = None,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
):
|
||||||
|
latent_height = height // components.vae_scale_factor_spatial
|
||||||
|
latent_width = width // components.vae_scale_factor_spatial
|
||||||
|
|
||||||
|
latents_mean = (
|
||||||
|
torch.tensor(components.vae.config.latents_mean)
|
||||||
|
.view(1, components.vae.config.z_dim, 1, 1, 1)
|
||||||
|
.to(device, dtype)
|
||||||
|
)
|
||||||
|
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
||||||
|
1, components.vae.config.z_dim, 1, 1, 1
|
||||||
|
).to(device, dtype)
|
||||||
|
|
||||||
|
image = image.unsqueeze(2)
|
||||||
|
if last_image is None:
|
||||||
|
video_condition = torch.cat(
|
||||||
|
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
last_image = last_image.unsqueeze(2)
|
||||||
|
video_condition = torch.cat(
|
||||||
|
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
video_condition = video_condition.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if isinstance(generator, list):
|
||||||
|
latent_condition = [
|
||||||
|
retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") for _ in generator
|
||||||
|
]
|
||||||
|
latent_condition = torch.cat(latent_condition)
|
||||||
|
else:
|
||||||
|
latent_condition = retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax")
|
||||||
|
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
||||||
|
|
||||||
|
latent_condition = latent_condition.to(dtype)
|
||||||
|
latent_condition = (latent_condition - latents_mean) * latents_std
|
||||||
|
|
||||||
|
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||||
|
if last_image is None:
|
||||||
|
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||||
|
else:
|
||||||
|
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
|
||||||
|
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||||
|
first_frame_mask = torch.repeat_interleave(
|
||||||
|
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||||
|
)
|
||||||
|
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||||
|
mask_lat_size = mask_lat_size.view(
|
||||||
|
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||||
|
)
|
||||||
|
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||||
|
mask_lat_size = mask_lat_size.to(latent_condition.device)
|
||||||
|
latent_condition = torch.concat([mask_lat_size, latent_condition], dim=1)
|
||||||
|
|
||||||
|
return latent_condition
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
block_state.num_channels_latents = components.vae.config.z_dim
|
||||||
|
|
||||||
|
block_state.image = components.video_processor.preprocess(
|
||||||
|
block_state.image, height=block_state.height, width=block_state.width
|
||||||
|
).to(block_state.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
if block_state.last_image is not None:
|
||||||
|
block_state.last_image = components.video_processor.preprocess(
|
||||||
|
block_state.last_image, height=block_state.height, width=block_state.width
|
||||||
|
).to(block_state.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
block_state.latent_condition = self._encode_vae_image(
|
||||||
|
components,
|
||||||
|
block_state.batch_size,
|
||||||
|
block_state.height,
|
||||||
|
block_state.width,
|
||||||
|
block_state.num_frames,
|
||||||
|
block_state.image,
|
||||||
|
block_state.device,
|
||||||
|
block_state.dtype,
|
||||||
|
block_state.last_image,
|
||||||
|
block_state.generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|||||||
@@ -21,13 +21,43 @@ from .before_denoise import (
|
|||||||
WanSetTimestepsStep,
|
WanSetTimestepsStep,
|
||||||
)
|
)
|
||||||
from .decoders import WanDecodeStep
|
from .decoders import WanDecodeStep
|
||||||
from .denoise import WanDenoiseStep
|
from .denoise import WanDenoiseStep, WanI2VDenoiseStep
|
||||||
from .encoders import WanTextEncoderStep
|
from .encoders import WanImageEncoderStep, WanTextEncoderStep, WanVaeEncoderStep
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||||
|
block_classes = [WanImageEncoderStep]
|
||||||
|
block_names = ["image_encoder"]
|
||||||
|
block_trigger_inputs = ["image"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"Image encoder step that encodes the image inputs into a conditioning embedding.\n"
|
||||||
|
+ "This is an auto pipeline block that works for both first-frame and first-last-frame conditioning tasks.\n"
|
||||||
|
+ " - `WanImageEncoderStep` (image_encoder) is used when `image`, and possibly `last_image` is provided."
|
||||||
|
+ " - if `image` is not provided, this step will be skipped."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||||
|
block_classes = [WanVaeEncoderStep]
|
||||||
|
block_names = ["img2vid"]
|
||||||
|
block_trigger_inputs = ["image"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||||
|
+ "This is an auto pipeline block that works for both first-frame and first-last-frame conditioning tasks.\n"
|
||||||
|
+ " - `WanVaeEncoderStep` (img2vid) is used when `image`, and possibly `last_image` is provided."
|
||||||
|
+ " - if `image` is not provided, this step will be skipped."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# before_denoise: text2vid
|
# before_denoise: text2vid
|
||||||
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
|
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||||
block_classes = [
|
block_classes = [
|
||||||
@@ -48,44 +78,72 @@ class WanBeforeDenoiseStep(SequentialPipelineBlocks):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# before_denoise: all task (text2vid,)
|
# before_denoise: img2vid
|
||||||
|
class WanI2VBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||||
|
block_classes = [
|
||||||
|
WanInputStep,
|
||||||
|
WanSetTimestepsStep,
|
||||||
|
WanPrepareLatentsStep,
|
||||||
|
WanImageEncoderStep,
|
||||||
|
WanVaeEncoderStep,
|
||||||
|
]
|
||||||
|
block_names = ["input", "set_timesteps", "prepare_latents", "image_encoder", "vae_encoder"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"Before denoise step that prepare the inputs for the denoise step for image-to-video and first-last-frame-to-video tasks.\n"
|
||||||
|
+ "This is a sequential pipeline blocks:\n"
|
||||||
|
+ " - `WanInputStep` is used to adjust the batch size of the model inputs\n"
|
||||||
|
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||||
|
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||||
|
+ " - `WanImageEncoderStep` is used to encode the image inputs into a conditioning embedding\n"
|
||||||
|
+ " - `WanVaeEncoderStep` is used to encode the image/last-image inputs into their latent representations\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# before_denoise: all task (text2vid, img2vid)
|
||||||
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||||
block_classes = [
|
block_classes = [
|
||||||
WanBeforeDenoiseStep,
|
WanBeforeDenoiseStep,
|
||||||
|
WanI2VBeforeDenoiseStep,
|
||||||
]
|
]
|
||||||
block_names = ["text2vid"]
|
block_names = ["text2vid", "img2vid"]
|
||||||
block_trigger_inputs = [None]
|
block_trigger_inputs = [None, "image"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self):
|
def description(self):
|
||||||
return (
|
return (
|
||||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||||
+ "This is an auto pipeline block that works for text2vid.\n"
|
+ "This is an auto pipeline block that works for text2vid, img2vid, first-last-frame2vid.\n"
|
||||||
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
|
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
|
||||||
|
+ " - `WanI2VBeforeDenoiseStep` (img2vid) is used when `image` is provided.\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# denoise: text2vid
|
# denoise: text2vid, img2vid
|
||||||
class WanAutoDenoiseStep(AutoPipelineBlocks):
|
class WanAutoDenoiseStep(AutoPipelineBlocks):
|
||||||
block_classes = [
|
block_classes = [
|
||||||
WanDenoiseStep,
|
WanDenoiseStep,
|
||||||
|
WanI2VDenoiseStep,
|
||||||
]
|
]
|
||||||
block_names = ["denoise"]
|
block_names = ["denoise", "denoise_i2v"]
|
||||||
block_trigger_inputs = [None]
|
block_trigger_inputs = [None, "image"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Denoise step that iteratively denoise the latents. "
|
"Denoise step that iteratively denoise the latents. "
|
||||||
"This is a auto pipeline block that works for text2vid tasks.."
|
"This is a auto pipeline block that works for text2vid and img2vid tasks..."
|
||||||
" - `WanDenoiseStep` (denoise) for text2vid tasks."
|
" - `WanDenoiseStep` (denoise) for text2vid task."
|
||||||
|
" - `WanI2VDenoiseStep` (denoise_i2v) for img2vid task, which is used when `image` is provided.\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# decode: all task (text2img, img2img, inpainting)
|
# decode: all task (text2img, img2img, inpainting)
|
||||||
class WanAutoDecodeStep(AutoPipelineBlocks):
|
class WanAutoDecodeStep(AutoPipelineBlocks):
|
||||||
block_classes = [WanDecodeStep]
|
block_classes = [WanDecodeStep]
|
||||||
block_names = ["non-inpaint"]
|
block_names = ["decode"]
|
||||||
block_trigger_inputs = [None]
|
block_trigger_inputs = [None]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -116,6 +174,33 @@ class WanAutoBlocks(SequentialPipelineBlocks):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# img2vid and first-last-frame2vid
|
||||||
|
class WanI2VAutoBlocks(SequentialPipelineBlocks):
|
||||||
|
block_classes = [
|
||||||
|
WanTextEncoderStep,
|
||||||
|
WanAutoBeforeDenoiseStep,
|
||||||
|
WanImageEncoderStep,
|
||||||
|
WanAutoVaeEncoderStep,
|
||||||
|
WanAutoDenoiseStep,
|
||||||
|
WanAutoDecodeStep,
|
||||||
|
]
|
||||||
|
block_names = [
|
||||||
|
"text_encoder",
|
||||||
|
"before_denoise",
|
||||||
|
"image_encoder",
|
||||||
|
"vae_encoder",
|
||||||
|
"denoise",
|
||||||
|
"decoder",
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"Auto Modular pipeline for text-to-video using Wan.\n"
|
||||||
|
+ "- for image-to-video and first-last-frame-to-video generation, you need to provide is `image`, and possibly `last_image`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TEXT2VIDEO_BLOCKS = InsertableDict(
|
TEXT2VIDEO_BLOCKS = InsertableDict(
|
||||||
[
|
[
|
||||||
("text_encoder", WanTextEncoderStep),
|
("text_encoder", WanTextEncoderStep),
|
||||||
@@ -128,9 +213,25 @@ TEXT2VIDEO_BLOCKS = InsertableDict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE2VIDEO_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", WanTextEncoderStep),
|
||||||
|
("input", WanInputStep),
|
||||||
|
("set_timesteps", WanSetTimestepsStep),
|
||||||
|
("prepare_latents", WanPrepareLatentsStep),
|
||||||
|
("image_encoder", WanImageEncoderStep),
|
||||||
|
("vae_encoder", WanVaeEncoderStep),
|
||||||
|
("denoise", WanI2VDenoiseStep),
|
||||||
|
("decode", WanDecodeStep),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
AUTO_BLOCKS = InsertableDict(
|
AUTO_BLOCKS = InsertableDict(
|
||||||
[
|
[
|
||||||
("text_encoder", WanTextEncoderStep),
|
("text_encoder", WanTextEncoderStep),
|
||||||
|
("image_encoder", WanAutoImageEncoderStep),
|
||||||
|
("vae_encoder", WanAutoVaeEncoderStep),
|
||||||
("before_denoise", WanAutoBeforeDenoiseStep),
|
("before_denoise", WanAutoBeforeDenoiseStep),
|
||||||
("denoise", WanAutoDenoiseStep),
|
("denoise", WanAutoDenoiseStep),
|
||||||
("decode", WanAutoDecodeStep),
|
("decode", WanAutoDecodeStep),
|
||||||
@@ -140,5 +241,6 @@ AUTO_BLOCKS = InsertableDict(
|
|||||||
|
|
||||||
ALL_BLOCKS = {
|
ALL_BLOCKS = {
|
||||||
"text2video": TEXT2VIDEO_BLOCKS,
|
"text2video": TEXT2VIDEO_BLOCKS,
|
||||||
|
"image2video": IMAGE2VIDEO_BLOCKS,
|
||||||
"auto": AUTO_BLOCKS,
|
"auto": AUTO_BLOCKS,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user