mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
[Modular] Add single file support to Modular (#12383)
* update * update * update * update * Apply style fixes * update * update * update * update * update --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -159,7 +159,7 @@ Change the [`~ComponentSpec.default_creation_method`] to `from_pretrained` and u
|
||||
```py
|
||||
guider_spec = t2i_pipeline.get_component_spec("guider")
|
||||
guider_spec.default_creation_method="from_pretrained"
|
||||
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
|
||||
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
|
||||
guider_spec.subfolder="pag_guider"
|
||||
pag_guider = guider_spec.load()
|
||||
t2i_pipeline.update_components(guider=pag_guider)
|
||||
|
||||
@@ -313,14 +313,14 @@ unet_spec
|
||||
ComponentSpec(
|
||||
name='unet',
|
||||
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
|
||||
repo='RunDiffusion/Juggernaut-XL-v9',
|
||||
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
|
||||
subfolder='unet',
|
||||
variant='fp16',
|
||||
default_creation_method='from_pretrained'
|
||||
)
|
||||
|
||||
# modify to load from a different repository
|
||||
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
# load component with modified spec
|
||||
unet = unet_spec.load(torch_dtype=torch.float16)
|
||||
|
||||
@@ -157,7 +157,7 @@ guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
|
||||
```py
|
||||
guider_spec = t2i_pipeline.get_component_spec("guider")
|
||||
guider_spec.default_creation_method="from_pretrained"
|
||||
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
|
||||
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
|
||||
guider_spec.subfolder="pag_guider"
|
||||
pag_guider = guider_spec.load()
|
||||
t2i_pipeline.update_components(guider=pag_guider)
|
||||
|
||||
@@ -313,14 +313,14 @@ unet_spec
|
||||
ComponentSpec(
|
||||
name='unet',
|
||||
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
|
||||
repo='RunDiffusion/Juggernaut-XL-v9',
|
||||
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
|
||||
subfolder='unet',
|
||||
variant='fp16',
|
||||
default_creation_method='from_pretrained'
|
||||
)
|
||||
|
||||
# 修改以从不同的仓库加载
|
||||
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
# 使用修改后的规范加载组件
|
||||
unet = unet_spec.load(torch_dtype=torch.float16)
|
||||
|
||||
@@ -389,6 +389,14 @@ def is_valid_url(url):
|
||||
return False
|
||||
|
||||
|
||||
def _is_single_file_path_or_url(pretrained_model_name_or_path):
|
||||
if not os.path.isfile(pretrained_model_name_or_path) or not is_valid_url(pretrained_model_name_or_path):
|
||||
return False
|
||||
|
||||
repo_id, weight_name = _extract_repo_id_and_weights_name(pretrained_model_name_or_path)
|
||||
return bool(repo_id and weight_name)
|
||||
|
||||
|
||||
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
|
||||
if not is_valid_url(pretrained_model_name_or_path):
|
||||
raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
|
||||
@@ -400,7 +408,6 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
|
||||
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
|
||||
match = re.match(pattern, pretrained_model_name_or_path)
|
||||
if not match:
|
||||
logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
|
||||
return repo_id, weights_name
|
||||
|
||||
repo_id = f"{match.group(1)}/{match.group(2)}"
|
||||
|
||||
@@ -360,7 +360,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
collection: Optional[str] = None,
|
||||
) -> "ModularPipeline":
|
||||
"""
|
||||
create a ModularPipeline, optionally accept modular_repo to load from hub.
|
||||
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
|
||||
"""
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
@@ -1645,8 +1645,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
|
||||
Path to a pretrained pipeline configuration. It will first try to load config from
|
||||
`modular_model_index.json`, then fallback to `model_index.json` for compatibility with standard
|
||||
non-modular repositories. If the repo does not contain any pipeline config, it will be set to None
|
||||
during initialization.
|
||||
non-modular repositories. If the pretrained_model_name_or_path does not contain any pipeline config, it
|
||||
will be set to None during initialization.
|
||||
trust_remote_code (`bool`, optional):
|
||||
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
|
||||
pipeline blocks based on the custom code in `pretrained_model_name_or_path`
|
||||
@@ -1807,7 +1807,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
library, class_name = None, None
|
||||
|
||||
# extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config
|
||||
# e.g. {"repo": "stabilityai/stable-diffusion-2-1",
|
||||
# e.g. {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1",
|
||||
# "type_hint": ("diffusers", "UNet2DConditionModel"),
|
||||
# "subfolder": "unet",
|
||||
# "variant": None,
|
||||
@@ -2111,8 +2111,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
|
||||
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
|
||||
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
|
||||
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`,
|
||||
`variant`, `revision`, etc.
|
||||
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
|
||||
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
|
||||
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
|
||||
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
|
||||
"""
|
||||
|
||||
if names is None:
|
||||
@@ -2378,10 +2380,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- "type_hint": Tuple[str, str]
|
||||
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
|
||||
- All loading fields defined by `component_spec.loading_fields()`, typically:
|
||||
- "repo": Optional[str]
|
||||
The model repository (e.g., "stabilityai/stable-diffusion-xl").
|
||||
- "pretrained_model_name_or_path": Optional[str]
|
||||
The model pretrained_model_name_or_pathsitory (e.g., "stabilityai/stable-diffusion-xl").
|
||||
- "subfolder": Optional[str]
|
||||
A subfolder within the repo where this component lives.
|
||||
A subfolder within the pretrained_model_name_or_path where this component lives.
|
||||
- "variant": Optional[str]
|
||||
An optional variant identifier for the model.
|
||||
- "revision": Optional[str]
|
||||
@@ -2398,11 +2400,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
Example:
|
||||
>>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import
|
||||
UNet2DConditionModel >>> spec = ComponentSpec(
|
||||
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ... repo="path/to/repo", ...
|
||||
subfolder="subfolder", ... variant=None, ... revision=None, ...
|
||||
default_creation_method="from_pretrained",
|
||||
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ...
|
||||
pretrained_model_name_or_path="path/to/pretrained_model_name_or_path", ... subfolder="subfolder", ...
|
||||
variant=None, ... revision=None, ... default_creation_method="from_pretrained",
|
||||
... ) >>> ModularPipeline._component_spec_to_dict(spec) {
|
||||
"type_hint": ("diffusers", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": "subfolder",
|
||||
"type_hint": ("diffusers", "UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo",
|
||||
"subfolder": "subfolder", "variant": None, "revision": None, "type_hint": ("diffusers",
|
||||
"UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder",
|
||||
"variant": None, "revision": None,
|
||||
}
|
||||
"""
|
||||
@@ -2432,10 +2436,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- "type_hint": Tuple[str, str]
|
||||
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
|
||||
- All loading fields defined by `component_spec.loading_fields()`, typically:
|
||||
- "repo": Optional[str]
|
||||
- "pretrained_model_name_or_path": Optional[str]
|
||||
The model repository (e.g., "stabilityai/stable-diffusion-xl").
|
||||
- "subfolder": Optional[str]
|
||||
A subfolder within the repo where this component lives.
|
||||
A subfolder within the pretrained_model_name_or_path where this component lives.
|
||||
- "variant": Optional[str]
|
||||
An optional variant identifier for the model.
|
||||
- "revision": Optional[str]
|
||||
@@ -2452,11 +2456,20 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
ComponentSpec: A reconstructed ComponentSpec object.
|
||||
|
||||
Example:
|
||||
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ... "repo":
|
||||
"stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant": None, ... "revision": None, ...
|
||||
} >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) ComponentSpec(
|
||||
name="unet", type_hint=UNet2DConditionModel, config=None, repo="stabilityai/stable-diffusion-xl",
|
||||
subfolder="unet", variant=None, revision=None, default_creation_method="from_pretrained"
|
||||
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
|
||||
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
|
||||
ComponentSpec(
|
||||
name="unet", type_hint=UNet2DConditionModel, config=None,
|
||||
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
|
||||
revision=None, default_creation_method="from_pretrained"
|
||||
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
|
||||
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
|
||||
ComponentSpec(
|
||||
name="unet", type_hint=UNet2DConditionModel, config=None,
|
||||
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
|
||||
revision=None, default_creation_method="from_pretrained"
|
||||
)
|
||||
"""
|
||||
# make a shallow copy so we can pop() safely
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||
from ..utils import is_torch_available, logging
|
||||
|
||||
|
||||
@@ -80,10 +81,10 @@ class ComponentSpec:
|
||||
type_hint: Type of the component (e.g. UNet2DConditionModel)
|
||||
description: Optional description of the component
|
||||
config: Optional config dict for __init__ creation
|
||||
repo: Optional repo path for from_pretrained creation
|
||||
subfolder: Optional subfolder in repo
|
||||
variant: Optional variant in repo
|
||||
revision: Optional revision in repo
|
||||
pretrained_model_name_or_path: Optional pretrained_model_name_or_path path for from_pretrained creation
|
||||
subfolder: Optional subfolder in pretrained_model_name_or_path
|
||||
variant: Optional variant in pretrained_model_name_or_path
|
||||
revision: Optional revision in pretrained_model_name_or_path
|
||||
default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
|
||||
"""
|
||||
|
||||
@@ -91,13 +92,20 @@ class ComponentSpec:
|
||||
type_hint: Optional[Type] = None
|
||||
description: Optional[str] = None
|
||||
config: Optional[FrozenDict] = None
|
||||
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
|
||||
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
|
||||
pretrained_model_name_or_path: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
|
||||
subfolder: Optional[str] = field(default="", metadata={"loading": True})
|
||||
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
revision: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
|
||||
|
||||
# Deprecated
|
||||
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": False})
|
||||
|
||||
def __post_init__(self):
|
||||
repo_value = self.repo
|
||||
if repo_value is not None and self.pretrained_model_name_or_path is None:
|
||||
object.__setattr__(self, "pretrained_model_name_or_path", repo_value)
|
||||
|
||||
def __hash__(self):
|
||||
"""Make ComponentSpec hashable, using load_id as the hash value."""
|
||||
return hash((self.name, self.load_id, self.default_creation_method))
|
||||
@@ -182,8 +190,8 @@ class ComponentSpec:
|
||||
@property
|
||||
def load_id(self) -> str:
|
||||
"""
|
||||
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
|
||||
segments).
|
||||
Unique identifier for this spec's pretrained load, composed of
|
||||
pretrained_model_name_or_path|subfolder|variant|revision (no empty segments).
|
||||
"""
|
||||
if self.default_creation_method == "from_config":
|
||||
return "null"
|
||||
@@ -197,12 +205,13 @@ class ComponentSpec:
|
||||
Decode a load_id string back into a dictionary of loading fields and values.
|
||||
|
||||
Args:
|
||||
load_id: The load_id string to decode, format: "repo|subfolder|variant|revision"
|
||||
load_id: The load_id string to decode, format: "pretrained_model_name_or_path|subfolder|variant|revision"
|
||||
where None values are represented as "null"
|
||||
|
||||
Returns:
|
||||
Dict mapping loading field names to their values. e.g. {
|
||||
"repo": "path/to/repo", "subfolder": "subfolder", "variant": "variant", "revision": "revision"
|
||||
"pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder", "variant": "variant",
|
||||
"revision": "revision"
|
||||
} If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating
|
||||
component not created with `load` method).
|
||||
"""
|
||||
@@ -259,34 +268,45 @@ class ComponentSpec:
|
||||
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
|
||||
def load(self, **kwargs) -> Any:
|
||||
"""Load component using from_pretrained."""
|
||||
|
||||
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
|
||||
# select loading fields from kwargs passed from user: e.g. pretrained_model_name_or_path, subfolder, variant, revision, note the list could change
|
||||
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
|
||||
# merge loading field value in the spec with user passed values to create load_kwargs
|
||||
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
|
||||
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
||||
repo = load_kwargs.pop("repo", None)
|
||||
if repo is None:
|
||||
|
||||
pretrained_model_name_or_path = load_kwargs.pop("pretrained_model_name_or_path", None)
|
||||
if pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)"
|
||||
"`pretrained_model_name_or_path` info is required when using `load` method (you can directly set it in `pretrained_model_name_or_path` field of the ComponentSpec or pass it as an argument)"
|
||||
)
|
||||
is_single_file = _is_single_file_path_or_url(pretrained_model_name_or_path)
|
||||
if is_single_file and self.type_hint is None:
|
||||
raise ValueError(
|
||||
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
|
||||
)
|
||||
|
||||
if self.type_hint is None:
|
||||
try:
|
||||
from diffusers import AutoModel
|
||||
|
||||
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||
component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
|
||||
# update type_hint if AutoModel load successfully
|
||||
self.type_hint = component.__class__
|
||||
else:
|
||||
# determine load method
|
||||
load_method = (
|
||||
getattr(self.type_hint, "from_single_file")
|
||||
if is_single_file
|
||||
else getattr(self.type_hint, "from_pretrained")
|
||||
)
|
||||
|
||||
try:
|
||||
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to load {self.name} using load method: {e}")
|
||||
|
||||
self.repo = repo
|
||||
self.pretrained_model_name_or_path = pretrained_model_name_or_path
|
||||
for k, v in load_kwargs.items():
|
||||
setattr(self, k, v)
|
||||
component._diffusers_load_id = self.load_id
|
||||
|
||||
@@ -36,7 +36,7 @@ from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-flux-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
@@ -62,7 +62,7 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
||||
class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-flux-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
@@ -128,7 +128,7 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxKontextModularPipeline
|
||||
pipeline_blocks_class = FluxKontextAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
@@ -32,7 +32,7 @@ from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPip
|
||||
class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
||||
pipeline_class = QwenImageModularPipeline
|
||||
pipeline_blocks_class = QwenImageAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-qwenimage-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
@@ -58,7 +58,7 @@ class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuider
|
||||
class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
||||
pipeline_class = QwenImageEditModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-qwenimage-edit-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
@@ -84,7 +84,7 @@ class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGu
|
||||
class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
||||
pipeline_class = QwenImageEditPlusModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
|
||||
|
||||
# No `mask_image` yet.
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
|
||||
|
||||
@@ -105,7 +105,7 @@ class SDXLModularIPAdapterTesterMixin:
|
||||
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
pipe = blocks.init_pipeline(self.repo)
|
||||
pipe = blocks.init_pipeline(self.pretrained_model_name_or_path)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
@@ -278,7 +278,7 @@ class TestSDXLModularPipelineFast(
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -325,7 +325,7 @@ class TestSDXLImg2ImgModularPipelineFast(
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -378,7 +378,7 @@ class SDXLInpaintingModularPipelineFastTests(
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
|
||||
@@ -43,9 +43,9 @@ class ModularPipelineTesterMixin:
|
||||
)
|
||||
|
||||
@property
|
||||
def repo(self) -> str:
|
||||
def pretrained_model_name_or_path(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
|
||||
"You need to set the attribute `pretrained_model_name_or_path` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -103,7 +103,9 @@ class ModularPipelineTesterMixin:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(
|
||||
self.pretrained_model_name_or_path, components_manager=components_manager
|
||||
)
|
||||
pipeline.load_components(torch_dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
return pipeline
|
||||
|
||||
Reference in New Issue
Block a user