Compare commits

...

1 Commits

Author SHA1 Message Date
Patrick von Platen
2dca95f56b Improve loading ckpt 2023-07-03 15:38:39 +00:00

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import io
import requests
import warnings
from collections import defaultdict
from pathlib import Path
@@ -1432,38 +1434,62 @@ class FromCkptMixin:
else:
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
# remove huggingface url
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(pretrained_model_link_or_path)
if not ckpt_path.is_file():
# get repo_id and (potentially nested) file path of ckpt in repo
repo_id = "/".join(ckpt_path.parts[:2])
file_path = "/".join(ckpt_path.parts[2:])
if Path(pretrained_model_link_or_path).is_file():
pretrained_model_path_or_dict = pretrained_model_link_or_path
elif not Path(pretrained_model_link_or_path).is_file():
is_hf = False
is_civit_ai = False
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
is_hf = True
if file_path.startswith("blob/"):
file_path = file_path[len("blob/") :]
for prefix in ["https://civitai.com/", "civitai.com"]:
if pretrained_model_link_or_path.startswith(prefix):
if "api" not in pretrained_model_link_or_path:
raise ValueError(f"{pretrained_model_link_or_path} is not a valid Civitai link. Make sure to provide a link in the form: https://civitai.com/api/models/<num>")
is_civit_ai = True
if file_path.startswith("main/"):
file_path = file_path[len("main/") :]
if is_hf:
# get repo_id and (potentially nested) file path of ckpt in repo
repo_id = "/".join(pretrained_model_link_or_path.parts[:2])
file_path = "/".join(pretrained_model_link_or_path.parts[2:])
pretrained_model_link_or_path = hf_hub_download(
repo_id,
filename=file_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)
if file_path.startswith("blob/"):
file_path = file_path[len("blob/") :]
if file_path.startswith("main/"):
file_path = file_path[len("main/") :]
pretrained_model_path_or_dict = hf_hub_download(
repo_id,
filename=file_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)
pretrained_model_path_or_dict = pretrained_model_link_or_path
elif is_civit_ai:
response = requests.get(pretrained_model_link_or_path)
checkpoint_bytes = response.content
# Create an in-memory byte stream using io.BytesIO()
buffer = io.BytesIO(checkpoint_bytes)
try:
pretrained_model_path_or_dict = safetensors.torch.load(buffer)
except IOError as e:
pass
pretrained_model_path_or_dict = torch.load(buffer)
pipe = download_from_original_stable_diffusion_ckpt(
pretrained_model_link_or_path,
pretrained_model_path_or_dict,
pipeline_class=cls,
model_type=model_type,
stable_unclip=stable_unclip,