Compare commits

...

2 Commits

Author SHA1 Message Date
Dhruv Nair
969d0f252c update 2024-03-14 10:21:06 +00:00
Dhruv Nair
343f7c5c8a update 2024-03-14 10:20:34 +00:00
3 changed files with 9 additions and 10 deletions

View File

@@ -454,8 +454,7 @@ def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=
model_type = infer_model_type(original_config, checkpoint, model_type)
if pipeline_class_name == "StableDiffusionUpscalePipeline":
image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
return image_size
return 512
elif model_type in ["SDXL", "SDXL-Refiner", "Playground"]:
image_size = 1024

View File

@@ -758,6 +758,7 @@ class StableDiffusionImg2ImgPipeline(
init_latents = torch.cat([init_latents], dim=0)
shape = init_latents.shape
print(shape)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents

View File

@@ -522,7 +522,7 @@ class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
)
single_file_pipe = StableDiffusionUpscalePipeline.from_single_file(ckpt_path, load_safety_checker=True)
single_file_pipe = StableDiffusionUpscalePipeline.from_single_file(ckpt_path)
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
@@ -540,13 +540,12 @@ class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
for param_name, param_value in single_file_pipe.vae.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
# The sample_size parameter for the VAE is incorrectly configured on the hub
# It must be 512, but it is 256 on the hub
if param_name == "sample_size":
pipe.vae.config[param_name] = param_value
assert (
pipe.vae.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.safety_checker.config.to_dict().items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.safety_checker.config.to_dict()[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"