Compare commits

...

1 Commits

Author SHA1 Message Date
Dhruv Nair
5c7a251862 update 2024-12-18 09:44:36 +01:00

View File

@@ -188,6 +188,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: MochiTransformer3DModel,
force_zeros_for_empty_prompt: bool = False,
):
super().__init__()
@@ -205,10 +206,11 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
)
self.default_height = 480
self.default_width = 848
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
def _get_t5_prompt_embeds(
self,
@@ -236,7 +238,11 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
if prompt == "" or prompt[-1] == "":
# The original Mochi implementation zeros out empty negative prompts
# but this can lead to overflow when placing the entire pipeline under the autocast context
# adding this here so that we can enable zeroing prompts if necessary
if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
text_input_ids = torch.zeros_like(text_input_ids, device=device)
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)