Compare commits

...

13 Commits

Author SHA1 Message Date
DN6
15fb54de11 update 2025-11-28 21:34:32 +05:30
DN6
207847552d update 2025-11-28 18:52:39 +05:30
DN6
2cd7256ecb update 2025-11-28 18:29:02 +05:30
DN6
d391ace1ca update 2025-11-28 15:52:22 +05:30
DN6
546b6945df Merge branch 'main' into sf-modular 2025-11-28 15:30:13 +05:30
Dhruv Nair
6dda0ecf67 Merge branch 'main' into sf-modular 2025-10-28 23:03:45 +05:30
Dhruv Nair
fb1766ee4f Merge branch 'sf-modular' of https://github.com/huggingface/diffusers into sf-modular 2025-10-23 19:00:13 +02:00
Dhruv Nair
8733fef39d update 2025-10-23 18:56:14 +02:00
github-actions[bot]
a707e314ad Apply style fixes 2025-10-20 09:44:50 +00:00
DN6
c1c0e9a481 update 2025-09-29 12:35:55 +05:30
DN6
0a7bde9200 update 2025-09-24 16:33:09 +05:30
DN6
af48d815d8 update 2025-09-24 16:31:07 +05:30
DN6
bea02ccba3 update 2025-09-18 23:31:07 +05:30
11 changed files with 101 additions and 59 deletions

View File

@@ -159,7 +159,7 @@ Change the [`~ComponentSpec.default_creation_method`] to `from_pretrained` and u
```py ```py
guider_spec = t2i_pipeline.get_component_spec("guider") guider_spec = t2i_pipeline.get_component_spec("guider")
guider_spec.default_creation_method="from_pretrained" guider_spec.default_creation_method="from_pretrained"
guider_spec.repo="YiYiXu/modular-loader-t2i-guider" guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
guider_spec.subfolder="pag_guider" guider_spec.subfolder="pag_guider"
pag_guider = guider_spec.load() pag_guider = guider_spec.load()
t2i_pipeline.update_components(guider=pag_guider) t2i_pipeline.update_components(guider=pag_guider)

View File

@@ -313,14 +313,14 @@ unet_spec
ComponentSpec( ComponentSpec(
name='unet', name='unet',
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>, type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
repo='RunDiffusion/Juggernaut-XL-v9', pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
subfolder='unet', subfolder='unet',
variant='fp16', variant='fp16',
default_creation_method='from_pretrained' default_creation_method='from_pretrained'
) )
# modify to load from a different repository # modify to load from a different repository
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0" unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
# load component with modified spec # load component with modified spec
unet = unet_spec.load(torch_dtype=torch.float16) unet = unet_spec.load(torch_dtype=torch.float16)

View File

@@ -157,7 +157,7 @@ guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
```py ```py
guider_spec = t2i_pipeline.get_component_spec("guider") guider_spec = t2i_pipeline.get_component_spec("guider")
guider_spec.default_creation_method="from_pretrained" guider_spec.default_creation_method="from_pretrained"
guider_spec.repo="YiYiXu/modular-loader-t2i-guider" guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
guider_spec.subfolder="pag_guider" guider_spec.subfolder="pag_guider"
pag_guider = guider_spec.load() pag_guider = guider_spec.load()
t2i_pipeline.update_components(guider=pag_guider) t2i_pipeline.update_components(guider=pag_guider)

View File

@@ -313,14 +313,14 @@ unet_spec
ComponentSpec( ComponentSpec(
name='unet', name='unet',
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>, type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
repo='RunDiffusion/Juggernaut-XL-v9', pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
subfolder='unet', subfolder='unet',
variant='fp16', variant='fp16',
default_creation_method='from_pretrained' default_creation_method='from_pretrained'
) )
# 修改以从不同的仓库加载 # 修改以从不同的仓库加载
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0" unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
# 使用修改后的规范加载组件 # 使用修改后的规范加载组件
unet = unet_spec.load(torch_dtype=torch.float16) unet = unet_spec.load(torch_dtype=torch.float16)

View File

@@ -389,6 +389,14 @@ def is_valid_url(url):
return False return False
def _is_single_file_path_or_url(pretrained_model_name_or_path):
if not os.path.isfile(pretrained_model_name_or_path) or not is_valid_url(pretrained_model_name_or_path):
return False
repo_id, weight_name = _extract_repo_id_and_weights_name(pretrained_model_name_or_path)
return bool(repo_id and weight_name)
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
if not is_valid_url(pretrained_model_name_or_path): if not is_valid_url(pretrained_model_name_or_path):
raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.") raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
@@ -400,7 +408,6 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "") pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
match = re.match(pattern, pretrained_model_name_or_path) match = re.match(pattern, pretrained_model_name_or_path)
if not match: if not match:
logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
return repo_id, weights_name return repo_id, weights_name
repo_id = f"{match.group(1)}/{match.group(2)}" repo_id = f"{match.group(1)}/{match.group(2)}"

View File

@@ -360,7 +360,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
collection: Optional[str] = None, collection: Optional[str] = None,
) -> "ModularPipeline": ) -> "ModularPipeline":
""" """
create a ModularPipeline, optionally accept modular_repo to load from hub. create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
""" """
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__) pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
diffusers_module = importlib.import_module("diffusers") diffusers_module = importlib.import_module("diffusers")
@@ -1645,8 +1645,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
pretrained_model_name_or_path (`str` or `os.PathLike`, optional): pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
Path to a pretrained pipeline configuration. It will first try to load config from Path to a pretrained pipeline configuration. It will first try to load config from
`modular_model_index.json`, then fallback to `model_index.json` for compatibility with standard `modular_model_index.json`, then fallback to `model_index.json` for compatibility with standard
non-modular repositories. If the repo does not contain any pipeline config, it will be set to None non-modular repositories. If the pretrained_model_name_or_path does not contain any pipeline config, it
during initialization. will be set to None during initialization.
trust_remote_code (`bool`, optional): trust_remote_code (`bool`, optional):
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
pipeline blocks based on the custom code in `pretrained_model_name_or_path` pipeline blocks based on the custom code in `pretrained_model_name_or_path`
@@ -1807,7 +1807,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
library, class_name = None, None library, class_name = None, None
# extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config
# e.g. {"repo": "stabilityai/stable-diffusion-2-1", # e.g. {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1",
# "type_hint": ("diffusers", "UNet2DConditionModel"), # "type_hint": ("diffusers", "UNet2DConditionModel"),
# "subfolder": "unet", # "subfolder": "unet",
# "variant": None, # "variant": None,
@@ -2111,8 +2111,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`, - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
`variant`, `revision`, etc. `pretrained_model_name_or_path`, `variant`, `revision`, etc.
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
""" """
if names is None: if names is None:
@@ -2378,10 +2380,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- "type_hint": Tuple[str, str] - "type_hint": Tuple[str, str]
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel")) Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
- All loading fields defined by `component_spec.loading_fields()`, typically: - All loading fields defined by `component_spec.loading_fields()`, typically:
- "repo": Optional[str] - "pretrained_model_name_or_path": Optional[str]
The model repository (e.g., "stabilityai/stable-diffusion-xl"). The model pretrained_model_name_or_pathsitory (e.g., "stabilityai/stable-diffusion-xl").
- "subfolder": Optional[str] - "subfolder": Optional[str]
A subfolder within the repo where this component lives. A subfolder within the pretrained_model_name_or_path where this component lives.
- "variant": Optional[str] - "variant": Optional[str]
An optional variant identifier for the model. An optional variant identifier for the model.
- "revision": Optional[str] - "revision": Optional[str]
@@ -2398,11 +2400,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Example: Example:
>>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import
UNet2DConditionModel >>> spec = ComponentSpec( UNet2DConditionModel >>> spec = ComponentSpec(
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ... repo="path/to/repo", ... ... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ...
subfolder="subfolder", ... variant=None, ... revision=None, ... pretrained_model_name_or_path="path/to/pretrained_model_name_or_path", ... subfolder="subfolder", ...
default_creation_method="from_pretrained", variant=None, ... revision=None, ... default_creation_method="from_pretrained",
... ) >>> ModularPipeline._component_spec_to_dict(spec) { ... ) >>> ModularPipeline._component_spec_to_dict(spec) {
"type_hint": ("diffusers", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": "subfolder", "type_hint": ("diffusers", "UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo",
"subfolder": "subfolder", "variant": None, "revision": None, "type_hint": ("diffusers",
"UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder",
"variant": None, "revision": None, "variant": None, "revision": None,
} }
""" """
@@ -2432,10 +2436,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- "type_hint": Tuple[str, str] - "type_hint": Tuple[str, str]
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel")) Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
- All loading fields defined by `component_spec.loading_fields()`, typically: - All loading fields defined by `component_spec.loading_fields()`, typically:
- "repo": Optional[str] - "pretrained_model_name_or_path": Optional[str]
The model repository (e.g., "stabilityai/stable-diffusion-xl"). The model repository (e.g., "stabilityai/stable-diffusion-xl").
- "subfolder": Optional[str] - "subfolder": Optional[str]
A subfolder within the repo where this component lives. A subfolder within the pretrained_model_name_or_path where this component lives.
- "variant": Optional[str] - "variant": Optional[str]
An optional variant identifier for the model. An optional variant identifier for the model.
- "revision": Optional[str] - "revision": Optional[str]
@@ -2452,11 +2456,20 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
ComponentSpec: A reconstructed ComponentSpec object. ComponentSpec: A reconstructed ComponentSpec object.
Example: Example:
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ... "repo": >>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
"stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant": None, ... "revision": None, ... "pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
} >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) ComponentSpec( None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
name="unet", type_hint=UNet2DConditionModel, config=None, repo="stabilityai/stable-diffusion-xl", ComponentSpec(
subfolder="unet", variant=None, revision=None, default_creation_method="from_pretrained" name="unet", type_hint=UNet2DConditionModel, config=None,
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
revision=None, default_creation_method="from_pretrained"
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
ComponentSpec(
name="unet", type_hint=UNet2DConditionModel, config=None,
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
revision=None, default_creation_method="from_pretrained"
) )
""" """
# make a shallow copy so we can pop() safely # make a shallow copy so we can pop() safely

View File

@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch import torch
from ..configuration_utils import ConfigMixin, FrozenDict from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import is_torch_available, logging from ..utils import is_torch_available, logging
@@ -80,10 +81,10 @@ class ComponentSpec:
type_hint: Type of the component (e.g. UNet2DConditionModel) type_hint: Type of the component (e.g. UNet2DConditionModel)
description: Optional description of the component description: Optional description of the component
config: Optional config dict for __init__ creation config: Optional config dict for __init__ creation
repo: Optional repo path for from_pretrained creation pretrained_model_name_or_path: Optional pretrained_model_name_or_path path for from_pretrained creation
subfolder: Optional subfolder in repo subfolder: Optional subfolder in pretrained_model_name_or_path
variant: Optional variant in repo variant: Optional variant in pretrained_model_name_or_path
revision: Optional revision in repo revision: Optional revision in pretrained_model_name_or_path
default_creation_method: Preferred creation method - "from_config" or "from_pretrained" default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
""" """
@@ -91,13 +92,20 @@ class ComponentSpec:
type_hint: Optional[Type] = None type_hint: Optional[Type] = None
description: Optional[str] = None description: Optional[str] = None
config: Optional[FrozenDict] = None config: Optional[FrozenDict] = None
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name pretrained_model_name_or_path: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default="", metadata={"loading": True}) subfolder: Optional[str] = field(default="", metadata={"loading": True})
variant: Optional[str] = field(default=None, metadata={"loading": True}) variant: Optional[str] = field(default=None, metadata={"loading": True})
revision: Optional[str] = field(default=None, metadata={"loading": True}) revision: Optional[str] = field(default=None, metadata={"loading": True})
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained" default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
# Deprecated
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": False})
def __post_init__(self):
repo_value = self.repo
if repo_value is not None and self.pretrained_model_name_or_path is None:
object.__setattr__(self, "pretrained_model_name_or_path", repo_value)
def __hash__(self): def __hash__(self):
"""Make ComponentSpec hashable, using load_id as the hash value.""" """Make ComponentSpec hashable, using load_id as the hash value."""
return hash((self.name, self.load_id, self.default_creation_method)) return hash((self.name, self.load_id, self.default_creation_method))
@@ -182,8 +190,8 @@ class ComponentSpec:
@property @property
def load_id(self) -> str: def load_id(self) -> str:
""" """
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty Unique identifier for this spec's pretrained load, composed of
segments). pretrained_model_name_or_path|subfolder|variant|revision (no empty segments).
""" """
if self.default_creation_method == "from_config": if self.default_creation_method == "from_config":
return "null" return "null"
@@ -197,12 +205,13 @@ class ComponentSpec:
Decode a load_id string back into a dictionary of loading fields and values. Decode a load_id string back into a dictionary of loading fields and values.
Args: Args:
load_id: The load_id string to decode, format: "repo|subfolder|variant|revision" load_id: The load_id string to decode, format: "pretrained_model_name_or_path|subfolder|variant|revision"
where None values are represented as "null" where None values are represented as "null"
Returns: Returns:
Dict mapping loading field names to their values. e.g. { Dict mapping loading field names to their values. e.g. {
"repo": "path/to/repo", "subfolder": "subfolder", "variant": "variant", "revision": "revision" "pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder", "variant": "variant",
"revision": "revision"
} If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating } If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating
component not created with `load` method). component not created with `load` method).
""" """
@@ -259,34 +268,45 @@ class ComponentSpec:
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained # YiYi TODO: add guard for type of model, if it is supported by from_pretrained
def load(self, **kwargs) -> Any: def load(self, **kwargs) -> Any:
"""Load component using from_pretrained.""" """Load component using from_pretrained."""
# select loading fields from kwargs passed from user: e.g. pretrained_model_name_or_path, subfolder, variant, revision, note the list could change
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs} passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
# merge loading field value in the spec with user passed values to create load_kwargs # merge loading field value in the spec with user passed values to create load_kwargs
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()} load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
repo = load_kwargs.pop("repo", None) pretrained_model_name_or_path = load_kwargs.pop("pretrained_model_name_or_path", None)
if repo is None: if pretrained_model_name_or_path is None:
raise ValueError( raise ValueError(
"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)" "`pretrained_model_name_or_path` info is required when using `load` method (you can directly set it in `pretrained_model_name_or_path` field of the ComponentSpec or pass it as an argument)"
)
is_single_file = _is_single_file_path_or_url(pretrained_model_name_or_path)
if is_single_file and self.type_hint is None:
raise ValueError(
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
) )
if self.type_hint is None: if self.type_hint is None:
try: try:
from diffusers import AutoModel from diffusers import AutoModel
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs) component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e: except Exception as e:
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}") raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
# update type_hint if AutoModel load successfully # update type_hint if AutoModel load successfully
self.type_hint = component.__class__ self.type_hint = component.__class__
else: else:
# determine load method
load_method = (
getattr(self.type_hint, "from_single_file")
if is_single_file
else getattr(self.type_hint, "from_pretrained")
)
try: try:
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs) component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e: except Exception as e:
raise ValueError(f"Unable to load {self.name} using load method: {e}") raise ValueError(f"Unable to load {self.name} using load method: {e}")
self.repo = repo self.pretrained_model_name_or_path = pretrained_model_name_or_path
for k, v in load_kwargs.items(): for k, v in load_kwargs.items():
setattr(self, k, v) setattr(self, k, v)
component._diffusers_load_id = self.load_id component._diffusers_load_id = self.load_id

View File

@@ -36,7 +36,7 @@ from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestFluxModularPipelineFast(ModularPipelineTesterMixin): class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline pipeline_class = FluxModularPipeline
pipeline_blocks_class = FluxAutoBlocks pipeline_blocks_class = FluxAutoBlocks
repo = "hf-internal-testing/tiny-flux-modular" pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
params = frozenset(["prompt", "height", "width", "guidance_scale"]) params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"]) batch_params = frozenset(["prompt"])
@@ -62,7 +62,7 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin): class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline pipeline_class = FluxModularPipeline
pipeline_blocks_class = FluxAutoBlocks pipeline_blocks_class = FluxAutoBlocks
repo = "hf-internal-testing/tiny-flux-modular" pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"]) params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"]) batch_params = frozenset(["prompt", "image"])
@@ -128,7 +128,7 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin): class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxKontextModularPipeline pipeline_class = FluxKontextModularPipeline
pipeline_blocks_class = FluxKontextAutoBlocks pipeline_blocks_class = FluxKontextAutoBlocks
repo = "hf-internal-testing/tiny-flux-kontext-pipe" pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"]) params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"]) batch_params = frozenset(["prompt", "image"])

View File

@@ -32,7 +32,7 @@ from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPip
class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin): class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
pipeline_class = QwenImageModularPipeline pipeline_class = QwenImageModularPipeline
pipeline_blocks_class = QwenImageAutoBlocks pipeline_blocks_class = QwenImageAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-modular" pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-modular"
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"]) params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"]) batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
@@ -58,7 +58,7 @@ class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuider
class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin): class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
pipeline_class = QwenImageEditModularPipeline pipeline_class = QwenImageEditModularPipeline
pipeline_blocks_class = QwenImageEditAutoBlocks pipeline_blocks_class = QwenImageEditAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-edit-modular" pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-modular"
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"]) params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"]) batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
@@ -84,7 +84,7 @@ class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGu
class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin): class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
pipeline_class = QwenImageEditPlusModularPipeline pipeline_class = QwenImageEditPlusModularPipeline
pipeline_blocks_class = QwenImageEditPlusAutoBlocks pipeline_blocks_class = QwenImageEditPlusAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular" pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
# No `mask_image` yet. # No `mask_image` yet.
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"]) params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])

View File

@@ -105,7 +105,7 @@ class SDXLModularIPAdapterTesterMixin:
blocks = self.pipeline_blocks_class() blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter") _ = blocks.sub_blocks.pop("ip_adapter")
pipe = blocks.init_pipeline(self.repo) pipe = blocks.init_pipeline(self.pretrained_model_name_or_path)
pipe.load_components(torch_dtype=torch.float32) pipe.load_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
@@ -278,7 +278,7 @@ class TestSDXLModularPipelineFast(
pipeline_class = StableDiffusionXLModularPipeline pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks pipeline_blocks_class = StableDiffusionXLAutoBlocks
repo = "hf-internal-testing/tiny-sdxl-modular" pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
params = frozenset( params = frozenset(
[ [
"prompt", "prompt",
@@ -325,7 +325,7 @@ class TestSDXLImg2ImgModularPipelineFast(
pipeline_class = StableDiffusionXLModularPipeline pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks pipeline_blocks_class = StableDiffusionXLAutoBlocks
repo = "hf-internal-testing/tiny-sdxl-modular" pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
params = frozenset( params = frozenset(
[ [
"prompt", "prompt",
@@ -378,7 +378,7 @@ class SDXLInpaintingModularPipelineFastTests(
pipeline_class = StableDiffusionXLModularPipeline pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks pipeline_blocks_class = StableDiffusionXLAutoBlocks
repo = "hf-internal-testing/tiny-sdxl-modular" pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
params = frozenset( params = frozenset(
[ [
"prompt", "prompt",

View File

@@ -43,9 +43,9 @@ class ModularPipelineTesterMixin:
) )
@property @property
def repo(self) -> str: def pretrained_model_name_or_path(self) -> str:
raise NotImplementedError( raise NotImplementedError(
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference." "You need to set the attribute `pretrained_model_name_or_path` in the child test class. See existing pipeline tests for reference."
) )
@property @property
@@ -103,7 +103,9 @@ class ModularPipelineTesterMixin:
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager) pipeline = self.pipeline_blocks_class().init_pipeline(
self.pretrained_model_name_or_path, components_manager=components_manager
)
pipeline.load_components(torch_dtype=torch_dtype) pipeline.load_components(torch_dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=None) pipeline.set_progress_bar_config(disable=None)
return pipeline return pipeline