mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
4 Commits
pipe-fetch
...
single-fil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
840344b817 | ||
|
|
7739271db3 | ||
|
|
7f1ea22c07 | ||
|
|
0de7e023fd |
@@ -56,6 +56,8 @@ def build_sub_model_components(
|
||||
|
||||
if component_name == "unet":
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
|
||||
unet_components = create_diffusers_unet_model_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
@@ -64,6 +66,7 @@ def build_sub_model_components(
|
||||
image_size=image_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type=model_type,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
return unet_components
|
||||
|
||||
@@ -300,7 +303,9 @@ class FromSingleFileMixin:
|
||||
continue
|
||||
init_kwargs.update(components)
|
||||
|
||||
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
|
||||
additional_components = set_additional_components(
|
||||
class_name, original_config, checkpoint=checkpoint, model_type=model_type
|
||||
)
|
||||
if additional_components:
|
||||
init_kwargs.update(additional_components)
|
||||
|
||||
|
||||
@@ -307,7 +307,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
|
||||
return original_config
|
||||
|
||||
|
||||
def infer_model_type(original_config, checkpoint=None, model_type=None):
|
||||
def infer_model_type(original_config, checkpoint, model_type=None):
|
||||
if model_type is not None:
|
||||
return model_type
|
||||
|
||||
@@ -1176,7 +1176,7 @@ def create_diffusers_unet_model_from_ldm(
|
||||
original_config,
|
||||
checkpoint,
|
||||
num_in_channels=None,
|
||||
upcast_attention=False,
|
||||
upcast_attention=None,
|
||||
extract_ema=False,
|
||||
image_size=None,
|
||||
torch_dtype=None,
|
||||
@@ -1204,7 +1204,8 @@ def create_diffusers_unet_model_from_ldm(
|
||||
)
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["in_channels"] = num_in_channels
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
if upcast_attention is not None:
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
|
||||
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
@@ -838,9 +838,11 @@ class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase):
|
||||
for param_name, param_value in single_file_pipe.unet.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
|
||||
pipe.unet.config[param_name] = False
|
||||
assert (
|
||||
pipe.unet.config[param_name] == param_value
|
||||
), f"{param_name} differs between single file loading and pretrained loading"
|
||||
), f"{param_name} is differs between single file loading and pretrained loading"
|
||||
|
||||
for param_name, param_value in single_file_pipe.vae.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
|
||||
Reference in New Issue
Block a user