mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-24 09:28:19 +08:00
Compare commits
7 Commits
properly-d
...
tests-cond
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
585aa6667f | ||
|
|
ab07b603ab | ||
|
|
7601432849 | ||
|
|
2a5f136142 | ||
|
|
4ade16db58 | ||
|
|
58c304595d | ||
|
|
55c563281a |
@@ -143,7 +143,6 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
|
||||
@@ -229,7 +229,6 @@ class AttentionBackendName(str, Enum):
|
||||
FLASH_HUB = "flash_hub"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
FLASH_VARLEN_HUB = "flash_varlen_hub"
|
||||
FLASH_4_HUB = "flash_4_hub"
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
@@ -359,11 +358,6 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
function_attr="sageattn",
|
||||
version=1,
|
||||
),
|
||||
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-staging/flash-attn4",
|
||||
function_attr="flash_attn_func",
|
||||
version=0,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -527,7 +521,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB,
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
]:
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
@@ -538,11 +531,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.AITER:
|
||||
if not _CAN_USE_AITER_ATTN:
|
||||
raise RuntimeError(
|
||||
@@ -2688,37 +2676,6 @@ def _flash_attention_3_varlen_hub(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_4_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
return (out[0], out[1]) if return_lse else out[0]
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_VARLEN_3,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -324,18 +324,17 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
if hasattr(self.language_model, "_get_initial_cache_position"):
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
# prepare model inputs
|
||||
|
||||
@@ -29,7 +29,6 @@ from numpy.linalg import norm
|
||||
from packaging import version
|
||||
|
||||
from .constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from .deprecation_utils import deprecate
|
||||
from .import_utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
@@ -68,11 +67,9 @@ else:
|
||||
global_rng = random.Random()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
deprecate(
|
||||
"diffusers.utils.testing_utils",
|
||||
"1.0.0",
|
||||
"diffusers.utils.testing_utils is deprecated and will be removed in a future version. "
|
||||
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. ",
|
||||
logger.warning(
|
||||
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
|
||||
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
|
||||
)
|
||||
_required_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
|
||||
@@ -28,6 +28,7 @@ from diffusers.utils.import_utils import is_peft_available
|
||||
|
||||
from ..testing_utils import (
|
||||
floats_tensor,
|
||||
is_flaky,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
skip_mps,
|
||||
@@ -45,6 +46,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
@is_flaky(max_attempts=10, description="very flaky class")
|
||||
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanVACEPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
@@ -71,8 +73,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"base_dim": 3,
|
||||
"z_dim": 4,
|
||||
"dim_mult": [1, 1, 1, 1],
|
||||
"latents_mean": [-0.7571, -0.7089, -0.9113, -0.7245],
|
||||
"latents_std": [2.8184, 1.4541, 2.3275, 2.6558],
|
||||
"latents_mean": torch.randn(4).numpy().tolist(),
|
||||
"latents_std": torch.randn(4).numpy().tolist(),
|
||||
"num_res_blocks": 1,
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
|
||||
242
tests/modular_pipelines/test_conditional_pipeline_blocks.py
Normal file
242
tests/modular_pipelines/test_conditional_pipeline_blocks.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
AutoPipelineBlocks,
|
||||
ConditionalPipelineBlocks,
|
||||
InputParam,
|
||||
ModularPipelineBlocks,
|
||||
)
|
||||
|
||||
|
||||
class TextToImageBlock(ModularPipelineBlocks):
|
||||
model_name = "text2img"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "text-to-image workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "text2img"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ImageToImageBlock(ModularPipelineBlocks):
|
||||
model_name = "img2img"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt"), InputParam(name="image")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "image-to-image workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "img2img"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class InpaintBlock(ModularPipelineBlocks):
|
||||
model_name = "inpaint"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "inpaint workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "inpaint"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ConditionalImageBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = "text2img"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Conditional image blocks for testing"
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None # falls back to default_block_name
|
||||
|
||||
|
||||
class OptionalConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = None # no default; block can be skipped
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Optional conditional blocks (skippable)"
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None
|
||||
|
||||
|
||||
class AutoImageBlocks(AutoPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto image blocks for testing"
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksSelectBlock:
|
||||
def test_select_block_with_mask(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask="something") == "inpaint"
|
||||
|
||||
def test_select_block_with_image(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(image="something") == "img2img"
|
||||
|
||||
def test_select_block_with_mask_and_image(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask="m", image="i") == "inpaint"
|
||||
|
||||
def test_select_block_no_triggers_returns_none(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block() is None
|
||||
|
||||
def test_select_block_explicit_none_values(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask=None, image=None) is None
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksWorkflowSelection:
|
||||
def test_default_workflow_when_no_triggers(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert execution is not None
|
||||
assert isinstance(execution, TextToImageBlock)
|
||||
|
||||
def test_mask_trigger_selects_inpaint(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_image_trigger_selects_img2img(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
def test_mask_and_image_selects_inpaint(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True, image=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_skippable_block_returns_none(self):
|
||||
blocks = OptionalConditionalBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert execution is None
|
||||
|
||||
def test_skippable_block_still_selects_when_triggered(self):
|
||||
blocks = OptionalConditionalBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
|
||||
class TestAutoPipelineBlocksSelectBlock:
|
||||
def test_auto_select_mask(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(mask="m") == "inpaint"
|
||||
|
||||
def test_auto_select_image(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(image="i") == "img2img"
|
||||
|
||||
def test_auto_select_default(self):
|
||||
blocks = AutoImageBlocks()
|
||||
# No trigger -> returns None -> falls back to default (text2img)
|
||||
assert blocks.select_block() is None
|
||||
|
||||
def test_auto_select_priority_order(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(mask="m", image="i") == "inpaint"
|
||||
|
||||
|
||||
class TestAutoPipelineBlocksWorkflowSelection:
|
||||
def test_auto_default_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert isinstance(execution, TextToImageBlock)
|
||||
|
||||
def test_auto_mask_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_auto_image_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksStructure:
|
||||
def test_block_names_accessible(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
sub = dict(blocks.sub_blocks)
|
||||
assert set(sub.keys()) == {"inpaint", "img2img", "text2img"}
|
||||
|
||||
def test_sub_block_types(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
sub = dict(blocks.sub_blocks)
|
||||
assert isinstance(sub["inpaint"], InpaintBlock)
|
||||
assert isinstance(sub["img2img"], ImageToImageBlock)
|
||||
assert isinstance(sub["text2img"], TextToImageBlock)
|
||||
|
||||
def test_description(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert "Conditional" in blocks.description
|
||||
@@ -5,16 +5,10 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines import (
|
||||
ConditionalPipelineBlocks,
|
||||
LoopSequentialPipelineBlocks,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
@@ -25,7 +19,6 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
@@ -33,33 +26,6 @@ from ..testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
def _get_specified_components(path_or_repo_id, cache_dir=None):
|
||||
if os.path.isdir(path_or_repo_id):
|
||||
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=path_or_repo_id,
|
||||
filename="modular_model_index.json",
|
||||
local_dir=cache_dir,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
components = set()
|
||||
for k, v in config.items():
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
continue
|
||||
for entry in v:
|
||||
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
|
||||
components.add(k)
|
||||
break
|
||||
return components
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
@@ -388,39 +354,6 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_load_expected_components_from_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
|
||||
if not expected:
|
||||
pytest.skip("Skipping test as we couldn't fetch the expected components.")
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in pipe.components
|
||||
if getattr(pipe, name, None) is not None
|
||||
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
|
||||
|
||||
def test_load_expected_components_from_save_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
save_dir = str(tmp_path / "saved-pipeline")
|
||||
pipe.save_pretrained(save_dir)
|
||||
|
||||
expected = _get_specified_components(save_dir)
|
||||
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
|
||||
loaded_pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in loaded_pipe.components
|
||||
if getattr(loaded_pipe, name, None) is not None
|
||||
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, (
|
||||
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
|
||||
)
|
||||
|
||||
def test_modular_index_consistency(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
components_spec = pipe._component_specs
|
||||
@@ -498,117 +431,6 @@ class ModularGuiderTesterMixin:
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def get_dummy_conditional_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
class DummyConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [DummyBlockOne, DummyBlockTwo]
|
||||
block_names = ["block_one", "block_two"]
|
||||
block_trigger_inputs = []
|
||||
|
||||
def select_block(self, **kwargs):
|
||||
return "block_one"
|
||||
|
||||
return DummyConditionalBlocks()
|
||||
|
||||
def get_dummy_loop_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
|
||||
|
||||
def test_sequential_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_sequential_block_requirements_warnings(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
def test_conditional_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_conditional_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
def test_loop_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_loop_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
|
||||
class TestModularModelCardContent:
|
||||
def create_mock_block(self, name="TestBlock", description="Test block description"):
|
||||
class MockBlock:
|
||||
|
||||
@@ -24,14 +24,18 @@ import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.modular_pipelines import (
|
||||
ComponentSpec,
|
||||
ConditionalPipelineBlocks,
|
||||
InputParam,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipelineBlocks,
|
||||
OutputParam,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
WanModularPipeline,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
|
||||
from ..testing_utils import CaptureLogger, nightly, require_torch, require_torch_accelerator, slow, torch_device
|
||||
|
||||
|
||||
def _create_tiny_model_dir(model_dir):
|
||||
@@ -463,6 +467,117 @@ class TestModularCustomBlocks:
|
||||
assert output_prompt.startswith("Modular diffusers + ")
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def get_dummy_conditional_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
class DummyConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [DummyBlockOne, DummyBlockTwo]
|
||||
block_names = ["block_one", "block_two"]
|
||||
block_trigger_inputs = []
|
||||
|
||||
def select_block(self, **kwargs):
|
||||
return "block_one"
|
||||
|
||||
return DummyConditionalBlocks()
|
||||
|
||||
def get_dummy_loop_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
|
||||
|
||||
def test_sequential_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_sequential_block_requirements_warnings(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
def test_conditional_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_conditional_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
def test_loop_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_loop_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch
|
||||
|
||||
@@ -13,10 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -184,25 +182,6 @@ class DeprecateTester(unittest.TestCase):
|
||||
assert str(warning.warning) == "This message is better!!!"
|
||||
assert "diffusers/tests/others/test_utils.py" in warning.filename
|
||||
|
||||
def test_deprecate_testing_utils_module(self):
|
||||
import diffusers.utils.testing_utils
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
warnings.simplefilter("always")
|
||||
importlib.reload(diffusers.utils.testing_utils)
|
||||
|
||||
deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, FutureWarning)]
|
||||
assert len(deprecation_warnings) >= 1, "Expected at least one FutureWarning from diffusers.utils.testing_utils"
|
||||
|
||||
messages = [str(w.message) for w in deprecation_warnings]
|
||||
assert any("diffusers.utils.testing_utils" in msg for msg in messages), (
|
||||
f"Expected a deprecation warning mentioning 'diffusers.utils.testing_utils', got: {messages}"
|
||||
)
|
||||
assert any(
|
||||
"diffusers.utils.testing_utils is deprecated and will be removed in a future version." in msg
|
||||
for msg in messages
|
||||
), f"Expected deprecation message substring not found, got: {messages}"
|
||||
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
|
||||
class ExpectationsTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user