Compare commits

...

4 Commits

Author SHA1 Message Date
Sayak Paul
a46f06a7ef Merge branch 'main' into tru-cfg-hunyuanvideo 2025-01-24 16:00:23 +05:30
Sayak Paul
74f68d9b3a Merge branch 'main' into tru-cfg-hunyuanvideo 2025-01-22 18:16:04 +05:30
Sayak Paul
4f31393d77 Merge branch 'main' into tru-cfg-hunyuanvideo 2025-01-16 17:56:22 +05:30
sayakpaul
ba988355f7 feat: enable true cfg in hunyuanvideo. 2025-01-14 10:44:20 +05:30

View File

@@ -466,12 +466,15 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
self,
prompt: Union[str, List[str]] = None,
prompt_2: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Union[str, List[str]] = None,
height: int = 720,
width: int = 1280,
num_frames: int = 129,
num_inference_steps: int = 50,
sigmas: List[float] = None,
guidance_scale: float = 6.0,
true_cfg_scale: float = 1.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -590,6 +593,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
do_true_cfg = true_cfg_scale > 1.0 and negative_prompt is not None
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
@@ -601,12 +605,29 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
device=device,
max_sequence_length=max_sequence_length,
)
if do_true_cfg:
negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_template=prompt_template,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=None,
pooled_prompt_embeds=None,
prompt_attention_mask=None,
device=device,
max_sequence_length=max_sequence_length,
)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
if pooled_prompt_embeds is not None:
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
if do_true_cfg:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
if negative_pooled_prompt_embeds is not None:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
@@ -658,6 +679,18 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]