Compare commits

...

1 Commits

Author SHA1 Message Date
Patrick von Platen
993907f561 [Torch Compile] Fix torch compile for svd vae 2023-12-18 14:36:53 +00:00

View File

@@ -25,7 +25,7 @@ from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler from ...schedulers import EulerDiscreteScheduler
from ...utils import BaseOutput, logging from ...utils import BaseOutput, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
@@ -211,7 +211,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
latents = 1 / self.vae.config.scaling_factor * latents latents = 1 / self.vae.config.scaling_factor * latents
accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys()) forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
# decode decode_chunk_size frames at a time to avoid OOM # decode decode_chunk_size frames at a time to avoid OOM
frames = [] frames = []