Compare commits

..

2 Commits

Author SHA1 Message Date
Sayak Paul
d8f6063c27 Merge branch 'main' into cog-tests 2026-03-30 09:01:58 +05:30
DN6
f7405f2b44 update 2026-03-26 16:41:25 +05:30
14 changed files with 328 additions and 423 deletions

View File

@@ -10,34 +10,24 @@ Strive to write code as simple and explicit as possible.
---
## Code formatting
### Dependencies
- No new mandatory dependency without discussion (e.g. `einops`)
- Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`
## Code formatting
- `make style` and `make fix-copies` should be run as the final step before opening a PR
### Copied Code
- Many classes are kept in sync with a source via a `# Copied from ...` header comment
- Do not edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source
- Remove the header to intentionally break the link
### Models
- See [models.md](models.md) for model conventions, attention pattern, implementation rules, dependencies, and gotchas.
- See the [model-integration](./skills/model-integration/SKILL.md) skill for the full integration workflow, file structure, test setup, and other details.
### Pipelines & Schedulers
- Pipelines inherit from `DiffusionPipeline`
- Schedulers use `SchedulerMixin` with `ConfigMixin`
- Use `@torch.no_grad()` on pipeline `__call__`
- Support `output_type="latent"` for skipping VAE decode
- Support `generator` parameter for reproducibility
- Use `self.progress_bar(timesteps)` for progress tracking
- Don't subclass an existing pipeline for a variant — DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`)
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
- See the **model-integration** skill for the attention pattern, pipeline rules, test setup instructions, and other important details.
## Skills
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents. Available skills include:
- [model-integration](./skills/model-integration/SKILL.md) (adding/converting pipelines)
- [parity-testing](./skills/parity-testing/SKILL.md) (debugging numerical parity).
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents.
Available skills: **model-integration** (adding/converting pipelines), **parity-testing** (debugging numerical parity).

View File

@@ -1,76 +0,0 @@
# Model conventions and rules
Shared reference for model-related conventions, patterns, and gotchas.
Linked from `AGENTS.md`, `skills/model-integration/SKILL.md`, and `review-rules.md`.
## Coding style
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
- No new mandatory dependency without discussion (e.g. `einops`). Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`.
## Common model conventions
- Models use `ModelMixin` with `register_to_config` for config serialization
## Attention pattern
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
```python
# transformer_mymodel.py
class MyModelAttnProcessor:
_attention_backend = None
_parallel_config = None
def __call__(self, attn, hidden_states, attention_mask=None, ...):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# reshape, apply rope, etc.
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
return attn.to_out[0](hidden_states)
class MyModelAttention(nn.Module, AttentionModuleMixin):
_default_processor_cls = MyModelAttnProcessor
_available_processors = [MyModelAttnProcessor]
def __init__(self, query_dim, heads=8, dim_head=64, ...):
super().__init__()
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
self.set_processor(MyModelAttnProcessor())
def forward(self, hidden_states, attention_mask=None, **kwargs):
return self.processor(self, hidden_states, attention_mask, **kwargs)
```
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
## Gotchas
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.

View File

@@ -3,8 +3,8 @@
Review-specific rules for Claude. Focus on correctness — style is handled by ruff.
Before reviewing, read and apply the guidelines in:
- [AGENTS.md](AGENTS.md) — coding style, copied code
- [models.md](models.md) — model conventions, attention pattern, implementation rules, dependencies, gotchas
- [AGENTS.md](AGENTS.md) — coding style, dependencies, copied code, model conventions
- [skills/model-integration/SKILL.md](skills/model-integration/SKILL.md) — attention pattern, pipeline rules, implementation checklist, gotchas
- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities
- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.)

View File

@@ -65,19 +65,89 @@ docs/source/en/api/
- [ ] Run `make style` and `make quality`
- [ ] Test parity with reference implementation (see `parity-testing` skill)
### Model conventions, attention pattern, and implementation rules
### Attention pattern
See [../../models.md](../../models.md) for the attention pattern, implementation rules, common conventions, dependencies, and gotchas. These apply to all model work.
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
### Model integration specific rules
```python
# transformer_mymodel.py
**Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
class MyModelAttnProcessor:
_attention_backend = None
_parallel_config = None
def __call__(self, attn, hidden_states, attention_mask=None, ...):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# reshape, apply rope, etc.
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
return attn.to_out[0](hidden_states)
class MyModelAttention(nn.Module, AttentionModuleMixin):
_default_processor_cls = MyModelAttnProcessor
_available_processors = [MyModelAttnProcessor]
def __init__(self, query_dim, heads=8, dim_head=64, ...):
super().__init__()
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
self.set_processor(MyModelAttnProcessor())
def forward(self, hidden_states, attention_mask=None, **kwargs):
return self.processor(self, hidden_states, attention_mask, **kwargs)
```
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
### Implementation rules
1. **Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
2. **Pipelines must inherit from `DiffusionPipeline`.** Consult implementations in `src/diffusers/pipelines` in case you need references.
3. **Don't subclass an existing pipeline for a variant.** DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`).
### Test setup
- Slow tests gated with `@slow` and `RUN_SLOW=1`
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.
### Common diffusers conventions
- Pipelines inherit from `DiffusionPipeline`
- Models use `ModelMixin` with `register_to_config` for config serialization
- Schedulers use `SchedulerMixin` with `ConfigMixin`
- Use `@torch.no_grad()` on pipeline `__call__`
- Support `output_type="latent"` for skipping VAE decode
- Support `generator` parameter for reproducibility
- Use `self.progress_bar(timesteps)` for progress tracking
## Gotchas
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
---
## Modular Pipeline Conversion

View File

@@ -6,7 +6,6 @@ on:
- main
paths:
- "src/diffusers/**.py"
- "tests/**.py"
push:
branches:
- main

View File

@@ -6,7 +6,6 @@ on:
- main
paths:
- "src/diffusers/**.py"
- "tests/**.py"
push:
branches:
- main
@@ -27,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
pip install -e .
pip install torch pytest
pip install torch torchvision torchaudio pytest
- name: Check for soft dependencies
run: |
pytest tests/others/test_dependencies.py

View File

@@ -862,23 +862,23 @@ def _native_attention_backward_op(
key.requires_grad_(True)
value.requires_grad_(True)
with torch.enable_grad():
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query_t,
key=key_t,
value=value_t,
attn_mask=ctx.attn_mask,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
enable_gqa=ctx.enable_gqa,
)
out = out.permute(0, 2, 1, 3)
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query_t,
key=key_t,
value=value_t,
attn_mask=ctx.attn_mask,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
enable_gqa=ctx.enable_gqa,
)
out = out.permute(0, 2, 1, 3)
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out, retain_graph=False
)
grad_out_t = grad_out.permute(0, 2, 1, 3)
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
)
grad_query = grad_query_t.permute(0, 2, 1, 3)
grad_key = grad_key_t.permute(0, 2, 1, 3)

View File

@@ -5,13 +5,10 @@ import cv2
import numpy as np
import torch
from PIL import Image, ImageOps
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import normalize, resize
from ...utils import get_logger, is_torchvision_available, load_image
if is_torchvision_available():
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import normalize, resize
from ...utils import get_logger, load_image
logger = get_logger(__name__)

View File

@@ -44,9 +44,9 @@ class AutoencoderTesterMixin:
if isinstance(output, dict):
output = output.to_tuple()[0]
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_enable_disable_tiling(self):
if not hasattr(self.model_class, "enable_tiling"):

View File

@@ -98,64 +98,6 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _context_parallel_backward_worker(
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict
):
"""Worker function for context parallel backward pass testing."""
try:
# Set up distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
# Get device configuration
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
backend = device_config["backend"]
device_module = device_config["module"]
# Initialize process group
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
# Set device for this process
device_module.set_device(rank)
device = torch.device(f"{torch_device}:{rank}")
# Create model in training mode
model = model_class(**init_dict)
model.to(device)
model.train()
# Move inputs to device
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)
# Run forward and backward pass
output = model(**inputs_on_device, return_dict=False)[0]
loss = output.sum()
loss.backward()
# Check that backward actually produced at least one valid gradient
grads = [p.grad for p in model.parameters() if p.requires_grad and p.grad is not None]
has_valid_grads = len(grads) > 0 and all(torch.isfinite(g).all() for g in grads)
# Only rank 0 reports results
if rank == 0:
return_dict["status"] = "success"
return_dict["has_valid_grads"] = bool(has_valid_grads)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
@@ -262,51 +204,6 @@ class ContextParallelTesterMixin:
def test_context_parallel_batch_inputs(self, cp_type):
self.test_context_parallel_inference(cp_type, batch_size=2)
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_backward(self, cp_type, batch_size: int = 1):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
cp_dict = {cp_type: world_size}
# Find a free port for distributed communication
master_port = _find_free_port()
# Use multiprocessing manager for cross-process communication
manager = mp.Manager()
return_dict = manager.dict()
# Spawn worker processes
mp.spawn(
_context_parallel_backward_worker,
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Context parallel backward pass failed: {return_dict.get('error', 'Unknown error')}"
)
assert return_dict.get("has_valid_grads"), "Context parallel backward pass did not produce valid gradients."
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_backward_batch_inputs(self, cp_type):
self.test_context_parallel_backward(cp_type, batch_size=2)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[

View File

@@ -13,59 +13,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogVideoXTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.8]
# ======================== CogVideoX ========================
class CogVideoXTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CogVideoXTransformer3DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 1
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
def main_input_name(self) -> str:
return "hidden_states"
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def model_split_percents(self) -> list:
return [0.7, 0.7, 0.8]
@property
def output_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def input_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
@@ -81,50 +75,66 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
"temporal_compression_ratio": 4,
"max_text_seq_length": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 2
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 2
num_frames = 1
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"hidden_states": randn_tensor(
(batch_size, num_frames, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestCogVideoXTransformer(CogVideoXTransformerTesterConfig, ModelTesterMixin):
pass
class TestCogVideoXTransformerTraining(CogVideoXTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestCogVideoXTransformerCompile(CogVideoXTransformerTesterConfig, TorchCompileTesterMixin):
pass
# ======================== CogVideoX 1.5 ========================
class CogVideoX15TransformerTesterConfig(BaseModelTesterConfig):
@property
def input_shape(self):
def model_class(self):
return CogVideoXTransformer3DModel
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def output_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def output_shape(self):
def input_shape(self) -> tuple:
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
@@ -141,9 +151,29 @@ class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
"max_text_seq_length": 8,
"use_rotary_positional_embeddings": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 2
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_frames, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestCogVideoX15Transformer(CogVideoX15TransformerTesterConfig, ModelTesterMixin):
pass
class TestCogVideoX15TransformerCompile(CogVideoX15TransformerTesterConfig, TorchCompileTesterMixin):
pass

View File

@@ -13,63 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogView3PlusTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogView3PlusTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.6, 0.6]
class CogView3PlusTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CogView3PlusTransformer2DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
def main_input_name(self) -> str:
return "hidden_states"
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def output_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def input_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"original_size": original_size,
"target_size": target_size,
"crop_coords": crop_coords,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 4,
"num_layers": 2,
@@ -82,9 +69,37 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
"pos_embed_max_size": 8,
"sample_size": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"original_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"target_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"crop_coords": torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestCogView3PlusTransformer(CogView3PlusTransformerTesterConfig, ModelTesterMixin):
pass
class TestCogView3PlusTransformerTraining(CogView3PlusTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogView3PlusTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestCogView3PlusTransformerCompile(CogView3PlusTransformerTesterConfig, TorchCompileTesterMixin):
pass

View File

@@ -12,59 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogView4Transformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogView4Transformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class CogView4TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CogView4Transformer2DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
def main_input_name(self) -> str:
return "hidden_states"
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def output_shape(self) -> tuple:
return (4, 8, 8)
@property
def input_shape(self) -> tuple:
return (4, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"original_size": original_size,
"target_size": target_size,
"crop_coords": crop_coords,
}
@property
def input_shape(self):
return (4, 8, 8)
@property
def output_shape(self):
return (4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 4,
"num_layers": 2,
@@ -75,9 +62,37 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
"time_embed_dim": 8,
"condition_dim": 4,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"original_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"target_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"crop_coords": torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
}
class TestCogView4Transformer(CogView4TransformerTesterConfig, ModelTesterMixin):
pass
class TestCogView4TransformerTraining(CogView4TransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogView4Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestCogView4TransformerCompile(CogView4TransformerTesterConfig, TorchCompileTesterMixin):
pass

View File

@@ -13,14 +13,16 @@
# limitations under the License.
import inspect
import unittest
from importlib import import_module
import pytest
class TestDependencies:
class DependencyTester(unittest.TestCase):
def test_diffusers_import(self):
import diffusers # noqa: F401
try:
import diffusers # noqa: F401
except ImportError:
assert False
def test_backend_registration(self):
import diffusers
@@ -50,36 +52,3 @@ class TestDependencies:
if hasattr(diffusers.pipelines, cls_name):
pipeline_folder_module = ".".join(str(cls_module.__module__).split(".")[:3])
_ = import_module(pipeline_folder_module, str(cls_name))
def test_pipeline_module_imports(self):
"""Import every pipeline submodule whose dependencies are satisfied,
to catch unguarded optional-dep imports (e.g., torchvision).
Uses inspect.getmembers to discover classes that the lazy loader can
actually resolve (same self-filtering as test_pipeline_imports), then
imports the full module path instead of truncating to the folder level.
"""
import diffusers
import diffusers.pipelines
failures = []
all_classes = inspect.getmembers(diffusers, inspect.isclass)
for cls_name, cls_module in all_classes:
if not hasattr(diffusers.pipelines, cls_name):
continue
if "dummy_" in cls_module.__module__:
continue
full_module_path = cls_module.__module__
try:
import_module(full_module_path)
except ImportError as e:
failures.append(f"{full_module_path}: {e}")
except Exception:
# Non-import errors (e.g., missing config) are fine; we only
# care about unguarded import statements.
pass
if failures:
pytest.fail("Unguarded optional-dependency imports found:\n" + "\n".join(failures))