Compare commits

..

1 Commits

Author SHA1 Message Date
DN6
19df302d13 update 2026-03-20 11:11:42 +05:30
7 changed files with 14 additions and 179 deletions

View File

@@ -1,34 +0,0 @@
# PR Review Rules
Rules for Claude to check during PR reviews. Focus on correctness — style is handled by ruff.
## Code style
- Inline logic — minimize small helper/utility functions. A reader should follow the full flow without jumping between functions.
- No defensive code or unused code paths — no fallback paths, safety checks, or config options "just in case".
- No silent fallbacks — raise a concise error for unsupported cases rather than guessing user intent.
## Dependencies
- No new mandatory dependencies without prior discussion.
- Optional deps must be guarded with `is_X_available()` and have a dummy in `utils/dummy_*.py`.
- Never use `einops` — rewrite with native PyTorch (`reshape`, `permute`, `unflatten`).
## Models
- All layer calls must be visible directly in `forward()` — no helper functions hiding `nn.Module` calls.
- No NumPy operations in `forward()` — breaks `torch.compile` with `fullgraph=True`.
- No hardcoded dtypes (e.g. `torch.float32`, `torch.bfloat16`) in forward — use input tensor dtype or `self.dtype`.
- Attention must use `dispatch_attention_fn`, not `F.scaled_dot_product_attention` directly.
- Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`.
- New classes must be registered in `__init__.py` with lazy imports (both `_import_structure` and `_lazy_modules`).
## Pipelines
- Must inherit from `DiffusionPipeline`.
- `@torch.no_grad()` on pipeline `__call__` — forgetting this causes OOM from gradient accumulation.
- Do NOT subclass an existing pipeline for a variant (e.g. don't subclass `FluxPipeline` for `FluxImg2ImgPipeline`).
- Support `output_type="latent"` for skipping VAE decode.
- Support `generator` parameter for reproducibility.
## Copied code
- Never edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source.
- Remove the `# Copied from` header to intentionally break the sync link.
## Common mistakes (add new rules below this line)

View File

@@ -1,27 +0,0 @@
name: Claude PR Review
on:
issue_comment:
types: [created]
permissions:
contents: write
pull-requests: write
issues: read
jobs:
claude-review:
if: |
github.event.issue.pull_request &&
github.event.issue.state == 'open' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'MEMBER' ||
github.event.comment.author_association == 'OWNER' ||
github.event.comment.author_association == 'COLLABORATOR')
runs-on: ubuntu-latest
steps:
- uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
claude_args: |
--append-system-prompt "Review this PR against the rules in .ai/review-rules.md. Focus on correctness, not style (ruff handles style). Only review changes under src/diffusers/."

View File

@@ -143,7 +143,6 @@ Refer to the table below for a complete list of available attention backends and
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |

View File

@@ -229,7 +229,6 @@ class AttentionBackendName(str, Enum):
FLASH_HUB = "flash_hub"
FLASH_VARLEN = "flash_varlen"
FLASH_VARLEN_HUB = "flash_varlen_hub"
FLASH_4_HUB = "flash_4_hub"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
@@ -359,11 +358,6 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
function_attr="sageattn",
version=1,
),
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
repo_id="kernels-staging/flash-attn4",
function_attr="flash_attn_func",
version=0,
),
}
@@ -527,7 +521,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
AttentionBackendName._FLASH_3_HUB,
AttentionBackendName._FLASH_3_VARLEN_HUB,
AttentionBackendName.SAGE_HUB,
AttentionBackendName.FLASH_4_HUB,
]:
if not is_kernels_available():
raise RuntimeError(
@@ -538,11 +531,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
)
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
raise RuntimeError(
@@ -2688,37 +2676,6 @@ def _flash_attention_3_varlen_hub(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_4_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
)
def _flash_attention_4_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
scale: float | None = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
)
if isinstance(out, tuple):
return (out[0], out[1]) if return_lse else out[0]
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

View File

@@ -324,18 +324,17 @@ class AudioLDM2Pipeline(DiffusionPipeline):
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
The sequence of generated hidden-states.
"""
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.1"):
cache_position_kwargs["input_ids"] = inputs_embeds
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
if hasattr(self.language_model, "_get_initial_cache_position"):
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.1"):
cache_position_kwargs["input_ids"] = inputs_embeds
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
for _ in range(max_new_tokens):
# prepare model inputs

View File

@@ -28,6 +28,7 @@ from diffusers.utils.import_utils import is_peft_available
from ..testing_utils import (
floats_tensor,
is_flaky,
require_peft_backend,
require_peft_version_greater,
skip_mps,
@@ -45,6 +46,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
@is_flaky(max_attempts=10, description="very flaky class")
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
@@ -71,8 +73,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"base_dim": 3,
"z_dim": 4,
"dim_mult": [1, 1, 1, 1],
"latents_mean": [-0.7571, -0.7089, -0.9113, -0.7245],
"latents_std": [2.8184, 1.4541, 2.3275, 2.6558],
"latents_mean": torch.randn(4).numpy().tolist(),
"latents_std": torch.randn(4).numpy().tolist(),
"num_res_blocks": 1,
"temperal_downsample": [False, True, True],
}

View File

@@ -5,7 +5,6 @@ from typing import Callable
import pytest
import torch
from huggingface_hub import hf_hub_download
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
@@ -33,33 +32,6 @@ from ..testing_utils import (
)
def _get_specified_components(path_or_repo_id, cache_dir=None):
if os.path.isdir(path_or_repo_id):
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
else:
try:
config_path = hf_hub_download(
repo_id=path_or_repo_id,
filename="modular_model_index.json",
local_dir=cache_dir,
)
except Exception:
return None
with open(config_path) as f:
config = json.load(f)
components = set()
for k, v in config.items():
if isinstance(v, (str, int, float, bool)):
continue
for entry in v:
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
components.add(k)
break
return components
class ModularPipelineTesterMixin:
"""
It provides a set of common tests for each modular pipeline,
@@ -388,39 +360,6 @@ class ModularPipelineTesterMixin:
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_load_expected_components_from_pretrained(self, tmp_path):
pipe = self.get_pipeline()
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
if not expected:
pytest.skip("Skipping test as we couldn't fetch the expected components.")
actual = {
name
for name in pipe.components
if getattr(pipe, name, None) is not None
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
}
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
def test_load_expected_components_from_save_pretrained(self, tmp_path):
pipe = self.get_pipeline()
save_dir = str(tmp_path / "saved-pipeline")
pipe.save_pretrained(save_dir)
expected = _get_specified_components(save_dir)
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
loaded_pipe.load_components(torch_dtype=torch.float32)
actual = {
name
for name in loaded_pipe.components
if getattr(loaded_pipe, name, None) is not None
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
}
assert expected == actual, (
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
)
def test_modular_index_consistency(self, tmp_path):
pipe = self.get_pipeline()
components_spec = pipe._component_specs