mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-31 04:51:31 +08:00
Compare commits
9 Commits
overhaul-r
...
autoencode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
771174ac68 | ||
|
|
d37402866c | ||
|
|
e1e7d58a4a | ||
|
|
d10190b1b7 | ||
|
|
a1c3e6ccbb | ||
|
|
a93f7f137a | ||
|
|
10ec3040a2 | ||
|
|
f2be8bd6b3 | ||
|
|
7da22b9db5 |
@@ -10,24 +10,34 @@ Strive to write code as simple and explicit as possible.
|
||||
|
||||
---
|
||||
|
||||
### Dependencies
|
||||
- No new mandatory dependency without discussion (e.g. `einops`)
|
||||
- Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`
|
||||
|
||||
## Code formatting
|
||||
|
||||
- `make style` and `make fix-copies` should be run as the final step before opening a PR
|
||||
|
||||
### Copied Code
|
||||
|
||||
- Many classes are kept in sync with a source via a `# Copied from ...` header comment
|
||||
- Do not edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source
|
||||
- Remove the header to intentionally break the link
|
||||
|
||||
### Models
|
||||
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
|
||||
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
|
||||
- See the **model-integration** skill for the attention pattern, pipeline rules, test setup instructions, and other important details.
|
||||
|
||||
- See [models.md](models.md) for model conventions, attention pattern, implementation rules, dependencies, and gotchas.
|
||||
- See the [model-integration](./skills/model-integration/SKILL.md) skill for the full integration workflow, file structure, test setup, and other details.
|
||||
|
||||
### Pipelines & Schedulers
|
||||
|
||||
- Pipelines inherit from `DiffusionPipeline`
|
||||
- Schedulers use `SchedulerMixin` with `ConfigMixin`
|
||||
- Use `@torch.no_grad()` on pipeline `__call__`
|
||||
- Support `output_type="latent"` for skipping VAE decode
|
||||
- Support `generator` parameter for reproducibility
|
||||
- Use `self.progress_bar(timesteps)` for progress tracking
|
||||
- Don't subclass an existing pipeline for a variant — DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`)
|
||||
|
||||
## Skills
|
||||
|
||||
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents.
|
||||
Available skills: **model-integration** (adding/converting pipelines), **parity-testing** (debugging numerical parity).
|
||||
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents. Available skills include:
|
||||
|
||||
- [model-integration](./skills/model-integration/SKILL.md) (adding/converting pipelines)
|
||||
- [parity-testing](./skills/parity-testing/SKILL.md) (debugging numerical parity).
|
||||
|
||||
76
.ai/models.md
Normal file
76
.ai/models.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# Model conventions and rules
|
||||
|
||||
Shared reference for model-related conventions, patterns, and gotchas.
|
||||
Linked from `AGENTS.md`, `skills/model-integration/SKILL.md`, and `review-rules.md`.
|
||||
|
||||
## Coding style
|
||||
|
||||
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
|
||||
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
|
||||
- No new mandatory dependency without discussion (e.g. `einops`). Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`.
|
||||
|
||||
## Common model conventions
|
||||
|
||||
- Models use `ModelMixin` with `register_to_config` for config serialization
|
||||
|
||||
## Attention pattern
|
||||
|
||||
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
|
||||
|
||||
```python
|
||||
# transformer_mymodel.py
|
||||
|
||||
class MyModelAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(self, attn, hidden_states, attention_mask=None, ...):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
# reshape, apply rope, etc.
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
return attn.to_out[0](hidden_states)
|
||||
|
||||
|
||||
class MyModelAttention(nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = MyModelAttnProcessor
|
||||
_available_processors = [MyModelAttnProcessor]
|
||||
|
||||
def __init__(self, query_dim, heads=8, dim_head=64, ...):
|
||||
super().__init__()
|
||||
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
|
||||
self.set_processor(MyModelAttnProcessor())
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
||||
return self.processor(self, hidden_states, attention_mask, **kwargs)
|
||||
```
|
||||
|
||||
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
|
||||
|
||||
## Gotchas
|
||||
|
||||
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
|
||||
|
||||
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
|
||||
|
||||
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
|
||||
|
||||
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
|
||||
|
||||
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
|
||||
|
||||
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
|
||||
|
||||
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
|
||||
|
||||
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
|
||||
@@ -3,8 +3,8 @@
|
||||
Review-specific rules for Claude. Focus on correctness — style is handled by ruff.
|
||||
|
||||
Before reviewing, read and apply the guidelines in:
|
||||
- [AGENTS.md](AGENTS.md) — coding style, dependencies, copied code, model conventions
|
||||
- [skills/model-integration/SKILL.md](skills/model-integration/SKILL.md) — attention pattern, pipeline rules, implementation checklist, gotchas
|
||||
- [AGENTS.md](AGENTS.md) — coding style, copied code
|
||||
- [models.md](models.md) — model conventions, attention pattern, implementation rules, dependencies, gotchas
|
||||
- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities
|
||||
- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.)
|
||||
|
||||
|
||||
@@ -65,89 +65,19 @@ docs/source/en/api/
|
||||
- [ ] Run `make style` and `make quality`
|
||||
- [ ] Test parity with reference implementation (see `parity-testing` skill)
|
||||
|
||||
### Attention pattern
|
||||
### Model conventions, attention pattern, and implementation rules
|
||||
|
||||
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
|
||||
See [../../models.md](../../models.md) for the attention pattern, implementation rules, common conventions, dependencies, and gotchas. These apply to all model work.
|
||||
|
||||
```python
|
||||
# transformer_mymodel.py
|
||||
### Model integration specific rules
|
||||
|
||||
class MyModelAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(self, attn, hidden_states, attention_mask=None, ...):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
# reshape, apply rope, etc.
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
return attn.to_out[0](hidden_states)
|
||||
|
||||
|
||||
class MyModelAttention(nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = MyModelAttnProcessor
|
||||
_available_processors = [MyModelAttnProcessor]
|
||||
|
||||
def __init__(self, query_dim, heads=8, dim_head=64, ...):
|
||||
super().__init__()
|
||||
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
|
||||
self.set_processor(MyModelAttnProcessor())
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
||||
return self.processor(self, hidden_states, attention_mask, **kwargs)
|
||||
```
|
||||
|
||||
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
|
||||
|
||||
### Implementation rules
|
||||
|
||||
1. **Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
|
||||
2. **Pipelines must inherit from `DiffusionPipeline`.** Consult implementations in `src/diffusers/pipelines` in case you need references.
|
||||
3. **Don't subclass an existing pipeline for a variant.** DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`).
|
||||
**Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
|
||||
|
||||
### Test setup
|
||||
|
||||
- Slow tests gated with `@slow` and `RUN_SLOW=1`
|
||||
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.
|
||||
|
||||
### Common diffusers conventions
|
||||
|
||||
- Pipelines inherit from `DiffusionPipeline`
|
||||
- Models use `ModelMixin` with `register_to_config` for config serialization
|
||||
- Schedulers use `SchedulerMixin` with `ConfigMixin`
|
||||
- Use `@torch.no_grad()` on pipeline `__call__`
|
||||
- Support `output_type="latent"` for skipping VAE decode
|
||||
- Support `generator` parameter for reproducibility
|
||||
- Use `self.progress_bar(timesteps)` for progress tracking
|
||||
|
||||
## Gotchas
|
||||
|
||||
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
|
||||
|
||||
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
|
||||
|
||||
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
|
||||
|
||||
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
|
||||
|
||||
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
|
||||
|
||||
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
|
||||
|
||||
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
|
||||
|
||||
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
|
||||
|
||||
---
|
||||
|
||||
## Modular Pipeline Conversion
|
||||
|
||||
3
.github/workflows/claude_review.yml
vendored
3
.github/workflows/claude_review.yml
vendored
@@ -32,6 +32,9 @@ jobs:
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
|
||||
@@ -862,23 +862,23 @@ def _native_attention_backward_op(
|
||||
key.requires_grad_(True)
|
||||
value.requires_grad_(True)
|
||||
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
with torch.enable_grad():
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
grad_out_t = grad_out.permute(0, 2, 1, 3)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
|
||||
)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out, retain_graph=False
|
||||
)
|
||||
|
||||
grad_query = grad_query_t.permute(0, 2, 1, 3)
|
||||
grad_key = grad_key_t.permute(0, 2, 1, 3)
|
||||
|
||||
@@ -470,8 +470,8 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
if is_torchao_version("<=", "0.9.0"):
|
||||
raise ValueError("TorchAoConfig requires torchao > 0.9.0. Please upgrade with `pip install -U torchao`.")
|
||||
if is_torchao_version("<", "0.15.0"):
|
||||
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")
|
||||
|
||||
from torchao.quantization.quant_api import AOBaseConfig
|
||||
|
||||
@@ -495,8 +495,8 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
|
||||
"""Create configuration from a dictionary."""
|
||||
if not is_torchao_version(">", "0.9.0"):
|
||||
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
|
||||
if not is_torchao_version(">=", "0.15.0"):
|
||||
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
|
||||
config_dict = config_dict.copy()
|
||||
quant_type = config_dict.pop("quant_type")
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ if (
|
||||
is_torch_available()
|
||||
and is_torch_version(">=", "2.6.0")
|
||||
and is_torchao_available()
|
||||
and is_torchao_version(">=", "0.7.0")
|
||||
and is_torchao_version(">=", "0.15.0")
|
||||
):
|
||||
_update_torch_safe_globals()
|
||||
|
||||
@@ -168,10 +168,10 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
raise ImportError(
|
||||
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
|
||||
)
|
||||
torchao_version = version.parse(importlib.metadata.version("torch"))
|
||||
if torchao_version < version.parse("0.7.0"):
|
||||
torchao_version = version.parse(importlib.metadata.version("torchao"))
|
||||
if torchao_version < version.parse("0.15.0"):
|
||||
raise RuntimeError(
|
||||
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
f"The minimum required version of `torchao` is 0.15.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
)
|
||||
|
||||
self.offload = False
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
@@ -25,7 +25,6 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
@@ -35,22 +34,26 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
class AutoencoderKLTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderKL
|
||||
|
||||
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def get_init_dict(self, block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [2, 4]
|
||||
norm_num_groups = norm_num_groups or 2
|
||||
init_dict = {
|
||||
return {
|
||||
"block_out_channels": block_out_channels,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -59,42 +62,32 @@ class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": norm_num_groups,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
image = torch.randn(batch_size, num_channels, *sizes).to(torch_device)
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
# Bridge for AutoencoderTesterMixin which still uses the old interface
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
return self.get_init_dict(), self.get_dummy_inputs()
|
||||
|
||||
|
||||
class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
image = model(**self.get_dummy_inputs())
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -168,17 +161,24 @@ class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
|
||||
]
|
||||
)
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
|
||||
|
||||
|
||||
class TestAutoencoderKLMemory(AutoencoderKLTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderKL."""
|
||||
|
||||
|
||||
class TestAutoencoderKLSlicingTiling(AutoencoderKLTesterConfig, AutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderKL."""
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
class AutoencoderKLIntegrationTests:
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def tearDown(self):
|
||||
def teardown_method(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -341,10 +341,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(13,), (16,), (27,)])
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
|
||||
model = self.get_sd_vae_model(fp16=True)
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
|
||||
@@ -362,10 +359,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(13,), (16,), (37,)])
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
@@ -44,9 +44,9 @@ class AutoencoderTesterMixin:
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
if not hasattr(self.model_class, "enable_tiling"):
|
||||
|
||||
@@ -98,6 +98,64 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _context_parallel_backward_worker(
|
||||
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict
|
||||
):
|
||||
"""Worker function for context parallel backward pass testing."""
|
||||
try:
|
||||
# Set up distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
# Get device configuration
|
||||
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
|
||||
backend = device_config["backend"]
|
||||
device_module = device_config["module"]
|
||||
|
||||
# Initialize process group
|
||||
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
|
||||
# Set device for this process
|
||||
device_module.set_device(rank)
|
||||
device = torch.device(f"{torch_device}:{rank}")
|
||||
|
||||
# Create model in training mode
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
# Run forward and backward pass
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
# Check that backward actually produced at least one valid gradient
|
||||
grads = [p.grad for p in model.parameters() if p.requires_grad and p.grad is not None]
|
||||
has_valid_grads = len(grads) > 0 and all(torch.isfinite(g).all() for g in grads)
|
||||
|
||||
# Only rank 0 reports results
|
||||
if rank == 0:
|
||||
return_dict["status"] = "success"
|
||||
return_dict["has_valid_grads"] = bool(has_valid_grads)
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
return_dict["status"] = "error"
|
||||
return_dict["error"] = str(e)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _custom_mesh_worker(
|
||||
rank,
|
||||
world_size,
|
||||
@@ -204,6 +262,51 @@ class ContextParallelTesterMixin:
|
||||
def test_context_parallel_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_inference(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_backward(self, cp_type, batch_size: int = 1):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
if active_backend == AttentionBackendName.NATIVE:
|
||||
pytest.skip("Ring attention is not supported with the native attention backend.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
|
||||
|
||||
# Move all tensors to CPU for multiprocessing
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
# Find a free port for distributed communication
|
||||
master_port = _find_free_port()
|
||||
|
||||
# Use multiprocessing manager for cross-process communication
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
# Spawn worker processes
|
||||
mp.spawn(
|
||||
_context_parallel_backward_worker,
|
||||
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel backward pass failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
assert return_dict.get("has_valid_grads"), "Context parallel backward pass did not produce valid gradients."
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_backward_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_backward(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
|
||||
@@ -14,13 +14,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import importlib.metadata
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
@@ -82,18 +80,17 @@ if is_torchao_available():
|
||||
Float8WeightOnlyConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8DynamicActivationIntxWeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
IntxWeightOnlyConfig,
|
||||
)
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
from torchao.utils import get_model_size_in_bytes
|
||||
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.10.0"):
|
||||
from torchao.quantization import Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
@@ -128,7 +125,7 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
@@ -527,7 +524,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
_ = pipe(**inputs)
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.9.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
def test_aobase_config(self):
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
@@ -540,7 +537,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
model_name = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
@@ -650,7 +647,7 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
@@ -696,7 +693,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoTests(unittest.TestCase):
|
||||
@@ -854,7 +851,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user