mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-17 09:54:41 +08:00
Compare commits
2 Commits
diffusers-
...
sdxl-inpat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61481575c0 | ||
|
|
7d7bd7f77b |
@@ -169,10 +169,12 @@ class FromSingleFileMixin:
|
|||||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||||
prediction_type = kwargs.pop("prediction_type", None)
|
prediction_type = kwargs.pop("prediction_type", None)
|
||||||
text_encoder = kwargs.pop("text_encoder", None)
|
text_encoder = kwargs.pop("text_encoder", None)
|
||||||
|
text_encoder_2 = kwargs.pop("text_encoder_2", None)
|
||||||
vae = kwargs.pop("vae", None)
|
vae = kwargs.pop("vae", None)
|
||||||
controlnet = kwargs.pop("controlnet", None)
|
controlnet = kwargs.pop("controlnet", None)
|
||||||
adapter = kwargs.pop("adapter", None)
|
adapter = kwargs.pop("adapter", None)
|
||||||
tokenizer = kwargs.pop("tokenizer", None)
|
tokenizer = kwargs.pop("tokenizer", None)
|
||||||
|
tokenizer_2 = kwargs.pop("tokenizer_2", None)
|
||||||
|
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
|
||||||
@@ -274,8 +276,10 @@ class FromSingleFileMixin:
|
|||||||
load_safety_checker=load_safety_checker,
|
load_safety_checker=load_safety_checker,
|
||||||
prediction_type=prediction_type,
|
prediction_type=prediction_type,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
text_encoder_2=text_encoder_2,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
tokenizer_2=tokenizer_2,
|
||||||
original_config_file=original_config_file,
|
original_config_file=original_config_file,
|
||||||
config_files=config_files,
|
config_files=config_files,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
|||||||
@@ -1153,7 +1153,9 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
vae_path=None,
|
vae_path=None,
|
||||||
vae=None,
|
vae=None,
|
||||||
text_encoder=None,
|
text_encoder=None,
|
||||||
|
text_encoder_2=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
|
tokenizer_2=None,
|
||||||
config_files=None,
|
config_files=None,
|
||||||
) -> DiffusionPipeline:
|
) -> DiffusionPipeline:
|
||||||
"""
|
"""
|
||||||
@@ -1232,7 +1234,9 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
StableDiffusionInpaintPipeline,
|
StableDiffusionInpaintPipeline,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionUpscalePipeline,
|
||||||
|
StableDiffusionXLControlNetInpaintPipeline,
|
||||||
StableDiffusionXLImg2ImgPipeline,
|
StableDiffusionXLImg2ImgPipeline,
|
||||||
|
StableDiffusionXLInpaintPipeline,
|
||||||
StableDiffusionXLPipeline,
|
StableDiffusionXLPipeline,
|
||||||
StableUnCLIPImg2ImgPipeline,
|
StableUnCLIPImg2ImgPipeline,
|
||||||
StableUnCLIPPipeline,
|
StableUnCLIPPipeline,
|
||||||
@@ -1339,7 +1343,11 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
else:
|
else:
|
||||||
pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
|
pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
|
||||||
|
|
||||||
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
|
if num_in_channels is None and pipeline_class in [
|
||||||
|
StableDiffusionInpaintPipeline,
|
||||||
|
StableDiffusionXLInpaintPipeline,
|
||||||
|
StableDiffusionXLControlNetInpaintPipeline,
|
||||||
|
]:
|
||||||
num_in_channels = 9
|
num_in_channels = 9
|
||||||
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
|
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
|
||||||
num_in_channels = 7
|
num_in_channels = 7
|
||||||
@@ -1686,7 +1694,9 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
)
|
)
|
||||||
elif model_type in ["SDXL", "SDXL-Refiner"]:
|
elif model_type in ["SDXL", "SDXL-Refiner"]:
|
||||||
if model_type == "SDXL":
|
is_refiner = model_type == "SDXL-Refiner"
|
||||||
|
|
||||||
|
if (is_refiner is False) and (tokenizer is None):
|
||||||
try:
|
try:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
||||||
@@ -1695,7 +1705,11 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (is_refiner is False) and (text_encoder is None):
|
||||||
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||||
|
|
||||||
|
if tokenizer_2 is None:
|
||||||
try:
|
try:
|
||||||
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
||||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
|
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
|
||||||
@@ -1705,19 +1719,21 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
|
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if text_encoder_2 is None:
|
||||||
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||||
config_kwargs = {"projection_dim": 1280}
|
config_kwargs = {"projection_dim": 1280}
|
||||||
|
prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model."
|
||||||
|
|
||||||
text_encoder_2 = convert_open_clip_checkpoint(
|
text_encoder_2 = convert_open_clip_checkpoint(
|
||||||
checkpoint,
|
checkpoint,
|
||||||
config_name,
|
config_name,
|
||||||
prefix="conditioner.embedders.1.model.",
|
prefix=prefix,
|
||||||
has_projection=True,
|
has_projection=True,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
**config_kwargs,
|
**config_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_accelerate_available(): # SBM Now move model to cpu.
|
if is_accelerate_available(): # SBM Now move model to cpu.
|
||||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
|
||||||
for param_name, param in converted_unet_checkpoint.items():
|
for param_name, param in converted_unet_checkpoint.items():
|
||||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||||
|
|
||||||
@@ -1745,55 +1761,24 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
force_zeros_for_empty_prompt=True,
|
force_zeros_for_empty_prompt=True,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
pipe = pipeline_class(
|
|
||||||
vae=vae,
|
|
||||||
text_encoder=text_encoder,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
text_encoder_2=text_encoder_2,
|
|
||||||
tokenizer_2=tokenizer_2,
|
|
||||||
unet=unet,
|
|
||||||
scheduler=scheduler,
|
|
||||||
force_zeros_for_empty_prompt=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tokenizer = None
|
|
||||||
text_encoder = None
|
|
||||||
try:
|
|
||||||
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
|
||||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
raise ValueError(
|
|
||||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
|
|
||||||
)
|
|
||||||
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
|
||||||
config_kwargs = {"projection_dim": 1280}
|
|
||||||
text_encoder_2 = convert_open_clip_checkpoint(
|
|
||||||
checkpoint,
|
|
||||||
config_name,
|
|
||||||
prefix="conditioner.embedders.0.model.",
|
|
||||||
has_projection=True,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
**config_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_accelerate_available(): # SBM Now move model to cpu.
|
else:
|
||||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
pipeline_kwargs = {
|
||||||
for param_name, param in converted_unet_checkpoint.items():
|
"vae": vae,
|
||||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
"text_encoder": text_encoder,
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
"text_encoder_2": text_encoder_2,
|
||||||
|
"tokenizer_2": tokenizer_2,
|
||||||
|
"unet": unet,
|
||||||
|
"scheduler": scheduler,
|
||||||
|
}
|
||||||
|
|
||||||
pipe = StableDiffusionXLImg2ImgPipeline(
|
if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or (
|
||||||
vae=vae,
|
pipeline_class == StableDiffusionXLInpaintPipeline
|
||||||
text_encoder=text_encoder,
|
):
|
||||||
tokenizer=tokenizer,
|
pipeline_kwargs.update({"requires_aesthetics_score": is_refiner})
|
||||||
text_encoder_2=text_encoder_2,
|
|
||||||
tokenizer_2=tokenizer_2,
|
pipe = pipeline_class(**pipeline_kwargs)
|
||||||
unet=unet,
|
|
||||||
scheduler=scheduler,
|
|
||||||
requires_aesthetics_score=True,
|
|
||||||
force_zeros_for_empty_prompt=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
text_config = create_ldm_bert_config(original_config)
|
text_config = create_ldm_bert_config(original_config)
|
||||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||||
|
|||||||
Reference in New Issue
Block a user