mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-14 08:24:32 +08:00
Compare commits
13 Commits
enable-cp-
...
sf-modular
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
15fb54de11 | ||
|
|
207847552d | ||
|
|
2cd7256ecb | ||
|
|
d391ace1ca | ||
|
|
546b6945df | ||
|
|
6dda0ecf67 | ||
|
|
fb1766ee4f | ||
|
|
8733fef39d | ||
|
|
a707e314ad | ||
|
|
c1c0e9a481 | ||
|
|
0a7bde9200 | ||
|
|
af48d815d8 | ||
|
|
bea02ccba3 |
@@ -159,7 +159,7 @@ Change the [`~ComponentSpec.default_creation_method`] to `from_pretrained` and u
|
|||||||
```py
|
```py
|
||||||
guider_spec = t2i_pipeline.get_component_spec("guider")
|
guider_spec = t2i_pipeline.get_component_spec("guider")
|
||||||
guider_spec.default_creation_method="from_pretrained"
|
guider_spec.default_creation_method="from_pretrained"
|
||||||
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
|
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
|
||||||
guider_spec.subfolder="pag_guider"
|
guider_spec.subfolder="pag_guider"
|
||||||
pag_guider = guider_spec.load()
|
pag_guider = guider_spec.load()
|
||||||
t2i_pipeline.update_components(guider=pag_guider)
|
t2i_pipeline.update_components(guider=pag_guider)
|
||||||
|
|||||||
@@ -313,14 +313,14 @@ unet_spec
|
|||||||
ComponentSpec(
|
ComponentSpec(
|
||||||
name='unet',
|
name='unet',
|
||||||
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
|
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
|
||||||
repo='RunDiffusion/Juggernaut-XL-v9',
|
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
|
||||||
subfolder='unet',
|
subfolder='unet',
|
||||||
variant='fp16',
|
variant='fp16',
|
||||||
default_creation_method='from_pretrained'
|
default_creation_method='from_pretrained'
|
||||||
)
|
)
|
||||||
|
|
||||||
# modify to load from a different repository
|
# modify to load from a different repository
|
||||||
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
|
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||||
|
|
||||||
# load component with modified spec
|
# load component with modified spec
|
||||||
unet = unet_spec.load(torch_dtype=torch.float16)
|
unet = unet_spec.load(torch_dtype=torch.float16)
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
|
|||||||
```py
|
```py
|
||||||
guider_spec = t2i_pipeline.get_component_spec("guider")
|
guider_spec = t2i_pipeline.get_component_spec("guider")
|
||||||
guider_spec.default_creation_method="from_pretrained"
|
guider_spec.default_creation_method="from_pretrained"
|
||||||
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
|
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
|
||||||
guider_spec.subfolder="pag_guider"
|
guider_spec.subfolder="pag_guider"
|
||||||
pag_guider = guider_spec.load()
|
pag_guider = guider_spec.load()
|
||||||
t2i_pipeline.update_components(guider=pag_guider)
|
t2i_pipeline.update_components(guider=pag_guider)
|
||||||
|
|||||||
@@ -313,14 +313,14 @@ unet_spec
|
|||||||
ComponentSpec(
|
ComponentSpec(
|
||||||
name='unet',
|
name='unet',
|
||||||
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
|
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
|
||||||
repo='RunDiffusion/Juggernaut-XL-v9',
|
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
|
||||||
subfolder='unet',
|
subfolder='unet',
|
||||||
variant='fp16',
|
variant='fp16',
|
||||||
default_creation_method='from_pretrained'
|
default_creation_method='from_pretrained'
|
||||||
)
|
)
|
||||||
|
|
||||||
# 修改以从不同的仓库加载
|
# 修改以从不同的仓库加载
|
||||||
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
|
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||||
|
|
||||||
# 使用修改后的规范加载组件
|
# 使用修改后的规范加载组件
|
||||||
unet = unet_spec.load(torch_dtype=torch.float16)
|
unet = unet_spec.load(torch_dtype=torch.float16)
|
||||||
|
|||||||
@@ -389,6 +389,14 @@ def is_valid_url(url):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_single_file_path_or_url(pretrained_model_name_or_path):
|
||||||
|
if not os.path.isfile(pretrained_model_name_or_path) or not is_valid_url(pretrained_model_name_or_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
repo_id, weight_name = _extract_repo_id_and_weights_name(pretrained_model_name_or_path)
|
||||||
|
return bool(repo_id and weight_name)
|
||||||
|
|
||||||
|
|
||||||
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
|
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
|
||||||
if not is_valid_url(pretrained_model_name_or_path):
|
if not is_valid_url(pretrained_model_name_or_path):
|
||||||
raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
|
raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
|
||||||
@@ -400,7 +408,6 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
|
|||||||
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
|
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
|
||||||
match = re.match(pattern, pretrained_model_name_or_path)
|
match = re.match(pattern, pretrained_model_name_or_path)
|
||||||
if not match:
|
if not match:
|
||||||
logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
|
|
||||||
return repo_id, weights_name
|
return repo_id, weights_name
|
||||||
|
|
||||||
repo_id = f"{match.group(1)}/{match.group(2)}"
|
repo_id = f"{match.group(1)}/{match.group(2)}"
|
||||||
|
|||||||
@@ -360,7 +360,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
collection: Optional[str] = None,
|
collection: Optional[str] = None,
|
||||||
) -> "ModularPipeline":
|
) -> "ModularPipeline":
|
||||||
"""
|
"""
|
||||||
create a ModularPipeline, optionally accept modular_repo to load from hub.
|
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
|
||||||
"""
|
"""
|
||||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
|
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
|
||||||
diffusers_module = importlib.import_module("diffusers")
|
diffusers_module = importlib.import_module("diffusers")
|
||||||
@@ -1645,8 +1645,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
|
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
|
||||||
Path to a pretrained pipeline configuration. It will first try to load config from
|
Path to a pretrained pipeline configuration. It will first try to load config from
|
||||||
`modular_model_index.json`, then fallback to `model_index.json` for compatibility with standard
|
`modular_model_index.json`, then fallback to `model_index.json` for compatibility with standard
|
||||||
non-modular repositories. If the repo does not contain any pipeline config, it will be set to None
|
non-modular repositories. If the pretrained_model_name_or_path does not contain any pipeline config, it
|
||||||
during initialization.
|
will be set to None during initialization.
|
||||||
trust_remote_code (`bool`, optional):
|
trust_remote_code (`bool`, optional):
|
||||||
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
|
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
|
||||||
pipeline blocks based on the custom code in `pretrained_model_name_or_path`
|
pipeline blocks based on the custom code in `pretrained_model_name_or_path`
|
||||||
@@ -1807,7 +1807,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
library, class_name = None, None
|
library, class_name = None, None
|
||||||
|
|
||||||
# extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config
|
# extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config
|
||||||
# e.g. {"repo": "stabilityai/stable-diffusion-2-1",
|
# e.g. {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1",
|
||||||
# "type_hint": ("diffusers", "UNet2DConditionModel"),
|
# "type_hint": ("diffusers", "UNet2DConditionModel"),
|
||||||
# "subfolder": "unet",
|
# "subfolder": "unet",
|
||||||
# "variant": None,
|
# "variant": None,
|
||||||
@@ -2111,8 +2111,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
|
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
|
||||||
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
|
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
|
||||||
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
|
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
|
||||||
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`,
|
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
|
||||||
`variant`, `revision`, etc.
|
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
|
||||||
|
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
|
||||||
|
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if names is None:
|
if names is None:
|
||||||
@@ -2378,10 +2380,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
- "type_hint": Tuple[str, str]
|
- "type_hint": Tuple[str, str]
|
||||||
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
|
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
|
||||||
- All loading fields defined by `component_spec.loading_fields()`, typically:
|
- All loading fields defined by `component_spec.loading_fields()`, typically:
|
||||||
- "repo": Optional[str]
|
- "pretrained_model_name_or_path": Optional[str]
|
||||||
The model repository (e.g., "stabilityai/stable-diffusion-xl").
|
The model pretrained_model_name_or_pathsitory (e.g., "stabilityai/stable-diffusion-xl").
|
||||||
- "subfolder": Optional[str]
|
- "subfolder": Optional[str]
|
||||||
A subfolder within the repo where this component lives.
|
A subfolder within the pretrained_model_name_or_path where this component lives.
|
||||||
- "variant": Optional[str]
|
- "variant": Optional[str]
|
||||||
An optional variant identifier for the model.
|
An optional variant identifier for the model.
|
||||||
- "revision": Optional[str]
|
- "revision": Optional[str]
|
||||||
@@ -2398,11 +2400,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
Example:
|
Example:
|
||||||
>>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import
|
>>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import
|
||||||
UNet2DConditionModel >>> spec = ComponentSpec(
|
UNet2DConditionModel >>> spec = ComponentSpec(
|
||||||
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ... repo="path/to/repo", ...
|
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ...
|
||||||
subfolder="subfolder", ... variant=None, ... revision=None, ...
|
pretrained_model_name_or_path="path/to/pretrained_model_name_or_path", ... subfolder="subfolder", ...
|
||||||
default_creation_method="from_pretrained",
|
variant=None, ... revision=None, ... default_creation_method="from_pretrained",
|
||||||
... ) >>> ModularPipeline._component_spec_to_dict(spec) {
|
... ) >>> ModularPipeline._component_spec_to_dict(spec) {
|
||||||
"type_hint": ("diffusers", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": "subfolder",
|
"type_hint": ("diffusers", "UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo",
|
||||||
|
"subfolder": "subfolder", "variant": None, "revision": None, "type_hint": ("diffusers",
|
||||||
|
"UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder",
|
||||||
"variant": None, "revision": None,
|
"variant": None, "revision": None,
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
@@ -2432,10 +2436,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
- "type_hint": Tuple[str, str]
|
- "type_hint": Tuple[str, str]
|
||||||
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
|
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
|
||||||
- All loading fields defined by `component_spec.loading_fields()`, typically:
|
- All loading fields defined by `component_spec.loading_fields()`, typically:
|
||||||
- "repo": Optional[str]
|
- "pretrained_model_name_or_path": Optional[str]
|
||||||
The model repository (e.g., "stabilityai/stable-diffusion-xl").
|
The model repository (e.g., "stabilityai/stable-diffusion-xl").
|
||||||
- "subfolder": Optional[str]
|
- "subfolder": Optional[str]
|
||||||
A subfolder within the repo where this component lives.
|
A subfolder within the pretrained_model_name_or_path where this component lives.
|
||||||
- "variant": Optional[str]
|
- "variant": Optional[str]
|
||||||
An optional variant identifier for the model.
|
An optional variant identifier for the model.
|
||||||
- "revision": Optional[str]
|
- "revision": Optional[str]
|
||||||
@@ -2452,11 +2456,20 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
ComponentSpec: A reconstructed ComponentSpec object.
|
ComponentSpec: A reconstructed ComponentSpec object.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ... "repo":
|
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
|
||||||
"stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant": None, ... "revision": None, ...
|
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
|
||||||
} >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) ComponentSpec(
|
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
|
||||||
name="unet", type_hint=UNet2DConditionModel, config=None, repo="stabilityai/stable-diffusion-xl",
|
ComponentSpec(
|
||||||
subfolder="unet", variant=None, revision=None, default_creation_method="from_pretrained"
|
name="unet", type_hint=UNet2DConditionModel, config=None,
|
||||||
|
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
|
||||||
|
revision=None, default_creation_method="from_pretrained"
|
||||||
|
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
|
||||||
|
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
|
||||||
|
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
|
||||||
|
ComponentSpec(
|
||||||
|
name="unet", type_hint=UNet2DConditionModel, config=None,
|
||||||
|
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
|
||||||
|
revision=None, default_creation_method="from_pretrained"
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
# make a shallow copy so we can pop() safely
|
# make a shallow copy so we can pop() safely
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||||
|
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||||
from ..utils import is_torch_available, logging
|
from ..utils import is_torch_available, logging
|
||||||
|
|
||||||
|
|
||||||
@@ -80,10 +81,10 @@ class ComponentSpec:
|
|||||||
type_hint: Type of the component (e.g. UNet2DConditionModel)
|
type_hint: Type of the component (e.g. UNet2DConditionModel)
|
||||||
description: Optional description of the component
|
description: Optional description of the component
|
||||||
config: Optional config dict for __init__ creation
|
config: Optional config dict for __init__ creation
|
||||||
repo: Optional repo path for from_pretrained creation
|
pretrained_model_name_or_path: Optional pretrained_model_name_or_path path for from_pretrained creation
|
||||||
subfolder: Optional subfolder in repo
|
subfolder: Optional subfolder in pretrained_model_name_or_path
|
||||||
variant: Optional variant in repo
|
variant: Optional variant in pretrained_model_name_or_path
|
||||||
revision: Optional revision in repo
|
revision: Optional revision in pretrained_model_name_or_path
|
||||||
default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
|
default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -91,13 +92,20 @@ class ComponentSpec:
|
|||||||
type_hint: Optional[Type] = None
|
type_hint: Optional[Type] = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
config: Optional[FrozenDict] = None
|
config: Optional[FrozenDict] = None
|
||||||
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
|
pretrained_model_name_or_path: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
|
||||||
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
|
|
||||||
subfolder: Optional[str] = field(default="", metadata={"loading": True})
|
subfolder: Optional[str] = field(default="", metadata={"loading": True})
|
||||||
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
||||||
revision: Optional[str] = field(default=None, metadata={"loading": True})
|
revision: Optional[str] = field(default=None, metadata={"loading": True})
|
||||||
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
|
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
|
||||||
|
|
||||||
|
# Deprecated
|
||||||
|
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": False})
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
repo_value = self.repo
|
||||||
|
if repo_value is not None and self.pretrained_model_name_or_path is None:
|
||||||
|
object.__setattr__(self, "pretrained_model_name_or_path", repo_value)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
"""Make ComponentSpec hashable, using load_id as the hash value."""
|
"""Make ComponentSpec hashable, using load_id as the hash value."""
|
||||||
return hash((self.name, self.load_id, self.default_creation_method))
|
return hash((self.name, self.load_id, self.default_creation_method))
|
||||||
@@ -182,8 +190,8 @@ class ComponentSpec:
|
|||||||
@property
|
@property
|
||||||
def load_id(self) -> str:
|
def load_id(self) -> str:
|
||||||
"""
|
"""
|
||||||
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
|
Unique identifier for this spec's pretrained load, composed of
|
||||||
segments).
|
pretrained_model_name_or_path|subfolder|variant|revision (no empty segments).
|
||||||
"""
|
"""
|
||||||
if self.default_creation_method == "from_config":
|
if self.default_creation_method == "from_config":
|
||||||
return "null"
|
return "null"
|
||||||
@@ -197,12 +205,13 @@ class ComponentSpec:
|
|||||||
Decode a load_id string back into a dictionary of loading fields and values.
|
Decode a load_id string back into a dictionary of loading fields and values.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
load_id: The load_id string to decode, format: "repo|subfolder|variant|revision"
|
load_id: The load_id string to decode, format: "pretrained_model_name_or_path|subfolder|variant|revision"
|
||||||
where None values are represented as "null"
|
where None values are represented as "null"
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping loading field names to their values. e.g. {
|
Dict mapping loading field names to their values. e.g. {
|
||||||
"repo": "path/to/repo", "subfolder": "subfolder", "variant": "variant", "revision": "revision"
|
"pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder", "variant": "variant",
|
||||||
|
"revision": "revision"
|
||||||
} If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating
|
} If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating
|
||||||
component not created with `load` method).
|
component not created with `load` method).
|
||||||
"""
|
"""
|
||||||
@@ -259,34 +268,45 @@ class ComponentSpec:
|
|||||||
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
|
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
|
||||||
def load(self, **kwargs) -> Any:
|
def load(self, **kwargs) -> Any:
|
||||||
"""Load component using from_pretrained."""
|
"""Load component using from_pretrained."""
|
||||||
|
# select loading fields from kwargs passed from user: e.g. pretrained_model_name_or_path, subfolder, variant, revision, note the list could change
|
||||||
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
|
|
||||||
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
|
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
|
||||||
# merge loading field value in the spec with user passed values to create load_kwargs
|
# merge loading field value in the spec with user passed values to create load_kwargs
|
||||||
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
|
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
|
||||||
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
|
||||||
repo = load_kwargs.pop("repo", None)
|
pretrained_model_name_or_path = load_kwargs.pop("pretrained_model_name_or_path", None)
|
||||||
if repo is None:
|
if pretrained_model_name_or_path is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)"
|
"`pretrained_model_name_or_path` info is required when using `load` method (you can directly set it in `pretrained_model_name_or_path` field of the ComponentSpec or pass it as an argument)"
|
||||||
|
)
|
||||||
|
is_single_file = _is_single_file_path_or_url(pretrained_model_name_or_path)
|
||||||
|
if is_single_file and self.type_hint is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.type_hint is None:
|
if self.type_hint is None:
|
||||||
try:
|
try:
|
||||||
from diffusers import AutoModel
|
from diffusers import AutoModel
|
||||||
|
|
||||||
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
|
component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
|
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
|
||||||
# update type_hint if AutoModel load successfully
|
# update type_hint if AutoModel load successfully
|
||||||
self.type_hint = component.__class__
|
self.type_hint = component.__class__
|
||||||
else:
|
else:
|
||||||
|
# determine load method
|
||||||
|
load_method = (
|
||||||
|
getattr(self.type_hint, "from_single_file")
|
||||||
|
if is_single_file
|
||||||
|
else getattr(self.type_hint, "from_pretrained")
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
|
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Unable to load {self.name} using load method: {e}")
|
raise ValueError(f"Unable to load {self.name} using load method: {e}")
|
||||||
|
|
||||||
self.repo = repo
|
self.pretrained_model_name_or_path = pretrained_model_name_or_path
|
||||||
for k, v in load_kwargs.items():
|
for k, v in load_kwargs.items():
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
component._diffusers_load_id = self.load_id
|
component._diffusers_load_id = self.load_id
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
|||||||
class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
||||||
pipeline_class = FluxModularPipeline
|
pipeline_class = FluxModularPipeline
|
||||||
pipeline_blocks_class = FluxAutoBlocks
|
pipeline_blocks_class = FluxAutoBlocks
|
||||||
repo = "hf-internal-testing/tiny-flux-modular"
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||||
|
|
||||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||||
batch_params = frozenset(["prompt"])
|
batch_params = frozenset(["prompt"])
|
||||||
@@ -62,7 +62,7 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
|||||||
class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||||
pipeline_class = FluxModularPipeline
|
pipeline_class = FluxModularPipeline
|
||||||
pipeline_blocks_class = FluxAutoBlocks
|
pipeline_blocks_class = FluxAutoBlocks
|
||||||
repo = "hf-internal-testing/tiny-flux-modular"
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||||
|
|
||||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||||
batch_params = frozenset(["prompt", "image"])
|
batch_params = frozenset(["prompt", "image"])
|
||||||
@@ -128,7 +128,7 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
|||||||
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
|
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
|
||||||
pipeline_class = FluxKontextModularPipeline
|
pipeline_class = FluxKontextModularPipeline
|
||||||
pipeline_blocks_class = FluxKontextAutoBlocks
|
pipeline_blocks_class = FluxKontextAutoBlocks
|
||||||
repo = "hf-internal-testing/tiny-flux-kontext-pipe"
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||||
|
|
||||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||||
batch_params = frozenset(["prompt", "image"])
|
batch_params = frozenset(["prompt", "image"])
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPip
|
|||||||
class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
||||||
pipeline_class = QwenImageModularPipeline
|
pipeline_class = QwenImageModularPipeline
|
||||||
pipeline_blocks_class = QwenImageAutoBlocks
|
pipeline_blocks_class = QwenImageAutoBlocks
|
||||||
repo = "hf-internal-testing/tiny-qwenimage-modular"
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-modular"
|
||||||
|
|
||||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||||
@@ -58,7 +58,7 @@ class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuider
|
|||||||
class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
||||||
pipeline_class = QwenImageEditModularPipeline
|
pipeline_class = QwenImageEditModularPipeline
|
||||||
pipeline_blocks_class = QwenImageEditAutoBlocks
|
pipeline_blocks_class = QwenImageEditAutoBlocks
|
||||||
repo = "hf-internal-testing/tiny-qwenimage-edit-modular"
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-modular"
|
||||||
|
|
||||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||||
@@ -84,7 +84,7 @@ class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGu
|
|||||||
class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
||||||
pipeline_class = QwenImageEditPlusModularPipeline
|
pipeline_class = QwenImageEditPlusModularPipeline
|
||||||
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
|
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
|
||||||
repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
|
||||||
|
|
||||||
# No `mask_image` yet.
|
# No `mask_image` yet.
|
||||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
|
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ class SDXLModularIPAdapterTesterMixin:
|
|||||||
|
|
||||||
blocks = self.pipeline_blocks_class()
|
blocks = self.pipeline_blocks_class()
|
||||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||||
pipe = blocks.init_pipeline(self.repo)
|
pipe = blocks.init_pipeline(self.pretrained_model_name_or_path)
|
||||||
pipe.load_components(torch_dtype=torch.float32)
|
pipe.load_components(torch_dtype=torch.float32)
|
||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
|
|
||||||
@@ -278,7 +278,7 @@ class TestSDXLModularPipelineFast(
|
|||||||
|
|
||||||
pipeline_class = StableDiffusionXLModularPipeline
|
pipeline_class = StableDiffusionXLModularPipeline
|
||||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||||
params = frozenset(
|
params = frozenset(
|
||||||
[
|
[
|
||||||
"prompt",
|
"prompt",
|
||||||
@@ -325,7 +325,7 @@ class TestSDXLImg2ImgModularPipelineFast(
|
|||||||
|
|
||||||
pipeline_class = StableDiffusionXLModularPipeline
|
pipeline_class = StableDiffusionXLModularPipeline
|
||||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||||
params = frozenset(
|
params = frozenset(
|
||||||
[
|
[
|
||||||
"prompt",
|
"prompt",
|
||||||
@@ -378,7 +378,7 @@ class SDXLInpaintingModularPipelineFastTests(
|
|||||||
|
|
||||||
pipeline_class = StableDiffusionXLModularPipeline
|
pipeline_class = StableDiffusionXLModularPipeline
|
||||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||||
params = frozenset(
|
params = frozenset(
|
||||||
[
|
[
|
||||||
"prompt",
|
"prompt",
|
||||||
|
|||||||
@@ -43,9 +43,9 @@ class ModularPipelineTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def repo(self) -> str:
|
def pretrained_model_name_or_path(self) -> str:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
|
"You need to set the attribute `pretrained_model_name_or_path` in the child test class. See existing pipeline tests for reference."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -103,7 +103,9 @@ class ModularPipelineTesterMixin:
|
|||||||
backend_empty_cache(torch_device)
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
pipeline = self.pipeline_blocks_class().init_pipeline(
|
||||||
|
self.pretrained_model_name_or_path, components_manager=components_manager
|
||||||
|
)
|
||||||
pipeline.load_components(torch_dtype=torch_dtype)
|
pipeline.load_components(torch_dtype=torch_dtype)
|
||||||
pipeline.set_progress_bar_config(disable=None)
|
pipeline.set_progress_bar_config(disable=None)
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|||||||
Reference in New Issue
Block a user