Compare commits

...

2 Commits

Author SHA1 Message Date
Dhruv Nair
61481575c0 update 2023-12-12 12:33:53 +00:00
Dhruv Nair
7d7bd7f77b update 2023-12-12 10:37:04 +00:00
2 changed files with 61 additions and 72 deletions

View File

@@ -169,10 +169,12 @@ class FromSingleFileMixin:
load_safety_checker = kwargs.pop("load_safety_checker", True) load_safety_checker = kwargs.pop("load_safety_checker", True)
prediction_type = kwargs.pop("prediction_type", None) prediction_type = kwargs.pop("prediction_type", None)
text_encoder = kwargs.pop("text_encoder", None) text_encoder = kwargs.pop("text_encoder", None)
text_encoder_2 = kwargs.pop("text_encoder_2", None)
vae = kwargs.pop("vae", None) vae = kwargs.pop("vae", None)
controlnet = kwargs.pop("controlnet", None) controlnet = kwargs.pop("controlnet", None)
adapter = kwargs.pop("adapter", None) adapter = kwargs.pop("adapter", None)
tokenizer = kwargs.pop("tokenizer", None) tokenizer = kwargs.pop("tokenizer", None)
tokenizer_2 = kwargs.pop("tokenizer_2", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
@@ -274,8 +276,10 @@ class FromSingleFileMixin:
load_safety_checker=load_safety_checker, load_safety_checker=load_safety_checker,
prediction_type=prediction_type, prediction_type=prediction_type,
text_encoder=text_encoder, text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
vae=vae, vae=vae,
tokenizer=tokenizer, tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
original_config_file=original_config_file, original_config_file=original_config_file,
config_files=config_files, config_files=config_files,
local_files_only=local_files_only, local_files_only=local_files_only,

View File

@@ -1153,7 +1153,9 @@ def download_from_original_stable_diffusion_ckpt(
vae_path=None, vae_path=None,
vae=None, vae=None,
text_encoder=None, text_encoder=None,
text_encoder_2=None,
tokenizer=None, tokenizer=None,
tokenizer_2=None,
config_files=None, config_files=None,
) -> DiffusionPipeline: ) -> DiffusionPipeline:
""" """
@@ -1232,7 +1234,9 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline, StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline, StableUnCLIPPipeline,
@@ -1339,7 +1343,11 @@ def download_from_original_stable_diffusion_ckpt(
else: else:
pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: if num_in_channels is None and pipeline_class in [
StableDiffusionInpaintPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLControlNetInpaintPipeline,
]:
num_in_channels = 9 num_in_channels = 9
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline: if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
num_in_channels = 7 num_in_channels = 7
@@ -1686,7 +1694,9 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
elif model_type in ["SDXL", "SDXL-Refiner"]: elif model_type in ["SDXL", "SDXL-Refiner"]:
if model_type == "SDXL": is_refiner = model_type == "SDXL-Refiner"
if (is_refiner is False) and (tokenizer is None):
try: try:
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only "openai/clip-vit-large-patch14", local_files_only=local_files_only
@@ -1695,7 +1705,11 @@ def download_from_original_stable_diffusion_ckpt(
raise ValueError( raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
) )
if (is_refiner is False) and (text_encoder is None):
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
if tokenizer_2 is None:
try: try:
tokenizer_2 = CLIPTokenizer.from_pretrained( tokenizer_2 = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
@@ -1705,95 +1719,66 @@ def download_from_original_stable_diffusion_ckpt(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'." f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
) )
if text_encoder_2 is None:
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280} config_kwargs = {"projection_dim": 1280}
prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model."
text_encoder_2 = convert_open_clip_checkpoint( text_encoder_2 = convert_open_clip_checkpoint(
checkpoint, checkpoint,
config_name, config_name,
prefix="conditioner.embedders.1.model.", prefix=prefix,
has_projection=True, has_projection=True,
local_files_only=local_files_only, local_files_only=local_files_only,
**config_kwargs, **config_kwargs,
) )
if is_accelerate_available(): # SBM Now move model to cpu. if is_accelerate_available(): # SBM Now move model to cpu.
if model_type in ["SDXL", "SDXL-Refiner"]: for param_name, param in converted_unet_checkpoint.items():
for param_name, param in converted_unet_checkpoint.items(): set_module_tensor_to_device(unet, param_name, "cpu", value=param)
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
if controlnet: if controlnet:
pipe = pipeline_class( pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
elif adapter:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
adapter=adapter,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
else:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
else:
tokenizer = None
text_encoder = None
try:
tokenizer_2 = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
)
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
checkpoint,
config_name,
prefix="conditioner.embedders.0.model.",
has_projection=True,
local_files_only=local_files_only,
**config_kwargs,
)
if is_accelerate_available(): # SBM Now move model to cpu.
if model_type in ["SDXL", "SDXL-Refiner"]:
for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
pipe = StableDiffusionXLImg2ImgPipeline(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder_2=text_encoder_2, text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2, tokenizer_2=tokenizer_2,
unet=unet, unet=unet,
controlnet=controlnet,
scheduler=scheduler, scheduler=scheduler,
requires_aesthetics_score=True, force_zeros_for_empty_prompt=True,
force_zeros_for_empty_prompt=False,
) )
elif adapter:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
adapter=adapter,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
else:
pipeline_kwargs = {
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"unet": unet,
"scheduler": scheduler,
}
if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or (
pipeline_class == StableDiffusionXLInpaintPipeline
):
pipeline_kwargs.update({"requires_aesthetics_score": is_refiner})
pipe = pipeline_class(**pipeline_kwargs)
else: else:
text_config = create_ldm_bert_config(original_config) text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)