Compare commits

...

1 Commits

Author SHA1 Message Date
DN6
327d8e5a6a update 2024-03-07 16:38:16 +05:30

View File

@@ -19,7 +19,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import logging, replace_example_docstring
from ...utils import is_torch_version, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
@@ -361,6 +361,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
device = self._execution_device
dtype = self.decoder.dtype
self._guidance_scale = guidance_scale
if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16:
raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.")
# 1. Check inputs. Raise error if not correct
self.check_inputs(