Compare commits

...

1 Commits

Author SHA1 Message Date
Patrick von Platen
2dca95f56b Improve loading ckpt 2023-07-03 15:38:39 +00:00

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import io
import requests
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@@ -1432,38 +1434,62 @@ class FromCkptMixin:
else: else:
raise ValueError(f"Unhandled pipeline class: {pipeline_name}") raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
# remove huggingface url
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(pretrained_model_link_or_path) if Path(pretrained_model_link_or_path).is_file():
if not ckpt_path.is_file(): pretrained_model_path_or_dict = pretrained_model_link_or_path
# get repo_id and (potentially nested) file path of ckpt in repo elif not Path(pretrained_model_link_or_path).is_file():
repo_id = "/".join(ckpt_path.parts[:2]) is_hf = False
file_path = "/".join(ckpt_path.parts[2:]) is_civit_ai = False
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
is_hf = True
if file_path.startswith("blob/"): for prefix in ["https://civitai.com/", "civitai.com"]:
file_path = file_path[len("blob/") :] if pretrained_model_link_or_path.startswith(prefix):
if "api" not in pretrained_model_link_or_path:
raise ValueError(f"{pretrained_model_link_or_path} is not a valid Civitai link. Make sure to provide a link in the form: https://civitai.com/api/models/<num>")
is_civit_ai = True
if file_path.startswith("main/"): if is_hf:
file_path = file_path[len("main/") :] # get repo_id and (potentially nested) file path of ckpt in repo
repo_id = "/".join(pretrained_model_link_or_path.parts[:2])
file_path = "/".join(pretrained_model_link_or_path.parts[2:])
pretrained_model_link_or_path = hf_hub_download( if file_path.startswith("blob/"):
repo_id, file_path = file_path[len("blob/") :]
filename=file_path,
cache_dir=cache_dir, if file_path.startswith("main/"):
resume_download=resume_download, file_path = file_path[len("main/") :]
proxies=proxies,
local_files_only=local_files_only, pretrained_model_path_or_dict = hf_hub_download(
use_auth_token=use_auth_token, repo_id,
revision=revision, filename=file_path,
force_download=force_download, cache_dir=cache_dir,
) resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)
pretrained_model_path_or_dict = pretrained_model_link_or_path
elif is_civit_ai:
response = requests.get(pretrained_model_link_or_path)
checkpoint_bytes = response.content
# Create an in-memory byte stream using io.BytesIO()
buffer = io.BytesIO(checkpoint_bytes)
try:
pretrained_model_path_or_dict = safetensors.torch.load(buffer)
except IOError as e:
pass
pretrained_model_path_or_dict = torch.load(buffer)
pipe = download_from_original_stable_diffusion_ckpt( pipe = download_from_original_stable_diffusion_ckpt(
pretrained_model_link_or_path, pretrained_model_path_or_dict,
pipeline_class=cls, pipeline_class=cls,
model_type=model_type, model_type=model_type,
stable_unclip=stable_unclip, stable_unclip=stable_unclip,