mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-14 00:14:23 +08:00
Compare commits
1 Commits
fix/lora-l
...
improve_ck
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2dca95f56b |
@@ -12,6 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
|
import io
|
||||||
|
import requests
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -1432,17 +1434,27 @@ class FromCkptMixin:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
||||||
|
|
||||||
# remove huggingface url
|
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||||
|
if Path(pretrained_model_link_or_path).is_file():
|
||||||
|
pretrained_model_path_or_dict = pretrained_model_link_or_path
|
||||||
|
elif not Path(pretrained_model_link_or_path).is_file():
|
||||||
|
is_hf = False
|
||||||
|
is_civit_ai = False
|
||||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||||
if pretrained_model_link_or_path.startswith(prefix):
|
if pretrained_model_link_or_path.startswith(prefix):
|
||||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||||
|
is_hf = True
|
||||||
|
|
||||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
for prefix in ["https://civitai.com/", "civitai.com"]:
|
||||||
ckpt_path = Path(pretrained_model_link_or_path)
|
if pretrained_model_link_or_path.startswith(prefix):
|
||||||
if not ckpt_path.is_file():
|
if "api" not in pretrained_model_link_or_path:
|
||||||
|
raise ValueError(f"{pretrained_model_link_or_path} is not a valid Civitai link. Make sure to provide a link in the form: https://civitai.com/api/models/<num>")
|
||||||
|
is_civit_ai = True
|
||||||
|
|
||||||
|
if is_hf:
|
||||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||||
repo_id = "/".join(ckpt_path.parts[:2])
|
repo_id = "/".join(pretrained_model_link_or_path.parts[:2])
|
||||||
file_path = "/".join(ckpt_path.parts[2:])
|
file_path = "/".join(pretrained_model_link_or_path.parts[2:])
|
||||||
|
|
||||||
if file_path.startswith("blob/"):
|
if file_path.startswith("blob/"):
|
||||||
file_path = file_path[len("blob/") :]
|
file_path = file_path[len("blob/") :]
|
||||||
@@ -1450,7 +1462,7 @@ class FromCkptMixin:
|
|||||||
if file_path.startswith("main/"):
|
if file_path.startswith("main/"):
|
||||||
file_path = file_path[len("main/") :]
|
file_path = file_path[len("main/") :]
|
||||||
|
|
||||||
pretrained_model_link_or_path = hf_hub_download(
|
pretrained_model_path_or_dict = hf_hub_download(
|
||||||
repo_id,
|
repo_id,
|
||||||
filename=file_path,
|
filename=file_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
@@ -1461,9 +1473,23 @@ class FromCkptMixin:
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
)
|
)
|
||||||
|
pretrained_model_path_or_dict = pretrained_model_link_or_path
|
||||||
|
elif is_civit_ai:
|
||||||
|
response = requests.get(pretrained_model_link_or_path)
|
||||||
|
checkpoint_bytes = response.content
|
||||||
|
|
||||||
|
# Create an in-memory byte stream using io.BytesIO()
|
||||||
|
buffer = io.BytesIO(checkpoint_bytes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
pretrained_model_path_or_dict = safetensors.torch.load(buffer)
|
||||||
|
except IOError as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
pretrained_model_path_or_dict = torch.load(buffer)
|
||||||
|
|
||||||
pipe = download_from_original_stable_diffusion_ckpt(
|
pipe = download_from_original_stable_diffusion_ckpt(
|
||||||
pretrained_model_link_or_path,
|
pretrained_model_path_or_dict,
|
||||||
pipeline_class=cls,
|
pipeline_class=cls,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
stable_unclip=stable_unclip,
|
stable_unclip=stable_unclip,
|
||||||
|
|||||||
Reference in New Issue
Block a user