mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-26 18:37:41 +08:00
Compare commits
10 Commits
bria-test-
...
fix-torcha
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70067734a2 | ||
|
|
85ffcf1db2 | ||
|
|
cbf4d9a3c3 | ||
|
|
426daabad9 | ||
|
|
6125a4f540 | ||
|
|
d2666a9d0a | ||
|
|
9b9e2e17a6 | ||
|
|
1a959dc26f | ||
|
|
8797398d3b | ||
|
|
019a9deafb |
11
.ai/review-rules.md
Normal file
11
.ai/review-rules.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# PR Review Rules
|
||||
|
||||
Review-specific rules for Claude. Focus on correctness — style is handled by ruff.
|
||||
|
||||
Before reviewing, read and apply the guidelines in:
|
||||
- [AGENTS.md](AGENTS.md) — coding style, dependencies, copied code, model conventions
|
||||
- [skills/model-integration/SKILL.md](skills/model-integration/SKILL.md) — attention pattern, pipeline rules, implementation checklist, gotchas
|
||||
- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities
|
||||
- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.)
|
||||
|
||||
## Common mistakes (add new rules below this line)
|
||||
38
.github/workflows/claude_review.yml
vendored
Normal file
38
.github/workflows/claude_review.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: Claude PR Review
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_review_comment:
|
||||
types: [created]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: read
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
if: |
|
||||
(
|
||||
github.event_name == 'issue_comment' &&
|
||||
github.event.issue.pull_request &&
|
||||
github.event.issue.state == 'open' &&
|
||||
contains(github.event.comment.body, '@claude') &&
|
||||
(github.event.comment.author_association == 'MEMBER' ||
|
||||
github.event.comment.author_association == 'OWNER' ||
|
||||
github.event.comment.author_association == 'COLLABORATOR')
|
||||
) || (
|
||||
github.event_name == 'pull_request_review_comment' &&
|
||||
contains(github.event.comment.body, '@claude') &&
|
||||
(github.event.comment.author_association == 'MEMBER' ||
|
||||
github.event.comment.author_association == 'OWNER' ||
|
||||
github.event.comment.author_association == 'COLLABORATOR')
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--append-system-prompt "Review this PR against the rules in .ai/review-rules.md. Focus on correctness, not style (ruff handles style). Only review changes under src/diffusers/. Do NOT commit changes unless the comment explicitly asks you to using the phrase 'commit this'."
|
||||
@@ -248,6 +248,24 @@ Refer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/be
|
||||
|
||||
The [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results) repository also contains benchmarking results for compiled versions of Flux and CogVideoX.
|
||||
|
||||
## Kernels
|
||||
|
||||
[Kernels](https://huggingface.co/docs/kernels/index) is a library for building, distributing, and loading optimized compute kernels on the [Hub](https://huggingface.co/kernels-community). It supports [attention](./attention_backends#set_attention_backend) kernels and custom CUDA kernels for operations like RMSNorm, GEGLU, RoPE, and AdaLN.
|
||||
|
||||
The [Diffusers Pipeline Integration](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/references/diffusers-integration.md) guide shows how to integrate a kernel with the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill. This skill enables an agent, like Claude or Codex, to write custom kernels targeted towards a specific model and your hardware.
|
||||
|
||||
> [!TIP]
|
||||
> Install the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill to teach an agent how to write a kernel. The [Custom kernels for all from Codex and Claude](https://huggingface.co/blog/custom-cuda-kernels-agent-skills) blog post covers this in more detail.
|
||||
|
||||
For example, a custom RMSNorm kernel (generated by the `add cuda-kernels` skill) with [torch.compile](#torchcompile) speeds up LTX-Video generation 1.43x on an H100.
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/docs-benchmarks/kernel-ltx-video/embed/viewer/default/train"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
## Dynamic quantization
|
||||
|
||||
[Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import Set
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger, is_accelerate_available
|
||||
from ..utils import get_logger, is_accelerate_available, is_torchao_available
|
||||
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
@@ -35,6 +35,54 @@ if is_accelerate_available():
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
|
||||
if not is_torchao_available():
|
||||
return False
|
||||
from torchao.utils import TorchAOBaseTensor
|
||||
|
||||
return isinstance(tensor, TorchAOBaseTensor)
|
||||
|
||||
|
||||
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
|
||||
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
|
||||
cls = type(tensor)
|
||||
names = list(getattr(cls, "tensor_data_names", []))
|
||||
for attr_name in getattr(cls, "optional_tensor_data_names", []):
|
||||
if getattr(tensor, attr_name, None) is not None:
|
||||
names.append(attr_name)
|
||||
return names
|
||||
|
||||
|
||||
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
|
||||
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
|
||||
|
||||
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
|
||||
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
|
||||
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
|
||||
that any dict keyed by `id(param)` remains valid.
|
||||
|
||||
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
|
||||
"""
|
||||
torch.utils.swap_tensors(param, source)
|
||||
|
||||
|
||||
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
|
||||
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
|
||||
|
||||
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
|
||||
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
|
||||
`cpu_param_dict`).
|
||||
"""
|
||||
for attr_name in _get_torchao_inner_tensor_names(source):
|
||||
setattr(param, attr_name, getattr(source, attr_name))
|
||||
|
||||
|
||||
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
|
||||
"""Record stream for all internal tensors of a TorchAO parameter."""
|
||||
for attr_name in _get_torchao_inner_tensor_names(param):
|
||||
getattr(param, attr_name).record_stream(stream)
|
||||
|
||||
|
||||
# fmt: off
|
||||
_GROUP_OFFLOADING = "group_offloading"
|
||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||
@@ -157,9 +205,16 @@ class ModuleGroup:
|
||||
pinned_dict = None
|
||||
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
|
||||
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if _is_torchao_tensor(tensor):
|
||||
_swap_torchao_tensor(tensor, moved)
|
||||
else:
|
||||
tensor.data = moved
|
||||
if self.record_stream:
|
||||
tensor.data.record_stream(default_stream)
|
||||
if _is_torchao_tensor(tensor):
|
||||
_record_stream_torchao_tensor(tensor, default_stream)
|
||||
else:
|
||||
tensor.data.record_stream(default_stream)
|
||||
|
||||
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
|
||||
for group_module in self.modules:
|
||||
@@ -245,18 +300,35 @@ class ModuleGroup:
|
||||
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
if _is_torchao_tensor(param):
|
||||
_restore_torchao_tensor(param, self.cpu_param_dict[param])
|
||||
else:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
if _is_torchao_tensor(param):
|
||||
_restore_torchao_tensor(param, self.cpu_param_dict[param])
|
||||
else:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for buffer in self.buffers:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
if _is_torchao_tensor(buffer):
|
||||
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
|
||||
else:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=False)
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
if _is_torchao_tensor(param):
|
||||
moved = param.data.to(self.offload_device, non_blocking=False)
|
||||
_swap_torchao_tensor(param, moved)
|
||||
else:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
if _is_torchao_tensor(buffer):
|
||||
moved = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
_swap_torchao_tensor(buffer, moved)
|
||||
else:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
|
||||
@@ -13,31 +13,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import BriaTransformer2DModel
|
||||
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
IPAdapterTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def create_bria_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
def create_bria_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 0
|
||||
|
||||
@@ -58,8 +50,11 @@ def create_bria_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 1
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=model.config["joint_attention_dim"],
|
||||
image_embed_dim=model.config["pooled_projection_dim"],
|
||||
@@ -78,36 +73,53 @@ def create_bria_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
)
|
||||
|
||||
del sd
|
||||
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class BriaTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return BriaTransformer2DModel
|
||||
class BriaTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = BriaTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.8, 0.7, 0.7]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
height = width = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.8, 0.7, 0.7]
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
@@ -119,35 +131,11 @@ class BriaTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": [0, 4, 4],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
height = width = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"img_ids": randn_tensor(
|
||||
(height * width, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"txt_ids": randn_tensor(
|
||||
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestBriaTransformer(BriaTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -155,6 +143,7 @@ class TestBriaTransformer(BriaTransformerTesterConfig, ModelTesterMixin):
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
|
||||
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
|
||||
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
|
||||
|
||||
@@ -167,59 +156,26 @@ class TestBriaTransformer(BriaTransformerTesterConfig, ModelTesterMixin):
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
assert output_1.shape == output_2.shape
|
||||
assert torch.allclose(output_1, output_2, atol=1e-5), (
|
||||
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
|
||||
"are not equal as them as 2d inputs"
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
self.assertTrue(
|
||||
torch.allclose(output_1, output_2, atol=1e-5),
|
||||
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
|
||||
)
|
||||
|
||||
|
||||
class TestBriaTransformerTraining(BriaTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"BriaTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestBriaTransformerCompile(BriaTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
class BriaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = BriaTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return BriaTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class TestBriaTransformerIPAdapter(BriaTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
return FluxIPAdapterJointAttnProcessor2_0
|
||||
class BriaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = BriaTransformer2DModel
|
||||
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
torch.manual_seed(0)
|
||||
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
|
||||
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
|
||||
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
|
||||
return inputs_dict
|
||||
|
||||
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
|
||||
return create_bria_ip_adapter_state_dict(model)
|
||||
|
||||
|
||||
class TestBriaTransformerLoRA(BriaTransformerTesterConfig, LoraTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestBriaTransformerLoRAHotSwap(BriaTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return BriaTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -13,50 +13,62 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import BriaFiboTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class BriaFiboTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return BriaFiboTransformer2DModel
|
||||
class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = BriaFiboTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.8, 0.7, 0.7]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_latent_channels = 48
|
||||
num_image_channels = 3
|
||||
height = width = 16
|
||||
sequence_length = 32
|
||||
embedding_dim = 64
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
|
||||
}
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.8, 0.7, 0.7]
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (256, 48)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
def input_shape(self):
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
def output_shape(self):
|
||||
return (256, 48)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 48,
|
||||
"num_layers": 1,
|
||||
@@ -69,41 +81,9 @@ class BriaFiboTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": [0, 4, 4],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_latent_channels = 48
|
||||
num_image_channels = 3
|
||||
height = width = 16
|
||||
sequence_length = 32
|
||||
embedding_dim = 64
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": randn_tensor(
|
||||
(height * width, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"txt_ids": randn_tensor(
|
||||
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
|
||||
}
|
||||
|
||||
|
||||
class TestBriaFiboTransformer(BriaFiboTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestBriaFiboTransformerTraining(BriaFiboTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"BriaFiboTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestBriaFiboTransformerCompile(BriaFiboTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
242
tests/modular_pipelines/test_conditional_pipeline_blocks.py
Normal file
242
tests/modular_pipelines/test_conditional_pipeline_blocks.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
AutoPipelineBlocks,
|
||||
ConditionalPipelineBlocks,
|
||||
InputParam,
|
||||
ModularPipelineBlocks,
|
||||
)
|
||||
|
||||
|
||||
class TextToImageBlock(ModularPipelineBlocks):
|
||||
model_name = "text2img"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "text-to-image workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "text2img"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ImageToImageBlock(ModularPipelineBlocks):
|
||||
model_name = "img2img"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt"), InputParam(name="image")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "image-to-image workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "img2img"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class InpaintBlock(ModularPipelineBlocks):
|
||||
model_name = "inpaint"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "inpaint workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "inpaint"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ConditionalImageBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = "text2img"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Conditional image blocks for testing"
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None # falls back to default_block_name
|
||||
|
||||
|
||||
class OptionalConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = None # no default; block can be skipped
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Optional conditional blocks (skippable)"
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None
|
||||
|
||||
|
||||
class AutoImageBlocks(AutoPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto image blocks for testing"
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksSelectBlock:
|
||||
def test_select_block_with_mask(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask="something") == "inpaint"
|
||||
|
||||
def test_select_block_with_image(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(image="something") == "img2img"
|
||||
|
||||
def test_select_block_with_mask_and_image(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask="m", image="i") == "inpaint"
|
||||
|
||||
def test_select_block_no_triggers_returns_none(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block() is None
|
||||
|
||||
def test_select_block_explicit_none_values(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask=None, image=None) is None
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksWorkflowSelection:
|
||||
def test_default_workflow_when_no_triggers(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert execution is not None
|
||||
assert isinstance(execution, TextToImageBlock)
|
||||
|
||||
def test_mask_trigger_selects_inpaint(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_image_trigger_selects_img2img(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
def test_mask_and_image_selects_inpaint(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True, image=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_skippable_block_returns_none(self):
|
||||
blocks = OptionalConditionalBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert execution is None
|
||||
|
||||
def test_skippable_block_still_selects_when_triggered(self):
|
||||
blocks = OptionalConditionalBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
|
||||
class TestAutoPipelineBlocksSelectBlock:
|
||||
def test_auto_select_mask(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(mask="m") == "inpaint"
|
||||
|
||||
def test_auto_select_image(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(image="i") == "img2img"
|
||||
|
||||
def test_auto_select_default(self):
|
||||
blocks = AutoImageBlocks()
|
||||
# No trigger -> returns None -> falls back to default (text2img)
|
||||
assert blocks.select_block() is None
|
||||
|
||||
def test_auto_select_priority_order(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(mask="m", image="i") == "inpaint"
|
||||
|
||||
|
||||
class TestAutoPipelineBlocksWorkflowSelection:
|
||||
def test_auto_default_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert isinstance(execution, TextToImageBlock)
|
||||
|
||||
def test_auto_mask_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_auto_image_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksStructure:
|
||||
def test_block_names_accessible(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
sub = dict(blocks.sub_blocks)
|
||||
assert set(sub.keys()) == {"inpaint", "img2img", "text2img"}
|
||||
|
||||
def test_sub_block_types(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
sub = dict(blocks.sub_blocks)
|
||||
assert isinstance(sub["inpaint"], InpaintBlock)
|
||||
assert isinstance(sub["img2img"], ImageToImageBlock)
|
||||
assert isinstance(sub["text2img"], TextToImageBlock)
|
||||
|
||||
def test_description(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert "Conditional" in blocks.description
|
||||
@@ -10,11 +10,6 @@ from huggingface_hub import hf_hub_download
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines import (
|
||||
ConditionalPipelineBlocks,
|
||||
LoopSequentialPipelineBlocks,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
@@ -25,7 +20,6 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
@@ -498,117 +492,6 @@ class ModularGuiderTesterMixin:
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def get_dummy_conditional_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
class DummyConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [DummyBlockOne, DummyBlockTwo]
|
||||
block_names = ["block_one", "block_two"]
|
||||
block_trigger_inputs = []
|
||||
|
||||
def select_block(self, **kwargs):
|
||||
return "block_one"
|
||||
|
||||
return DummyConditionalBlocks()
|
||||
|
||||
def get_dummy_loop_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
|
||||
|
||||
def test_sequential_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_sequential_block_requirements_warnings(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
def test_conditional_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_conditional_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
def test_loop_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_loop_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
|
||||
class TestModularModelCardContent:
|
||||
def create_mock_block(self, name="TestBlock", description="Test block description"):
|
||||
class MockBlock:
|
||||
|
||||
@@ -24,14 +24,18 @@ import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.modular_pipelines import (
|
||||
ComponentSpec,
|
||||
ConditionalPipelineBlocks,
|
||||
InputParam,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipelineBlocks,
|
||||
OutputParam,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
WanModularPipeline,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
|
||||
from ..testing_utils import CaptureLogger, nightly, require_torch, require_torch_accelerator, slow, torch_device
|
||||
|
||||
|
||||
def _create_tiny_model_dir(model_dir):
|
||||
@@ -463,6 +467,117 @@ class TestModularCustomBlocks:
|
||||
assert output_prompt.startswith("Modular diffusers + ")
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def get_dummy_conditional_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
class DummyConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [DummyBlockOne, DummyBlockTwo]
|
||||
block_names = ["block_one", "block_two"]
|
||||
block_trigger_inputs = []
|
||||
|
||||
def select_block(self, **kwargs):
|
||||
return "block_one"
|
||||
|
||||
return DummyConditionalBlocks()
|
||||
|
||||
def get_dummy_loop_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
|
||||
|
||||
def test_sequential_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_sequential_block_requirements_warnings(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
def test_conditional_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_conditional_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
def test_loop_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_loop_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user