Compare commits

..

9 Commits

Author SHA1 Message Date
Sayak Paul
d79b88ae8d Merge branch 'main' into fix-modules-no-convert-torchao 2026-03-04 16:34:08 +05:30
jiqing-feng
88798242bc cogvideo example: Distribute VAE video encoding across processes in CogVideoX LoRA training (#13207)
* Distribute VAE video encoding across processes in CogVideoX LoRA training

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Apply style fixes

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-03-04 15:09:01 +05:30
Sayak Paul
4a2833c1c2 [Modular] implement requirements validation for custom blocks (#12196)
* feat: implement requirements validation for custom blocks.

* up

* unify.

* up

* add tests

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* reviewer feedback.

* [docs] validation for custom blocks (#13156)

validation

* move to tmp_path fixture.

* propagate to conditional and loopsequential blocks.

* up

* remove collected tests

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-03-04 12:19:08 +05:30
YiYi Xu
1fe688a651 [modular] not pass trust_remote_code to external repos (#13204)
* add

* update warn

* add a test

* updaqte

* update_component with custom model

* add more tests

* Apply suggestion from @DN6

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* up

---------

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2026-03-03 02:36:36 -10:00
Sayak Paul
faed0087d3 Merge branch 'main' into fix-modules-no-convert-torchao 2026-02-13 19:53:58 +05:30
Sayak Paul
ed734a0e63 Merge branch 'main' into fix-modules-no-convert-torchao 2026-02-10 15:49:51 +05:30
sayakpaul
d676b03490 fix torchao/. 2026-02-10 15:32:41 +05:30
sayakpaul
e117274aa5 fix bnb modules_to_convert. 2026-02-10 13:49:05 +05:30
sayakpaul
a1804cfa80 make modules_to_not_convert actually run. 2026-02-05 09:47:15 +05:30
8 changed files with 390 additions and 32 deletions

View File

@@ -332,4 +332,49 @@ Make your custom block work with Mellon's visual interface. See the [Mellon Cust
Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.
</hfoption>
</hfoptions>
</hfoptions>
## Dependencies
Declaring package dependencies in custom blocks prevents runtime import errors later on. Diffusers validates the dependencies and returns a warning if a package is missing or incompatible.
Set a `_requirements` attribute in your block class, mapping package names to version specifiers.
```py
from diffusers.modular_pipelines import PipelineBlock
class MyCustomBlock(PipelineBlock):
_requirements = {
"transformers": ">=4.44.0",
"sentencepiece": ">=0.2.0"
}
```
When there are blocks with different requirements, Diffusers merges their requirements.
```py
from diffusers.modular_pipelines import SequentialPipelineBlocks
class BlockA(PipelineBlock):
_requirements = {"transformers": ">=4.44.0"}
# ...
class BlockB(PipelineBlock):
_requirements = {"sentencepiece": ">=0.2.0"}
# ...
pipe = SequentialPipelineBlocks.from_blocks_dict({
"block_a": BlockA,
"block_b": BlockB,
})
```
When this block is saved with [`~ModularPipeline.save_pretrained`], the requirements are saved to the `modular_config.json` file. When this block is loaded, Diffusers checks each requirement against the current environment. If there is a mismatch or a package isn't found, Diffusers returns the following warning.
```md
# missing package
xyz-package was specified in the requirements but wasn't found in the current environment.
# version mismatch
xyz requirement 'specific-version' is not satisfied by the installed version 'actual-version'. Things might work unexpected.
```

View File

@@ -1232,22 +1232,49 @@ def main(args):
id_token=args.id_token,
)
def encode_video(video, bar):
bar.update(1)
def encode_video(video):
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist
# Distribute video encoding across processes: each process only encodes its own shard
num_videos = len(train_dataset.instance_videos)
num_procs = accelerator.num_processes
local_rank = accelerator.process_index
local_count = len(range(local_rank, num_videos, num_procs))
progress_encode_bar = tqdm(
range(0, len(train_dataset.instance_videos)),
desc="Loading Encode videos",
range(local_count),
desc="Encoding videos",
disable=not accelerator.is_local_main_process,
)
train_dataset.instance_videos = [
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
]
encoded_videos = [None] * num_videos
for i, video in enumerate(train_dataset.instance_videos):
if i % num_procs == local_rank:
encoded_videos[i] = encode_video(video)
progress_encode_bar.update(1)
progress_encode_bar.close()
# Broadcast encoded latent distributions so every process has the full set
if num_procs > 1:
import torch.distributed as dist
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
ref_params = next(v for v in encoded_videos if v is not None).parameters
for i in range(num_videos):
src = i % num_procs
if encoded_videos[i] is not None:
params = encoded_videos[i].parameters.contiguous()
else:
params = torch.empty_like(ref_params)
dist.broadcast(params, src=src)
encoded_videos[i] = DiagonalGaussianDistribution(params)
train_dataset.instance_videos = encoded_videos
def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
prompts = [example["instance_prompt"] for example in examples]

View File

@@ -89,8 +89,6 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")
def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:

View File

@@ -47,6 +47,7 @@ from .modular_pipeline_utils import (
InputParam,
InsertableDict,
OutputParam,
_validate_requirements,
combine_inputs,
combine_outputs,
format_components,
@@ -297,6 +298,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
config_name = "modular_config.json"
model_name = None
_requirements: dict[str, str] | None = None
_workflow_map = None
@classmethod
@@ -411,6 +413,9 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
if "requirements" in config and config["requirements"] is not None:
_ = _validate_requirements(config["requirements"])
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
@@ -435,8 +440,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
self.register_to_config(auto_map=auto_map)
# resolve requirements
requirements = _validate_requirements(getattr(self, "_requirements", None))
if requirements:
self.register_to_config(requirements=requirements)
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
@@ -658,6 +668,15 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
combined_outputs = combine_outputs(*named_outputs)
return combined_outputs
@property
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements
def _requirements(self) -> dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
# used for `__repr__`
def _get_trigger_inputs(self) -> set:
"""
@@ -1247,6 +1266,14 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs=self.expected_configs,
)
@property
def _requirements(self) -> dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""
@@ -1385,6 +1412,15 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
def outputs(self) -> list[str]:
return next(reversed(self.sub_blocks.values())).intermediate_outputs
@property
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements
def _requirements(self) -> dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
def __init__(self):
sub_blocks = InsertableDict()
for block_name, block in zip(self.block_names, self.block_classes):

View File

@@ -22,10 +22,12 @@ from typing import Any, Literal, Type, Union, get_args, get_origin
import PIL.Image
import torch
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
from ..utils.import_utils import _is_package_available
if is_torch_available():
@@ -1020,6 +1022,89 @@ def make_doc_string(
return output
def _validate_requirements(reqs):
if reqs is None:
normalized_reqs = {}
else:
if not isinstance(reqs, dict):
raise ValueError(
"Requirements must be provided as a dictionary mapping package names to version specifiers."
)
normalized_reqs = _normalize_requirements(reqs)
if not normalized_reqs:
return {}
final: dict[str, str] = {}
for req, specified_ver in normalized_reqs.items():
req_available, req_actual_ver = _is_package_available(req)
if not req_available:
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
if specified_ver:
try:
specifier = SpecifierSet(specified_ver)
except InvalidSpecifier as err:
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
if req_actual_ver == "N/A":
logger.warning(
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
)
elif not specifier.contains(req_actual_ver, prereleases=True):
logger.warning(
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
)
final[req] = specified_ver
return final
def _normalize_requirements(reqs):
if not reqs:
return {}
normalized: "OrderedDict[str, str]" = OrderedDict()
def _accumulate(mapping: dict[str, Any]):
for pkg, spec in mapping.items():
if isinstance(spec, dict):
# This is recursive because blocks are composable. This way, we can merge requirements
# from multiple blocks.
_accumulate(spec)
continue
pkg_name = str(pkg).strip()
if not pkg_name:
raise ValueError("Requirement package name cannot be empty.")
spec_str = "" if spec is None else str(spec).strip()
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
spec_str = f"=={spec_str}"
existing_spec = normalized.get(pkg_name)
if existing_spec is not None:
if not existing_spec and spec_str:
normalized[pkg_name] = spec_str
elif existing_spec and spec_str and existing_spec != spec_str:
try:
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
except InvalidSpecifier:
logger.warning(
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
)
else:
normalized[pkg_name] = str(combined_spec)
continue
normalized[pkg_name] = spec_str
_accumulate(reqs)
return normalized
def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]:
"""
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current

View File

@@ -21,11 +21,8 @@ import torch
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
from diffusers.utils.import_utils import (
is_bitsandbytes_available,
is_gguf_available,
is_nvidia_modelopt_available,
is_optimum_quanto_available,
is_torchao_available,
is_torchao_version,
)
from ...testing_utils import (
@@ -59,13 +56,6 @@ if is_bitsandbytes_available():
if is_optimum_quanto_available():
from optimum.quanto import QLinear
if is_gguf_available():
pass
if is_torchao_available():
if is_torchao_version(">=", "0.9.0"):
pass
class LoRALayer(torch.nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only.
@@ -132,14 +122,14 @@ class QuantizationTesterMixin:
def _verify_if_layer_quantized(self, name, module, config_kwargs):
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
def _is_module_quantized(self, module):
def _is_module_quantized(self, module, quant_config_kwargs=None):
"""
Check if a module is quantized. Returns True if quantized, False otherwise.
Default implementation tries _verify_if_layer_quantized and catches exceptions.
Subclasses can override for more efficient checking.
"""
try:
self._verify_if_layer_quantized("", module, {})
self._verify_if_layer_quantized("", module, quant_config_kwargs or {})
return True
except (AssertionError, AttributeError):
return False
@@ -273,7 +263,9 @@ class QuantizationTesterMixin:
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
)
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
def _test_quantization_modules_to_not_convert(
self, config_kwargs, modules_to_not_convert, to_not_convert_key="modules_to_not_convert"
):
"""
Test that modules specified in modules_to_not_convert are not quantized.
@@ -283,7 +275,7 @@ class QuantizationTesterMixin:
"""
# Create config with modules_to_not_convert
config_kwargs_with_exclusion = config_kwargs.copy()
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
config_kwargs_with_exclusion[to_not_convert_key] = modules_to_not_convert
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
@@ -295,7 +287,7 @@ class QuantizationTesterMixin:
if any(excluded in name for excluded in modules_to_not_convert):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(module), (
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
f"Module {name} should not be quantized but was found to be quantized"
)
@@ -307,7 +299,7 @@ class QuantizationTesterMixin:
if isinstance(module, torch.nn.Linear):
# Check if this module is NOT in the exclusion list
if not any(excluded in name for excluded in modules_to_not_convert):
if self._is_module_quantized(module):
if self._is_module_quantized(module, config_kwargs_with_exclusion):
found_quantized = True
break
@@ -612,7 +604,7 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude, "llm_int8_skip_modules"
)
@pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"])
@@ -826,7 +818,14 @@ class TorchAoConfigMixin:
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
# Check if the weight is actually quantized
weight = module.weight
is_quantized = isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))
assert is_quantized, f"Layer {name} weight is not quantized, got {type(weight)}"
# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
@@ -922,9 +921,39 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
if modules_to_exclude is None:
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(
TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
)
# Custom implementation for torchao that skips memory footprint check
# because get_memory_footprint() doesn't accurately reflect torchao quantization
config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]
config_kwargs_with_exclusion = config_kwargs.copy()
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_exclude
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
# Find a module that should NOT be quantized
found_excluded = False
for name, module in model_with_exclusion.named_modules():
if isinstance(module, torch.nn.Linear):
# Check if this module is in the exclusion list
if any(excluded in name for excluded in modules_to_exclude):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
f"Module {name} should not be quantized but was found to be quantized"
)
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_exclude}"
# Find a module that SHOULD be quantized (not in exclusion list)
found_quantized = False
for name, module in model_with_exclusion.named_modules():
if isinstance(module, torch.nn.Linear):
# Check if this module is NOT in the exclusion list
if not any(excluded in name for excluded in modules_to_exclude):
if self._is_module_quantized(module, config_kwargs_with_exclusion):
found_quantized = True
break
assert found_quantized, "No quantized layers found outside of excluded modules"
def test_torchao_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""

View File

@@ -318,6 +318,10 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Flux Transformer."""
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
"""Quanto quantization tests for Flux Transformer."""
@@ -330,10 +334,18 @@ class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
def pretrained_model_kwargs(self):
return {}
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Flux Transformer."""
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
@property
@@ -402,6 +414,10 @@ class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTes
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
"""ModelOpt quantization tests for Flux Transformer."""
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
"""ModelOpt + compile tests for Flux Transformer."""

View File

@@ -10,6 +10,11 @@ import torch
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines import (
ConditionalPipelineBlocks,
LoopSequentialPipelineBlocks,
SequentialPipelineBlocks,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
@@ -19,7 +24,13 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
)
from diffusers.utils import logging
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
torch_device,
)
class ModularPipelineTesterMixin:
@@ -429,6 +440,117 @@ class ModularGuiderTesterMixin:
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(tmp_path)
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(tmp_path)
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(tmp_path)
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(tmp_path)
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
class TestModularModelCardContent:
def create_mock_block(self, name="TestBlock", description="Test block description"):
class MockBlock: