Compare commits

...

4 Commits

Author SHA1 Message Date
Dhruv Nair
840344b817 update 2024-03-12 10:36:41 +00:00
Dhruv Nair
7739271db3 update 2024-03-12 10:31:29 +00:00
Dhruv Nair
7f1ea22c07 update 2024-03-11 14:56:12 +00:00
Dhruv Nair
0de7e023fd update 2024-03-11 13:18:13 +00:00
3 changed files with 13 additions and 5 deletions

View File

@@ -56,6 +56,8 @@ def build_sub_model_components(
if component_name == "unet":
num_in_channels = kwargs.pop("num_in_channels", None)
upcast_attention = kwargs.pop("upcast_attention", None)
unet_components = create_diffusers_unet_model_from_ldm(
pipeline_class_name,
original_config,
@@ -64,6 +66,7 @@ def build_sub_model_components(
image_size=image_size,
torch_dtype=torch_dtype,
model_type=model_type,
upcast_attention=upcast_attention,
)
return unet_components
@@ -300,7 +303,9 @@ class FromSingleFileMixin:
continue
init_kwargs.update(components)
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
additional_components = set_additional_components(
class_name, original_config, checkpoint=checkpoint, model_type=model_type
)
if additional_components:
init_kwargs.update(additional_components)

View File

@@ -307,7 +307,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
return original_config
def infer_model_type(original_config, checkpoint=None, model_type=None):
def infer_model_type(original_config, checkpoint, model_type=None):
if model_type is not None:
return model_type
@@ -1176,7 +1176,7 @@ def create_diffusers_unet_model_from_ldm(
original_config,
checkpoint,
num_in_channels=None,
upcast_attention=False,
upcast_attention=None,
extract_ema=False,
image_size=None,
torch_dtype=None,
@@ -1204,7 +1204,8 @@ def create_diffusers_unet_model_from_ldm(
)
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["in_channels"] = num_in_channels
unet_config["upcast_attention"] = upcast_attention
if upcast_attention is not None:
unet_config["upcast_attention"] = upcast_attention
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
ctx = init_empty_weights if is_accelerate_available() else nullcontext

View File

@@ -838,9 +838,11 @@ class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase):
for param_name, param_value in single_file_pipe.unet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
pipe.unet.config[param_name] = False
assert (
pipe.unet.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
), f"{param_name} is differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.vae.config.items():
if param_name in PARAMS_TO_IGNORE: