Compare commits

..

7 Commits

Author SHA1 Message Date
Sayak Paul
4ade16db58 Merge branch 'main' into tests-conditional-pipeline-blocks 2026-03-12 11:19:09 +05:30
Alvaro Bartolome
81c354d879 Add PRXPipeline in AUTO_TEXT2IMAGE_PIPELINES_MAPPING (#13257) 2026-03-11 14:39:24 -03:00
Miguel Martin
0a2c26d0a4 Update Documentation for NVIDIA Cosmos (#13251)
* fix docs

* update main example
2026-03-11 09:14:56 -07:00
Dhruv Nair
07c5ba8eee [Context Parallel] Add support for custom device mesh (#13064)
* add custom mesh support

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-11 16:42:11 +05:30
Dhruv Nair
897aed72fa [Quantization] Deprecate Quanto (#13180)
* update

* update
2026-03-11 09:26:46 +05:30
sayakpaul
58c304595d remove 2026-03-10 18:25:02 +05:30
sayakpaul
55c563281a implement test suite for conditional blocks. 2026-03-10 18:24:49 +05:30
14 changed files with 533 additions and 172 deletions

View File

@@ -532,8 +532,6 @@
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/controlnet_union
title: ControlNetUnion
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/ddim
title: DDIM
- local: api/pipelines/ddpm
@@ -677,6 +675,8 @@
title: CogVideoX
- local: api/pipelines/consisid
title: ConsisID
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/helios

View File

@@ -21,29 +21,31 @@
> [!TIP]
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## Loading original format checkpoints
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
## Basic usage
```python
import torch
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
from diffusers import Cosmos2_5_PredictBasePipeline
from diffusers.utils import export_to_video
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
transformer = CosmosTransformer3DModel.from_single_file(
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
model_id = "nvidia/Cosmos-Predict2.5-2B"
pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
prompt = "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow advance of traffic through the frosty city corridor."
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
output = pipe(
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
).images[0]
output.save("output.png")
image=None,
video=None,
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=93,
generator=torch.Generator().manual_seed(1),
).frames[0]
export_to_video(output, "text2world.mp4", fps=16)
```
## Cosmos2_5_TransferPipeline

View File

@@ -44,6 +44,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
| [ControlNet-XS](controlnetxs) | text2image |
| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |
| [Cosmos](cosmos) | text2video, video2video |
| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
| [DDIM](ddim) | unconditional image generation |
| [DDPM](ddpm) | unconditional image generation |

View File

@@ -60,6 +60,16 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
ulysses_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -68,6 +78,7 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
@@ -124,7 +135,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh._flatten()
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = torch.distributed.device_mesh.init_device_mesh(
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -95,6 +95,7 @@ from .pag import (
StableDiffusionXLPAGPipeline,
)
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .prx import PRXPipeline
from .qwenimage import (
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
@@ -185,6 +186,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
("z-image-omni", ZImageOmniPipeline),
("ovis", OvisImagePipeline),
("prx", PRXPipeline),
]
)

View File

@@ -82,13 +82,16 @@ EXAMPLE_DOC_STRING = """
```python
>>> import cv2
>>> import numpy as np
>>> from PIL import Image
>>> import torch
>>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel
>>> from diffusers.utils import export_to_video, load_video
>>> model_id = "nvidia/Cosmos-Transfer2.5-2B"
>>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur)
>>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge")
>>> controlnet = AutoModel.from_pretrained(
... model_id, revision="diffusers/controlnet/general/edge", torch_dtype=torch.bfloat16
... )
>>> pipe = Cosmos2_5_TransferPipeline.from_pretrained(
... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16
... )

View File

@@ -36,7 +36,7 @@ from typing import Any, Callable
from packaging import version
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
if is_torch_available():
@@ -844,6 +844,8 @@ class QuantoConfig(QuantizationConfigMixin):
modules_to_not_convert: list[str] | None = None,
**kwargs,
):
deprecation_message = "`QuantoConfig` is deprecated and will be removed in version 1.0.0."
deprecate("QuantoConfig", "1.0.0", deprecation_message)
self.quant_method = QuantizationMethod.QUANTO
self.weights_dtype = weights_dtype
self.modules_to_not_convert = modules_to_not_convert

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
from diffusers.utils.import_utils import is_optimum_quanto_version
from ...utils import (
deprecate,
get_module_from_name,
is_accelerate_available,
is_accelerate_version,
@@ -42,6 +43,9 @@ class QuantoQuantizer(DiffusersQuantizer):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
deprecation_message = "The Quanto quantizer is deprecated and will be removed in version 1.0.0."
deprecate("QuantoQuantizer", "1.0.0", deprecation_message)
if not is_optimum_quanto_available():
raise ImportError(
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"

View File

@@ -13,8 +13,8 @@
# limitations under the License.
import gc
import unittest
import pytest
import torch
from diffusers.hooks import HookRegistry, ModelHook
@@ -134,18 +134,20 @@ class SkipLayerHook(ModelHook):
return output
class TestHooks:
class HookTests(unittest.TestCase):
in_features = 4
hidden_features = 8
out_features = 4
num_layers = 2
def setup_method(self):
def setUp(self):
params = self.get_module_parameters()
self.model = DummyModel(**params)
self.model.to(torch_device)
def teardown_method(self):
def tearDown(self):
super().tearDown()
del self.model
gc.collect()
free_memory()
@@ -169,20 +171,20 @@ class TestHooks:
registry_repr = repr(registry)
expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)"
assert len(registry.hooks) == 2
assert registry._hook_order == ["add_hook", "multiply_hook"]
assert registry_repr == expected_repr
self.assertEqual(len(registry.hooks), 2)
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
self.assertEqual(registry_repr, expected_repr)
registry.remove_hook("add_hook")
assert len(registry.hooks) == 1
assert registry._hook_order == ["multiply_hook"]
self.assertEqual(len(registry.hooks), 1)
self.assertEqual(registry._hook_order, ["multiply_hook"])
def test_stateful_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(StatefulAddHook(1), "stateful_add_hook")
assert registry.hooks["stateful_add_hook"].increment == 0
self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
num_repeats = 3
@@ -192,13 +194,13 @@ class TestHooks:
if i == 0:
output1 = result
assert registry.get_hook("stateful_add_hook").increment == num_repeats
self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)
registry.reset_stateful_hooks()
output2 = self.model(input)
assert registry.get_hook("stateful_add_hook").increment == 1
assert torch.allclose(output1, output2)
self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
self.assertTrue(torch.allclose(output1, output2))
def test_inference(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -216,9 +218,9 @@ class TestHooks:
new_input = input * 2 + 1
output3 = self.model(new_input).mean().detach().cpu().item()
assert output1 == pytest.approx(output2, abs=5e-6)
assert output1 == pytest.approx(output3, abs=5e-6)
assert output2 == pytest.approx(output3, abs=5e-6)
self.assertAlmostEqual(output1, output2, places=5)
self.assertAlmostEqual(output1, output3, places=5)
self.assertAlmostEqual(output2, output3, places=5)
def test_skip_layer_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -226,29 +228,30 @@ class TestHooks:
input = torch.zeros(1, 4, device=torch_device)
output = self.model(input).mean().detach().cpu().item()
assert output == 0.0
self.assertEqual(output, 0.0)
registry.remove_hook("skip_layer_hook")
registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
assert output != 0.0
self.assertNotEqual(output, 0.0)
def test_skip_layer_internal_block(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
input = torch.zeros(1, 4, device=torch_device)
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
with pytest.raises(RuntimeError, match="mat1 and mat2 shapes cannot be multiplied"):
with self.assertRaises(RuntimeError) as cm:
self.model(input).mean().detach().cpu().item()
self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))
registry.remove_hook("skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
assert output != 0.0
self.assertNotEqual(output, 0.0)
registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
assert output != 0.0
self.assertNotEqual(output, 0.0)
def test_invocation_order_stateful_first(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -275,7 +278,7 @@ class TestHooks:
.replace(" ", "")
.replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
@@ -286,7 +289,7 @@ class TestHooks:
.replace(" ", "")
.replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)
def test_invocation_order_stateful_middle(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -313,7 +316,7 @@ class TestHooks:
.replace(" ", "")
.replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
@@ -324,7 +327,7 @@ class TestHooks:
.replace(" ", "")
.replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook_2")
with CaptureLogger(logger) as cap_logger:
@@ -333,7 +336,7 @@ class TestHooks:
expected_invocation_order_log = (
("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)
def test_invocation_order_stateful_last(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -360,7 +363,7 @@ class TestHooks:
.replace(" ", "")
.replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
@@ -371,4 +374,4 @@ class TestHooks:
.replace(" ", "")
.replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)

View File

@@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.eval()
# Move inputs to device
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
@@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
):
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@@ -126,3 +174,48 @@ class ContextParallelTesterMixin:
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
],
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
)
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
cp_dict = {cp_type: world_size}
master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()
mp.spawn(
_custom_mesh_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

View File

@@ -0,0 +1,242 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers.modular_pipelines import (
AutoPipelineBlocks,
ConditionalPipelineBlocks,
InputParam,
ModularPipelineBlocks,
)
class TextToImageBlock(ModularPipelineBlocks):
model_name = "text2img"
@property
def inputs(self):
return [InputParam(name="prompt")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "text-to-image workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "text2img"
self.set_block_state(state, block_state)
return components, state
class ImageToImageBlock(ModularPipelineBlocks):
model_name = "img2img"
@property
def inputs(self):
return [InputParam(name="prompt"), InputParam(name="image")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "image-to-image workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "img2img"
self.set_block_state(state, block_state)
return components, state
class InpaintBlock(ModularPipelineBlocks):
model_name = "inpaint"
@property
def inputs(self):
return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "inpaint workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "inpaint"
self.set_block_state(state, block_state)
return components, state
class ConditionalImageBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = "text2img"
@property
def description(self):
return "Conditional image blocks for testing"
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None # falls back to default_block_name
class OptionalConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = None # no default; block can be skipped
@property
def description(self):
return "Optional conditional blocks (skippable)"
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None
class AutoImageBlocks(AutoPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image", None]
@property
def description(self):
return "Auto image blocks for testing"
class TestConditionalPipelineBlocksSelectBlock:
def test_select_block_with_mask(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask="something") == "inpaint"
def test_select_block_with_image(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(image="something") == "img2img"
def test_select_block_with_mask_and_image(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask="m", image="i") == "inpaint"
def test_select_block_no_triggers_returns_none(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block() is None
def test_select_block_explicit_none_values(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask=None, image=None) is None
class TestConditionalPipelineBlocksWorkflowSelection:
def test_default_workflow_when_no_triggers(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks()
assert execution is not None
assert isinstance(execution, TextToImageBlock)
def test_mask_trigger_selects_inpaint(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(mask=True)
assert isinstance(execution, InpaintBlock)
def test_image_trigger_selects_img2img(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
def test_mask_and_image_selects_inpaint(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(mask=True, image=True)
assert isinstance(execution, InpaintBlock)
def test_skippable_block_returns_none(self):
blocks = OptionalConditionalBlocks()
execution = blocks.get_execution_blocks()
assert execution is None
def test_skippable_block_still_selects_when_triggered(self):
blocks = OptionalConditionalBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
class TestAutoPipelineBlocksSelectBlock:
def test_auto_select_mask(self):
blocks = AutoImageBlocks()
assert blocks.select_block(mask="m") == "inpaint"
def test_auto_select_image(self):
blocks = AutoImageBlocks()
assert blocks.select_block(image="i") == "img2img"
def test_auto_select_default(self):
blocks = AutoImageBlocks()
# No trigger -> returns None -> falls back to default (text2img)
assert blocks.select_block() is None
def test_auto_select_priority_order(self):
blocks = AutoImageBlocks()
assert blocks.select_block(mask="m", image="i") == "inpaint"
class TestAutoPipelineBlocksWorkflowSelection:
def test_auto_default_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks()
assert isinstance(execution, TextToImageBlock)
def test_auto_mask_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks(mask=True)
assert isinstance(execution, InpaintBlock)
def test_auto_image_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
class TestConditionalPipelineBlocksStructure:
def test_block_names_accessible(self):
blocks = ConditionalImageBlocks()
sub = dict(blocks.sub_blocks)
assert set(sub.keys()) == {"inpaint", "img2img", "text2img"}
def test_sub_block_types(self):
blocks = ConditionalImageBlocks()
sub = dict(blocks.sub_blocks)
assert isinstance(sub["inpaint"], InpaintBlock)
assert isinstance(sub["img2img"], ImageToImageBlock)
assert isinstance(sub["text2img"], TextToImageBlock)
def test_description(self):
blocks = ConditionalImageBlocks()
assert "Conditional" in blocks.description

View File

@@ -9,11 +9,6 @@ import torch
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines import (
ConditionalPipelineBlocks,
LoopSequentialPipelineBlocks,
SequentialPipelineBlocks,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
@@ -24,7 +19,6 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
from diffusers.utils import logging
from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
@@ -437,117 +431,6 @@ class ModularGuiderTesterMixin:
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(str(tmp_path))
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
class TestModularModelCardContent:
def create_mock_block(self, name="TestBlock", description="Test block description"):
class MockBlock:

View File

@@ -24,14 +24,18 @@ import torch
from diffusers import FluxTransformer2DModel
from diffusers.modular_pipelines import (
ComponentSpec,
ConditionalPipelineBlocks,
InputParam,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
OutputParam,
PipelineState,
SequentialPipelineBlocks,
WanModularPipeline,
)
from diffusers.utils import logging
from ..testing_utils import nightly, require_torch, slow
from ..testing_utils import CaptureLogger, nightly, require_torch, slow
class DummyCustomBlockSimple(ModularPipelineBlocks):
@@ -354,6 +358,117 @@ class TestModularCustomBlocks:
assert output_prompt.startswith("Modular diffusers + ")
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(str(tmp_path))
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
@slow
@nightly
@require_torch