Compare commits

...

1 Commits

Author SHA1 Message Date
Dhruv Nair
11270912d8 update 2024-12-18 11:31:07 +01:00

View File

@@ -151,6 +151,8 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
@@ -587,7 +589,13 @@ def infer_diffusers_model_type(checkpoint):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
):
model_type = "flux-dev"
if checkpoint["img_in.weight"].shape[1] == 384:
model_type = "flux-fill"
elif checkpoint["img_in.weight"].shape[1] == 128:
model_type = "flux-depth"
else:
model_type = "flux-dev"
else:
model_type = "flux-schnell"