Compare commits

...

6 Commits

Author SHA1 Message Date
DN6
cb69798b3d update 2025-10-27 18:11:28 +05:30
DN6
0229976ab5 update 2025-10-23 16:08:35 +05:30
Dhruv Nair
8f1b207ffd Merge branch 'main' into vace-fix 2025-10-23 15:11:28 +05:30
DN6
99308efb55 update 2025-10-03 16:48:43 +05:30
DN6
5015ce4fc7 update 2025-10-03 16:44:23 +05:30
DN6
5ed984cc47 update 2025-10-03 14:42:58 +05:30
2 changed files with 137 additions and 25 deletions

View File

@@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder ([`T5EncoderModel`]): text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanVACETransformer3DModel`]):
Conditional Transformer to denoise the input latents.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
`transformer` is used.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]): vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
transformer ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
boundary_ratio (`float`, *optional*, defaults to `None`): boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. boundary_timestep. If `None`, only the available transformer is used for the entire denoising process.
""" """
model_cpu_offload_seq = "text_encoder->transformer->vae" model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2"] _optional_components = ["transformer", "transformer_2"]
def __init__( def __init__(
self, self,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel, text_encoder: UMT5EncoderModel,
transformer: WanVACETransformer3DModel,
vae: AutoencoderKLWan, vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler, scheduler: FlowMatchEulerDiscreteScheduler,
transformer: WanVACETransformer3DModel = None,
transformer_2: WanVACETransformer3DModel = None, transformer_2: WanVACETransformer3DModel = None,
boundary_ratio: Optional[float] = None, boundary_ratio: Optional[float] = None,
): ):
@@ -336,7 +338,15 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
reference_images=None, reference_images=None,
guidance_scale_2=None, guidance_scale_2=None,
): ):
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] if self.transformer is not None:
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
elif self.transformer_2 is not None:
base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1]
else:
raise ValueError(
"`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
)
if height % base != 0 or width % base != 0: if height % base != 0 or width % base != 0:
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
@@ -414,7 +424,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
): ):
if video is not None: if video is not None:
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] base = self.vae_scale_factor_spatial * (
self.transformer.config.patch_size[1]
if self.transformer is not None
else self.transformer_2.config.patch_size[1]
)
video_height, video_width = self.video_processor.get_default_height_width(video[0]) video_height, video_width = self.video_processor.get_default_height_width(video[0])
if video_height * video_width > height * width: if video_height * video_width > height * width:
@@ -589,7 +603,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
"Generating with more than one video is not yet supported. This may be supported in the future." "Generating with more than one video is not yet supported. This may be supported in the future."
) )
transformer_patch_size = self.transformer.config.patch_size[1] transformer_patch_size = (
self.transformer.config.patch_size[1]
if self.transformer is not None
else self.transformer_2.config.patch_size[1]
)
mask_list = [] mask_list = []
for mask_, reference_images_batch in zip(mask, reference_images): for mask_, reference_images_batch in zip(mask, reference_images):
@@ -844,20 +862,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
vae_dtype = self.vae.dtype vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
vace_layers = (
self.transformer.config.vace_layers
if self.transformer is not None
else self.transformer_2.config.vace_layers
)
if isinstance(conditioning_scale, (int, float)): if isinstance(conditioning_scale, (int, float)):
conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers) conditioning_scale = [conditioning_scale] * len(vace_layers)
if isinstance(conditioning_scale, list): if isinstance(conditioning_scale, list):
if len(conditioning_scale) != len(self.transformer.config.vace_layers): if len(conditioning_scale) != len(vace_layers):
raise ValueError( raise ValueError(
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}." f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}."
) )
conditioning_scale = torch.tensor(conditioning_scale) conditioning_scale = torch.tensor(conditioning_scale)
if isinstance(conditioning_scale, torch.Tensor): if isinstance(conditioning_scale, torch.Tensor):
if conditioning_scale.size(0) != len(self.transformer.config.vace_layers): if conditioning_scale.size(0) != len(vace_layers):
raise ValueError( raise ValueError(
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}." f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}."
) )
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype) conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
@@ -900,7 +923,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1) conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
conditioning_latents = conditioning_latents.to(transformer_dtype) conditioning_latents = conditioning_latents.to(transformer_dtype)
num_channels_latents = self.transformer.config.in_channels num_channels_latents = (
self.transformer.config.in_channels
if self.transformer is not None
else self.transformer_2.config.in_channels
)
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_videos_per_prompt, batch_size * num_videos_per_prompt,
num_channels_latents, num_channels_latents,
@@ -968,7 +995,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
attention_kwargs=attention_kwargs, attention_kwargs=attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile
import unittest import unittest
import numpy as np import numpy as np
@@ -19,9 +20,15 @@ import torch
from PIL import Image from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel from diffusers import (
AutoencoderKLWan,
FlowMatchEulerDiscreteScheduler,
UniPCMultistepScheduler,
WanVACEPipeline,
WanVACETransformer3DModel,
)
from ...testing_utils import enable_full_determinism from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
@@ -212,3 +219,81 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
) )
def test_save_load_float16(self): def test_save_load_float16(self):
pass pass
def test_inference_with_only_transformer(self):
components = self.get_dummy_components()
components["transformer_2"] = None
components["boundary_ratio"] = 0.0
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
video = pipe(**inputs).frames[0]
assert video.shape == (17, 3, 16, 16)
def test_inference_with_only_transformer_2(self):
components = self.get_dummy_components()
components["transformer_2"] = components["transformer"]
components["transformer"] = None
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components["scheduler"] = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
)
components["boundary_ratio"] = 1.0
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
video = pipe(**inputs).frames[0]
assert video.shape == (17, 3, 16, 16)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
optional_component = ["transformer"]
components = self.get_dummy_components()
components["transformer_2"] = components["transformer"]
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components["scheduler"] = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
)
for component in optional_component:
components[component] = None
components["boundary_ratio"] = 1.0
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for component in optional_component:
assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading."
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference"