Compare commits

...

1 Commits

Author SHA1 Message Date
Dhruv Nair
0e0d986533 update 2025-06-19 11:31:49 +02:00
2 changed files with 4 additions and 2 deletions

View File

@@ -30,6 +30,7 @@ from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -151,7 +152,7 @@ class SanaControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
else:
generator = torch.Generator(device=device).manual_seed(seed)
control_image = torch.randn(1, 3, 32, 32, generator=generator)
control_image = randn_tensor((1, 3, 32, 32), generator=generator, device=device)
inputs = {
"prompt": "",
"negative_prompt": "",

View File

@@ -24,6 +24,7 @@ from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
@@ -137,7 +138,7 @@ class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = torch.randn(1, 3, 32, 32, generator=generator)
image = randn_tensor((1, 3, 32, 32), generator=generator, device=device)
inputs = {
"prompt": "",
"image": image,