mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
1 Commits
find-modul
...
sf-comfy-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0256560f32 |
@@ -79,7 +79,10 @@ CHECKPOINT_KEY_NAMES = {
|
|||||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||||
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
|
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
|
||||||
"animatediff_rgb": "controlnet_cond_embedding.weight",
|
"animatediff_rgb": "controlnet_cond_embedding.weight",
|
||||||
"flux": "double_blocks.0.img_attn.norm.key_norm.scale",
|
"flux": [
|
||||||
|
"double_blocks.0.img_attn.norm.key_norm.scale",
|
||||||
|
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||||
@@ -258,7 +261,7 @@ SCHEDULER_DEFAULT_CONFIG = {
|
|||||||
"timestep_spacing": "leading",
|
"timestep_spacing": "leading",
|
||||||
}
|
}
|
||||||
|
|
||||||
LDM_VAE_KEY = "first_stage_model."
|
LDM_VAE_KEYS = ["first_stage_model.", "vae."]
|
||||||
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
||||||
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
||||||
LDM_UNET_KEY = "model.diffusion_model."
|
LDM_UNET_KEY = "model.diffusion_model."
|
||||||
@@ -267,7 +270,6 @@ LDM_CLIP_PREFIX_TO_REMOVE = [
|
|||||||
"cond_stage_model.transformer.",
|
"cond_stage_model.transformer.",
|
||||||
"conditioner.embedders.0.transformer.",
|
"conditioner.embedders.0.transformer.",
|
||||||
]
|
]
|
||||||
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
|
||||||
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
||||||
|
|
||||||
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
||||||
@@ -518,8 +520,10 @@ def infer_diffusers_model_type(checkpoint):
|
|||||||
else:
|
else:
|
||||||
model_type = "animatediff_v3"
|
model_type = "animatediff_v3"
|
||||||
|
|
||||||
elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint:
|
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
|
||||||
if "guidance_in.in_layer.bias" in checkpoint:
|
if any(
|
||||||
|
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||||
|
):
|
||||||
model_type = "flux-dev"
|
model_type = "flux-dev"
|
||||||
else:
|
else:
|
||||||
model_type = "flux-schnell"
|
model_type = "flux-schnell"
|
||||||
@@ -1178,7 +1182,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|||||||
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
|
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
|
||||||
vae_state_dict = {}
|
vae_state_dict = {}
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else ""
|
vae_key = ""
|
||||||
|
for ldm_vae_key in LDM_VAE_KEYS:
|
||||||
|
if any(k.startswith(ldm_vae_key) for k in keys):
|
||||||
|
vae_key = ldm_vae_key
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith(vae_key):
|
if key.startswith(vae_key):
|
||||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||||
@@ -1883,6 +1891,10 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|||||||
|
|
||||||
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||||
converted_state_dict = {}
|
converted_state_dict = {}
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
for k in keys:
|
||||||
|
if "model.diffusion_model." in k:
|
||||||
|
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||||
|
|
||||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
||||||
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
||||||
|
|||||||
Reference in New Issue
Block a user