mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
1 Commits
controlnet
...
improve_ck
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2dca95f56b |
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import io
|
||||
import requests
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
@@ -1432,38 +1434,62 @@ class FromCkptMixin:
|
||||
else:
|
||||
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
if Path(pretrained_model_link_or_path).is_file():
|
||||
pretrained_model_path_or_dict = pretrained_model_link_or_path
|
||||
elif not Path(pretrained_model_link_or_path).is_file():
|
||||
is_hf = False
|
||||
is_civit_ai = False
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
is_hf = True
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
for prefix in ["https://civitai.com/", "civitai.com"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
if "api" not in pretrained_model_link_or_path:
|
||||
raise ValueError(f"{pretrained_model_link_or_path} is not a valid Civitai link. Make sure to provide a link in the form: https://civitai.com/api/models/<num>")
|
||||
is_civit_ai = True
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
if is_hf:
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(pretrained_model_link_or_path.parts[:2])
|
||||
file_path = "/".join(pretrained_model_link_or_path.parts[2:])
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_path_or_dict = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
pretrained_model_path_or_dict = pretrained_model_link_or_path
|
||||
elif is_civit_ai:
|
||||
response = requests.get(pretrained_model_link_or_path)
|
||||
checkpoint_bytes = response.content
|
||||
|
||||
# Create an in-memory byte stream using io.BytesIO()
|
||||
buffer = io.BytesIO(checkpoint_bytes)
|
||||
|
||||
try:
|
||||
pretrained_model_path_or_dict = safetensors.torch.load(buffer)
|
||||
except IOError as e:
|
||||
pass
|
||||
|
||||
pretrained_model_path_or_dict = torch.load(buffer)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
pretrained_model_path_or_dict,
|
||||
pipeline_class=cls,
|
||||
model_type=model_type,
|
||||
stable_unclip=stable_unclip,
|
||||
|
||||
Reference in New Issue
Block a user