mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-04 01:45:15 +08:00
Compare commits
2 Commits
qwen-test-
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ebd06f9b11 | ||
|
|
b712042da1 |
@@ -2321,8 +2321,14 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
prefix = "diffusion_model."
|
||||
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_double_layers = 8
|
||||
num_single_layers = 48
|
||||
num_double_layers = 0
|
||||
num_single_layers = 0
|
||||
for key in original_state_dict.keys():
|
||||
if key.startswith("single_blocks."):
|
||||
num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1)
|
||||
elif key.startswith("double_blocks."):
|
||||
num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1)
|
||||
|
||||
lora_keys = ("lora_A", "lora_B")
|
||||
attn_types = ("img_attn", "txt_attn")
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Optional, Union
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import logging
|
||||
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, logging
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
|
||||
|
||||
@@ -220,4 +220,11 @@ class AutoModel(ConfigMixin):
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
model = model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
|
||||
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
|
||||
load_id = "|".join("null" if p is None else p for p in parts)
|
||||
model._diffusers_load_id = load_id
|
||||
|
||||
return model
|
||||
|
||||
@@ -2143,6 +2143,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||
and self._component_specs[name].pretrained_model_name_or_path is not None
|
||||
and getattr(self, name, None) is None
|
||||
]
|
||||
elif isinstance(names, str):
|
||||
names = [names]
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import inspect
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import PIL.Image
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||
from ..utils import is_torch_available, logging
|
||||
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -186,7 +186,7 @@ class ComponentSpec:
|
||||
"""
|
||||
Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True).
|
||||
"""
|
||||
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
|
||||
return DIFFUSERS_LOAD_ID_FIELDS.copy()
|
||||
|
||||
@property
|
||||
def load_id(self) -> str:
|
||||
@@ -198,7 +198,7 @@ class ComponentSpec:
|
||||
return "null"
|
||||
parts = [getattr(self, k) for k in self.loading_fields()]
|
||||
parts = ["null" if p is None else p for p in parts]
|
||||
return "|".join(p for p in parts if p)
|
||||
return "|".join(parts)
|
||||
|
||||
@classmethod
|
||||
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
|
||||
|
||||
@@ -23,6 +23,7 @@ from .constants import (
|
||||
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||
DIFFUSERS_LOAD_ID_FIELDS,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
GGUF_FILE_EXTENSION,
|
||||
HF_ENABLE_PARALLEL_LOADING,
|
||||
|
||||
@@ -73,3 +73,11 @@ DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoint
|
||||
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
|
||||
|
||||
DIFFUSERS_LOAD_ID_FIELDS = [
|
||||
"pretrained_model_name_or_path",
|
||||
"subfolder",
|
||||
"variant",
|
||||
"revision",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user