Compare commits

...

1 Commits

Author SHA1 Message Date
Patrick von Platen
2dca95f56b Improve loading ckpt 2023-07-03 15:38:39 +00:00

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import io
import requests
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@@ -1432,17 +1434,27 @@ class FromCkptMixin:
else: else:
raise ValueError(f"Unhandled pipeline class: {pipeline_name}") raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
# remove huggingface url # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
if Path(pretrained_model_link_or_path).is_file():
pretrained_model_path_or_dict = pretrained_model_link_or_path
elif not Path(pretrained_model_link_or_path).is_file():
is_hf = False
is_civit_ai = False
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
if pretrained_model_link_or_path.startswith(prefix): if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
is_hf = True
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained for prefix in ["https://civitai.com/", "civitai.com"]:
ckpt_path = Path(pretrained_model_link_or_path) if pretrained_model_link_or_path.startswith(prefix):
if not ckpt_path.is_file(): if "api" not in pretrained_model_link_or_path:
raise ValueError(f"{pretrained_model_link_or_path} is not a valid Civitai link. Make sure to provide a link in the form: https://civitai.com/api/models/<num>")
is_civit_ai = True
if is_hf:
# get repo_id and (potentially nested) file path of ckpt in repo # get repo_id and (potentially nested) file path of ckpt in repo
repo_id = "/".join(ckpt_path.parts[:2]) repo_id = "/".join(pretrained_model_link_or_path.parts[:2])
file_path = "/".join(ckpt_path.parts[2:]) file_path = "/".join(pretrained_model_link_or_path.parts[2:])
if file_path.startswith("blob/"): if file_path.startswith("blob/"):
file_path = file_path[len("blob/") :] file_path = file_path[len("blob/") :]
@@ -1450,7 +1462,7 @@ class FromCkptMixin:
if file_path.startswith("main/"): if file_path.startswith("main/"):
file_path = file_path[len("main/") :] file_path = file_path[len("main/") :]
pretrained_model_link_or_path = hf_hub_download( pretrained_model_path_or_dict = hf_hub_download(
repo_id, repo_id,
filename=file_path, filename=file_path,
cache_dir=cache_dir, cache_dir=cache_dir,
@@ -1461,9 +1473,23 @@ class FromCkptMixin:
revision=revision, revision=revision,
force_download=force_download, force_download=force_download,
) )
pretrained_model_path_or_dict = pretrained_model_link_or_path
elif is_civit_ai:
response = requests.get(pretrained_model_link_or_path)
checkpoint_bytes = response.content
# Create an in-memory byte stream using io.BytesIO()
buffer = io.BytesIO(checkpoint_bytes)
try:
pretrained_model_path_or_dict = safetensors.torch.load(buffer)
except IOError as e:
pass
pretrained_model_path_or_dict = torch.load(buffer)
pipe = download_from_original_stable_diffusion_ckpt( pipe = download_from_original_stable_diffusion_ckpt(
pretrained_model_link_or_path, pretrained_model_path_or_dict,
pipeline_class=cls, pipeline_class=cls,
model_type=model_type, model_type=model_type,
stable_unclip=stable_unclip, stable_unclip=stable_unclip,