mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
15 Commits
chroma-fix
...
fix/single
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed91e8b3e6 | ||
|
|
a77e426877 | ||
|
|
c0a0ef5deb | ||
|
|
9e35a12587 | ||
|
|
49b0b516ea | ||
|
|
52ba8061d3 | ||
|
|
2be231cce5 | ||
|
|
4b315f16a8 | ||
|
|
29e6b873c4 | ||
|
|
6d3e82c9cd | ||
|
|
1f358e1331 | ||
|
|
c1d0e091af | ||
|
|
9d90d60753 | ||
|
|
a4e00abb68 | ||
|
|
ce4f4f4545 |
@@ -63,13 +63,20 @@ def build_sub_model_components(
|
||||
num_in_channels=num_in_channels,
|
||||
image_size=image_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type=model_type,
|
||||
)
|
||||
return unet_components
|
||||
|
||||
if component_name == "vae":
|
||||
scaling_factor = kwargs.get("scaling_factor", None)
|
||||
vae_components = create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size, scaling_factor, torch_dtype
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
image_size,
|
||||
scaling_factor,
|
||||
torch_dtype,
|
||||
model_type=model_type,
|
||||
)
|
||||
return vae_components
|
||||
|
||||
@@ -124,11 +131,12 @@ def build_sub_model_components(
|
||||
def set_additional_components(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint=None,
|
||||
model_type=None,
|
||||
):
|
||||
components = {}
|
||||
if pipeline_class_name in REFINER_PIPELINES:
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
||||
is_refiner = model_type == "SDXL-Refiner"
|
||||
components.update(
|
||||
{
|
||||
|
||||
@@ -28,6 +28,7 @@ from ..schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EDMDPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
@@ -175,6 +176,7 @@ DIFFUSERS_TO_LDM_MAPPING = {
|
||||
|
||||
LDM_VAE_KEY = "first_stage_model."
|
||||
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
||||
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
||||
LDM_UNET_KEY = "model.diffusion_model."
|
||||
LDM_CONTROLNET_KEY = "control_model."
|
||||
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
||||
@@ -305,7 +307,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
|
||||
return original_config
|
||||
|
||||
|
||||
def infer_model_type(original_config, model_type=None):
|
||||
def infer_model_type(original_config, checkpoint=None, model_type=None):
|
||||
if model_type is not None:
|
||||
return model_type
|
||||
|
||||
@@ -323,7 +325,9 @@ def infer_model_type(original_config, model_type=None):
|
||||
|
||||
elif has_network_config:
|
||||
context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"]
|
||||
if context_dim == 2048:
|
||||
if "edm_mean" in checkpoint and "edm_std" in checkpoint:
|
||||
model_type = "Playground"
|
||||
elif context_dim == 2048:
|
||||
model_type = "SDXL"
|
||||
else:
|
||||
model_type = "SDXL-Refiner"
|
||||
@@ -344,13 +348,13 @@ def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=
|
||||
return image_size
|
||||
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
model_type = infer_model_type(original_config, model_type)
|
||||
model_type = infer_model_type(original_config, checkpoint, model_type)
|
||||
|
||||
if pipeline_class_name == "StableDiffusionUpscalePipeline":
|
||||
image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
|
||||
return image_size
|
||||
|
||||
elif model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
elif model_type in ["SDXL", "SDXL-Refiner", "Playground"]:
|
||||
image_size = 1024
|
||||
return image_size
|
||||
|
||||
@@ -506,12 +510,14 @@ def create_controlnet_diffusers_config(original_config, image_size: int):
|
||||
return controlnet_config
|
||||
|
||||
|
||||
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None):
|
||||
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None, latents_mean=None, latents_std=None):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
if scaling_factor is None and "scale_factor" in original_config["model"]["params"]:
|
||||
if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
|
||||
scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
|
||||
elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
|
||||
scaling_factor = original_config["model"]["params"]["scale_factor"]
|
||||
elif scaling_factor is None:
|
||||
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
|
||||
@@ -531,6 +537,8 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
|
||||
"layers_per_block": vae_params["num_res_blocks"],
|
||||
"scaling_factor": scaling_factor,
|
||||
}
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
config.update({"latents_mean": latents_mean, "latents_std": latents_std})
|
||||
|
||||
return config
|
||||
|
||||
@@ -1172,6 +1180,7 @@ def create_diffusers_unet_model_from_ldm(
|
||||
extract_ema=False,
|
||||
image_size=None,
|
||||
torch_dtype=None,
|
||||
model_type=None,
|
||||
):
|
||||
from ..models import UNet2DConditionModel
|
||||
|
||||
@@ -1190,7 +1199,9 @@ def create_diffusers_unet_model_from_ldm(
|
||||
else:
|
||||
num_in_channels = 4
|
||||
|
||||
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
|
||||
image_size = set_image_size(
|
||||
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
|
||||
)
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["in_channels"] = num_in_channels
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
@@ -1223,14 +1234,40 @@ def create_diffusers_unet_model_from_ldm(
|
||||
|
||||
|
||||
def create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None, torch_dtype=None
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
image_size=None,
|
||||
scaling_factor=None,
|
||||
torch_dtype=None,
|
||||
model_type=None,
|
||||
):
|
||||
# import here to avoid circular imports
|
||||
from ..models import AutoencoderKL
|
||||
|
||||
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
|
||||
image_size = set_image_size(
|
||||
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
|
||||
)
|
||||
model_type = infer_model_type(original_config, checkpoint, model_type)
|
||||
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size, scaling_factor=scaling_factor)
|
||||
if model_type == "Playground":
|
||||
edm_mean = (
|
||||
checkpoint["edm_mean"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_mean"].tolist()
|
||||
)
|
||||
edm_std = (
|
||||
checkpoint["edm_std"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_std"].tolist()
|
||||
)
|
||||
else:
|
||||
edm_mean = None
|
||||
edm_std = None
|
||||
|
||||
vae_config = create_vae_diffusers_config(
|
||||
original_config,
|
||||
image_size=image_size,
|
||||
scaling_factor=scaling_factor,
|
||||
latents_mean=edm_mean,
|
||||
latents_std=edm_std,
|
||||
)
|
||||
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
@@ -1265,7 +1302,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
||||
local_files_only=False,
|
||||
torch_dtype=None,
|
||||
):
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
config_name = "stabilityai/stable-diffusion-2"
|
||||
@@ -1332,7 +1369,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
||||
"text_encoder_2": text_encoder_2,
|
||||
}
|
||||
|
||||
elif model_type == "SDXL":
|
||||
elif model_type in ["SDXL", "Playground"]:
|
||||
try:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
@@ -1383,7 +1420,7 @@ def create_scheduler_from_ldm(
|
||||
model_type=None,
|
||||
):
|
||||
scheduler_config = get_default_scheduler_config()
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
||||
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
|
||||
@@ -1406,7 +1443,8 @@ def create_scheduler_from_ldm(
|
||||
|
||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
scheduler_type = "euler"
|
||||
|
||||
elif model_type == "Playground":
|
||||
scheduler_type = "edm_dpm_solver_multistep"
|
||||
else:
|
||||
beta_start = original_config["model"]["params"].get("linear_start", 0.02)
|
||||
beta_end = original_config["model"]["params"].get("linear_end", 0.085)
|
||||
@@ -1438,6 +1476,26 @@ def create_scheduler_from_ldm(
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = DDIMScheduler.from_config(scheduler_config)
|
||||
|
||||
elif scheduler_type == "edm_dpm_solver_multistep":
|
||||
scheduler_config = {
|
||||
"algorithm_type": "dpmsolver++",
|
||||
"dynamic_thresholding_ratio": 0.995,
|
||||
"euler_at_final": False,
|
||||
"final_sigmas_type": "zero",
|
||||
"lower_order_final": True,
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"rho": 7.0,
|
||||
"sample_max_value": 1.0,
|
||||
"sigma_data": 0.5,
|
||||
"sigma_max": 80.0,
|
||||
"sigma_min": 0.002,
|
||||
"solver_order": 2,
|
||||
"solver_type": "midpoint",
|
||||
"thresholding": False,
|
||||
}
|
||||
scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user