Compare commits

..

3 Commits

Author SHA1 Message Date
DN6
52bd90a3b3 update 2026-03-11 13:40:14 +05:30
DN6
b28e9204f7 update 2026-03-11 13:32:16 +05:30
Dhruv Nair
897aed72fa [Quantization] Deprecate Quanto (#13180)
* update

* update
2026-03-11 09:26:46 +05:30
7 changed files with 298 additions and 575 deletions

View File

@@ -36,7 +36,7 @@ from typing import Any, Callable
from packaging import version
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
if is_torch_available():
@@ -844,6 +844,8 @@ class QuantoConfig(QuantizationConfigMixin):
modules_to_not_convert: list[str] | None = None,
**kwargs,
):
deprecation_message = "`QuantoConfig` is deprecated and will be removed in version 1.0.0."
deprecate("QuantoConfig", "1.0.0", deprecation_message)
self.quant_method = QuantizationMethod.QUANTO
self.weights_dtype = weights_dtype
self.modules_to_not_convert = modules_to_not_convert

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
from diffusers.utils.import_utils import is_optimum_quanto_version
from ...utils import (
deprecate,
get_module_from_name,
is_accelerate_available,
is_accelerate_version,
@@ -42,6 +43,9 @@ class QuantoQuantizer(DiffusersQuantizer):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
deprecation_message = "The Quanto quantizer is deprecated and will be removed in version 1.0.0."
deprecate("QuantoQuantizer", "1.0.0", deprecation_message)
if not is_optimum_quanto_available():
raise ImportError(
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,59 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import LTXVideoTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class LTXTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return LTXVideoTransformer3DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 2
height = 16
width = 16
embedding_dim = 16
sequence_length = 16
def output_shape(self) -> tuple[int, int]:
return (512, 4)
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def input_shape(self) -> tuple[int, int]:
return (512, 4)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self):
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"encoder_attention_mask": encoder_attention_mask,
"num_frames": num_frames,
"height": height,
"width": width,
}
@property
def input_shape(self):
return (512, 4)
@property
def output_shape(self):
return (512, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -75,16 +62,57 @@ class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
"qk_norm": "rms_norm_across_heads",
"caption_channels": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 2
num_channels = 4
num_frames = 2
height = 16
width = 16
embedding_dim = 16
sequence_length = 16
return {
"hidden_states": randn_tensor(
(batch_size, num_frames * height * width, num_channels),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).bool().to(torch_device),
"num_frames": num_frames,
"height": height,
"width": width,
}
class TestLTXTransformer(LTXTransformerTesterConfig, ModelTesterMixin):
"""Core model tests for LTX Video Transformer."""
class TestLTXTransformerMemory(LTXTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for LTX Video Transformer."""
class TestLTXTransformerTraining(LTXTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for LTX Video Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTXVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
super().test_gradient_checkpointing_is_applied(expected_set={"LTXVideoTransformer3DModel"})
class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
class TestLTXTransformerCompile(LTXTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for LTX Video Transformer."""
def prepare_init_args_and_inputs_for_common(self):
return LTXTransformerTests().prepare_init_args_and_inputs_for_common()
# TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub
# class TestLTXTransformerBitsAndBytes(LTXTransformerTesterConfig, BitsAndBytesTesterMixin):
# """BitsAndBytes quantization tests for LTX Video Transformer."""
# TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub
# class TestLTXTransformerTorchAo(LTXTransformerTesterConfig, TorchAoTesterMixin):
# """TorchAo quantization tests for LTX Video Transformer."""

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,77 +12,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
import torch
from diffusers import LTX2VideoTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTX2VideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class LTX2TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return LTX2VideoTransformer3DModel
@property
def dummy_input(self):
# Common
batch_size = 2
def output_shape(self) -> tuple[int, int]:
return (512, 4)
# Video
num_frames = 2
num_channels = 4
height = 16
width = 16
@property
def input_shape(self) -> tuple[int, int]:
return (512, 4)
# Audio
audio_num_frames = 9
audio_num_channels = 2
num_mel_bins = 2
@property
def main_input_name(self) -> str:
return "hidden_states"
# Text
embedding_dim = 16
sequence_length = 16
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to(
torch_device
)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.rand((batch_size,)).to(torch_device) * 1000
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self):
return {
"hidden_states": hidden_states,
"audio_hidden_states": audio_hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"audio_encoder_hidden_states": audio_encoder_hidden_states,
"timestep": timestep,
"encoder_attention_mask": encoder_attention_mask,
"num_frames": num_frames,
"height": height,
"width": width,
"audio_num_frames": audio_num_frames,
"fps": 25.0,
}
@property
def input_shape(self):
return (512, 4)
@property
def output_shape(self):
return (512, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"patch_size": 1,
@@ -101,122 +72,80 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
"caption_channels": 16,
"rope_double_precision": False,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 2
num_frames = 2
num_channels = 4
height = 16
width = 16
audio_num_frames = 9
audio_num_channels = 2
num_mel_bins = 2
embedding_dim = 16
sequence_length = 16
return {
"hidden_states": randn_tensor(
(batch_size, num_frames * height * width, num_channels),
generator=self.generator,
device=torch_device,
),
"audio_hidden_states": randn_tensor(
(batch_size, audio_num_frames, audio_num_channels * num_mel_bins),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"audio_encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": (randn_tensor((batch_size,), generator=self.generator, device=torch_device).abs() * 1000),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).bool().to(torch_device),
"num_frames": num_frames,
"height": height,
"width": width,
"audio_num_frames": audio_num_frames,
"fps": 25.0,
}
class TestLTX2Transformer(LTX2TransformerTesterConfig, ModelTesterMixin):
"""Core model tests for LTX2 Video Transformer."""
class TestLTX2TransformerMemory(LTX2TransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for LTX2 Video Transformer."""
class TestLTX2TransformerTraining(LTX2TransformerTesterConfig, TrainingTesterMixin):
"""Training tests for LTX2 Video Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTX2VideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# def test_ltx2_consistency(self, seed=0, dtype=torch.float32):
# torch.manual_seed(seed)
# init_dict, _ = self.prepare_init_args_and_inputs_for_common()
# # Calculate dummy inputs in a custom manner to ensure compatibility with original code
# batch_size = 2
# num_frames = 9
# latent_frames = 2
# text_embedding_dim = 16
# text_seq_len = 16
# fps = 25.0
# sampling_rate = 16000.0
# hop_length = 160.0
# sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000
# timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device)
# num_channels = 4
# latent_height = 4
# latent_width = 4
# hidden_states = torch.randn(
# (batch_size, num_channels, latent_frames, latent_height, latent_width),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# # Patchify video latents (with patch_size (1, 1, 1))
# hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1)
# hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
# encoder_hidden_states = torch.randn(
# (batch_size, text_seq_len, text_embedding_dim),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# audio_num_channels = 2
# num_mel_bins = 2
# latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps))
# audio_hidden_states = torch.randn(
# (batch_size, audio_num_channels, latent_length, num_mel_bins),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# # Patchify audio latents
# audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3)
# audio_encoder_hidden_states = torch.randn(
# (batch_size, text_seq_len, text_embedding_dim),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# inputs_dict = {
# "hidden_states": hidden_states.to(device=torch_device),
# "audio_hidden_states": audio_hidden_states.to(device=torch_device),
# "encoder_hidden_states": encoder_hidden_states.to(device=torch_device),
# "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device),
# "timestep": timestep,
# "num_frames": latent_frames,
# "height": latent_height,
# "width": latent_width,
# "audio_num_frames": num_frames,
# "fps": 25.0,
# }
# model = self.model_class.from_pretrained(
# "diffusers-internal-dev/dummy-ltx2",
# subfolder="transformer",
# device_map="cpu",
# )
# # torch.manual_seed(seed)
# # model = self.model_class(**init_dict)
# model.to(torch_device)
# model.eval()
# with attention_backend("native"):
# with torch.no_grad():
# output = model(**inputs_dict)
# video_output, audio_output = output.to_tuple()
# self.assertIsNotNone(video_output)
# self.assertIsNotNone(audio_output)
# # input & output have to have the same shape
# video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels)
# self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match")
# audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins)
# self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match")
# # Check against expected slice
# # fmt: off
# video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676])
# audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692])
# # fmt: on
# video_output_flat = video_output.cpu().flatten().float()
# video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]])
# self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4))
# audio_output_flat = audio_output.cpu().flatten().float()
# audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]])
# self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))
super().test_gradient_checkpointing_is_applied(expected_set={"LTX2VideoTransformer3DModel"})
class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTX2VideoTransformer3DModel
class TestLTX2TransformerAttention(LTX2TransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for LTX2 Video Transformer."""
def prepare_init_args_and_inputs_for_common(self):
return LTX2TransformerTests().prepare_init_args_and_inputs_for_common()
@pytest.mark.skip(
"LTX2Attention does not set is_cross_attention, so fuse_projections tries to fuse Q+K+V together even for cross-attention modules with different input dimensions."
)
def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0):
pass
class TestLTX2TransformerCompile(LTX2TransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for LTX2 Video Transformer."""
# TODO: Add pretrained_model_name_or_path once a tiny LTX2 model is available on the Hub
# class TestLTX2TransformerBitsAndBytes(LTX2TransformerTesterConfig, BitsAndBytesTesterMixin):
# """BitsAndBytes quantization tests for LTX2 Video Transformer."""
# TODO: Add pretrained_model_name_or_path once a tiny LTX2 model is available on the Hub
# class TestLTX2TransformerTorchAo(LTX2TransformerTesterConfig, TorchAoTesterMixin):
# """TorchAo quantization tests for LTX2 Video Transformer."""

View File

@@ -1,242 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers.modular_pipelines import (
AutoPipelineBlocks,
ConditionalPipelineBlocks,
InputParam,
ModularPipelineBlocks,
)
class TextToImageBlock(ModularPipelineBlocks):
model_name = "text2img"
@property
def inputs(self):
return [InputParam(name="prompt")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "text-to-image workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "text2img"
self.set_block_state(state, block_state)
return components, state
class ImageToImageBlock(ModularPipelineBlocks):
model_name = "img2img"
@property
def inputs(self):
return [InputParam(name="prompt"), InputParam(name="image")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "image-to-image workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "img2img"
self.set_block_state(state, block_state)
return components, state
class InpaintBlock(ModularPipelineBlocks):
model_name = "inpaint"
@property
def inputs(self):
return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "inpaint workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "inpaint"
self.set_block_state(state, block_state)
return components, state
class ConditionalImageBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = "text2img"
@property
def description(self):
return "Conditional image blocks for testing"
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None # falls back to default_block_name
class OptionalConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = None # no default; block can be skipped
@property
def description(self):
return "Optional conditional blocks (skippable)"
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None
class AutoImageBlocks(AutoPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image", None]
@property
def description(self):
return "Auto image blocks for testing"
class TestConditionalPipelineBlocksSelectBlock:
def test_select_block_with_mask(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask="something") == "inpaint"
def test_select_block_with_image(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(image="something") == "img2img"
def test_select_block_with_mask_and_image(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask="m", image="i") == "inpaint"
def test_select_block_no_triggers_returns_none(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block() is None
def test_select_block_explicit_none_values(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask=None, image=None) is None
class TestConditionalPipelineBlocksWorkflowSelection:
def test_default_workflow_when_no_triggers(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks()
assert execution is not None
assert isinstance(execution, TextToImageBlock)
def test_mask_trigger_selects_inpaint(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(mask=True)
assert isinstance(execution, InpaintBlock)
def test_image_trigger_selects_img2img(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
def test_mask_and_image_selects_inpaint(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(mask=True, image=True)
assert isinstance(execution, InpaintBlock)
def test_skippable_block_returns_none(self):
blocks = OptionalConditionalBlocks()
execution = blocks.get_execution_blocks()
assert execution is None
def test_skippable_block_still_selects_when_triggered(self):
blocks = OptionalConditionalBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
class TestAutoPipelineBlocksSelectBlock:
def test_auto_select_mask(self):
blocks = AutoImageBlocks()
assert blocks.select_block(mask="m") == "inpaint"
def test_auto_select_image(self):
blocks = AutoImageBlocks()
assert blocks.select_block(image="i") == "img2img"
def test_auto_select_default(self):
blocks = AutoImageBlocks()
# No trigger -> returns None -> falls back to default (text2img)
assert blocks.select_block() is None
def test_auto_select_priority_order(self):
blocks = AutoImageBlocks()
assert blocks.select_block(mask="m", image="i") == "inpaint"
class TestAutoPipelineBlocksWorkflowSelection:
def test_auto_default_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks()
assert isinstance(execution, TextToImageBlock)
def test_auto_mask_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks(mask=True)
assert isinstance(execution, InpaintBlock)
def test_auto_image_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
class TestConditionalPipelineBlocksStructure:
def test_block_names_accessible(self):
blocks = ConditionalImageBlocks()
sub = dict(blocks.sub_blocks)
assert set(sub.keys()) == {"inpaint", "img2img", "text2img"}
def test_sub_block_types(self):
blocks = ConditionalImageBlocks()
sub = dict(blocks.sub_blocks)
assert isinstance(sub["inpaint"], InpaintBlock)
assert isinstance(sub["img2img"], ImageToImageBlock)
assert isinstance(sub["text2img"], TextToImageBlock)
def test_description(self):
blocks = ConditionalImageBlocks()
assert "Conditional" in blocks.description

View File

@@ -9,6 +9,11 @@ import torch
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines import (
ConditionalPipelineBlocks,
LoopSequentialPipelineBlocks,
SequentialPipelineBlocks,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
@@ -19,6 +24,7 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
from diffusers.utils import logging
from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
@@ -431,6 +437,117 @@ class ModularGuiderTesterMixin:
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(str(tmp_path))
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
class TestModularModelCardContent:
def create_mock_block(self, name="TestBlock", description="Test block description"):
class MockBlock:

View File

@@ -24,18 +24,14 @@ import torch
from diffusers import FluxTransformer2DModel
from diffusers.modular_pipelines import (
ComponentSpec,
ConditionalPipelineBlocks,
InputParam,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
OutputParam,
PipelineState,
SequentialPipelineBlocks,
WanModularPipeline,
)
from diffusers.utils import logging
from ..testing_utils import CaptureLogger, nightly, require_torch, slow
from ..testing_utils import nightly, require_torch, slow
class DummyCustomBlockSimple(ModularPipelineBlocks):
@@ -358,117 +354,6 @@ class TestModularCustomBlocks:
assert output_prompt.startswith("Modular diffusers + ")
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(str(tmp_path))
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
@slow
@nightly
@require_torch