Compare commits

...

1 Commits

Author SHA1 Message Date
Patrick von Platen
993907f561 [Torch Compile] Fix torch compile for svd vae 2023-12-18 14:36:53 +00:00

View File

@@ -25,7 +25,7 @@ from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import BaseOutput, logging
from ...utils.torch_utils import randn_tensor
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
@@ -211,7 +211,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
latents = 1 / self.vae.config.scaling_factor * latents
accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys())
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
# decode decode_chunk_size frames at a time to avoid OOM
frames = []