Compare commits

...

1 Commits

Author SHA1 Message Date
Aryan
8d88935f4a fix pipelines 2024-09-19 09:07:52 +02:00
3 changed files with 17 additions and 28 deletions

View File

@@ -577,19 +577,17 @@ class CogVideoXPipeline(DiffusionPipeline):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1 num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt=prompt,
height, height=height,
width, width=width,
negative_prompt, negative_prompt=negative_prompt,
callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._interrupt = False self._interrupt = False

View File

@@ -438,8 +438,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
width, width,
negative_prompt, negative_prompt,
callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs,
video=None,
latents=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
): ):
@@ -494,9 +492,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` should be provided")
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
def fuse_qkv_projections(self) -> None: def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections.""" r"""Enables fused QKV projections."""
@@ -657,28 +652,26 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline):
`tuple`. When returning a tuple, the first element is a list with the generated images. `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
if num_frames > 49: if num_frames != 49:
raise ValueError( raise ValueError(
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." "The number of frames must be 49 for now due to static learned positional embeddings. This will be updated in the future to remove this limitation."
) )
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1 num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
image, image=image,
prompt, prompt=prompt,
height, height=height,
width, width=width,
negative_prompt, negative_prompt=negative_prompt,
callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._interrupt = False self._interrupt = False

View File

@@ -651,8 +651,6 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1 num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct