mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-01 05:16:39 +08:00
Compare commits
9 Commits
unet-model
...
autoencode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39d7b1aa41 | ||
|
|
e231b433a3 | ||
|
|
7e463ea4cc | ||
|
|
7f2b34bced | ||
|
|
e1e7d58a4a | ||
|
|
a93f7f137a | ||
|
|
10ec3040a2 | ||
|
|
f2be8bd6b3 | ||
|
|
7da22b9db5 |
@@ -10,24 +10,34 @@ Strive to write code as simple and explicit as possible.
|
||||
|
||||
---
|
||||
|
||||
### Dependencies
|
||||
- No new mandatory dependency without discussion (e.g. `einops`)
|
||||
- Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`
|
||||
|
||||
## Code formatting
|
||||
|
||||
- `make style` and `make fix-copies` should be run as the final step before opening a PR
|
||||
|
||||
### Copied Code
|
||||
|
||||
- Many classes are kept in sync with a source via a `# Copied from ...` header comment
|
||||
- Do not edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source
|
||||
- Remove the header to intentionally break the link
|
||||
|
||||
### Models
|
||||
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
|
||||
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
|
||||
- See the **model-integration** skill for the attention pattern, pipeline rules, test setup instructions, and other important details.
|
||||
|
||||
- See [models.md](models.md) for model conventions, attention pattern, implementation rules, dependencies, and gotchas.
|
||||
- See the [model-integration](./skills/model-integration/SKILL.md) skill for the full integration workflow, file structure, test setup, and other details.
|
||||
|
||||
### Pipelines & Schedulers
|
||||
|
||||
- Pipelines inherit from `DiffusionPipeline`
|
||||
- Schedulers use `SchedulerMixin` with `ConfigMixin`
|
||||
- Use `@torch.no_grad()` on pipeline `__call__`
|
||||
- Support `output_type="latent"` for skipping VAE decode
|
||||
- Support `generator` parameter for reproducibility
|
||||
- Use `self.progress_bar(timesteps)` for progress tracking
|
||||
- Don't subclass an existing pipeline for a variant — DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`)
|
||||
|
||||
## Skills
|
||||
|
||||
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents.
|
||||
Available skills: **model-integration** (adding/converting pipelines), **parity-testing** (debugging numerical parity).
|
||||
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents. Available skills include:
|
||||
|
||||
- [model-integration](./skills/model-integration/SKILL.md) (adding/converting pipelines)
|
||||
- [parity-testing](./skills/parity-testing/SKILL.md) (debugging numerical parity).
|
||||
|
||||
76
.ai/models.md
Normal file
76
.ai/models.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# Model conventions and rules
|
||||
|
||||
Shared reference for model-related conventions, patterns, and gotchas.
|
||||
Linked from `AGENTS.md`, `skills/model-integration/SKILL.md`, and `review-rules.md`.
|
||||
|
||||
## Coding style
|
||||
|
||||
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
|
||||
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
|
||||
- No new mandatory dependency without discussion (e.g. `einops`). Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`.
|
||||
|
||||
## Common model conventions
|
||||
|
||||
- Models use `ModelMixin` with `register_to_config` for config serialization
|
||||
|
||||
## Attention pattern
|
||||
|
||||
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
|
||||
|
||||
```python
|
||||
# transformer_mymodel.py
|
||||
|
||||
class MyModelAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(self, attn, hidden_states, attention_mask=None, ...):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
# reshape, apply rope, etc.
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
return attn.to_out[0](hidden_states)
|
||||
|
||||
|
||||
class MyModelAttention(nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = MyModelAttnProcessor
|
||||
_available_processors = [MyModelAttnProcessor]
|
||||
|
||||
def __init__(self, query_dim, heads=8, dim_head=64, ...):
|
||||
super().__init__()
|
||||
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
|
||||
self.set_processor(MyModelAttnProcessor())
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
||||
return self.processor(self, hidden_states, attention_mask, **kwargs)
|
||||
```
|
||||
|
||||
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
|
||||
|
||||
## Gotchas
|
||||
|
||||
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
|
||||
|
||||
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
|
||||
|
||||
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
|
||||
|
||||
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
|
||||
|
||||
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
|
||||
|
||||
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
|
||||
|
||||
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
|
||||
|
||||
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
|
||||
@@ -3,8 +3,8 @@
|
||||
Review-specific rules for Claude. Focus on correctness — style is handled by ruff.
|
||||
|
||||
Before reviewing, read and apply the guidelines in:
|
||||
- [AGENTS.md](AGENTS.md) — coding style, dependencies, copied code, model conventions
|
||||
- [skills/model-integration/SKILL.md](skills/model-integration/SKILL.md) — attention pattern, pipeline rules, implementation checklist, gotchas
|
||||
- [AGENTS.md](AGENTS.md) — coding style, copied code
|
||||
- [models.md](models.md) — model conventions, attention pattern, implementation rules, dependencies, gotchas
|
||||
- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities
|
||||
- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.)
|
||||
|
||||
|
||||
@@ -65,89 +65,19 @@ docs/source/en/api/
|
||||
- [ ] Run `make style` and `make quality`
|
||||
- [ ] Test parity with reference implementation (see `parity-testing` skill)
|
||||
|
||||
### Attention pattern
|
||||
### Model conventions, attention pattern, and implementation rules
|
||||
|
||||
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
|
||||
See [../../models.md](../../models.md) for the attention pattern, implementation rules, common conventions, dependencies, and gotchas. These apply to all model work.
|
||||
|
||||
```python
|
||||
# transformer_mymodel.py
|
||||
### Model integration specific rules
|
||||
|
||||
class MyModelAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(self, attn, hidden_states, attention_mask=None, ...):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
# reshape, apply rope, etc.
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
return attn.to_out[0](hidden_states)
|
||||
|
||||
|
||||
class MyModelAttention(nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = MyModelAttnProcessor
|
||||
_available_processors = [MyModelAttnProcessor]
|
||||
|
||||
def __init__(self, query_dim, heads=8, dim_head=64, ...):
|
||||
super().__init__()
|
||||
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
|
||||
self.set_processor(MyModelAttnProcessor())
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
||||
return self.processor(self, hidden_states, attention_mask, **kwargs)
|
||||
```
|
||||
|
||||
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
|
||||
|
||||
### Implementation rules
|
||||
|
||||
1. **Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
|
||||
2. **Pipelines must inherit from `DiffusionPipeline`.** Consult implementations in `src/diffusers/pipelines` in case you need references.
|
||||
3. **Don't subclass an existing pipeline for a variant.** DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`).
|
||||
**Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
|
||||
|
||||
### Test setup
|
||||
|
||||
- Slow tests gated with `@slow` and `RUN_SLOW=1`
|
||||
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.
|
||||
|
||||
### Common diffusers conventions
|
||||
|
||||
- Pipelines inherit from `DiffusionPipeline`
|
||||
- Models use `ModelMixin` with `register_to_config` for config serialization
|
||||
- Schedulers use `SchedulerMixin` with `ConfigMixin`
|
||||
- Use `@torch.no_grad()` on pipeline `__call__`
|
||||
- Support `output_type="latent"` for skipping VAE decode
|
||||
- Support `generator` parameter for reproducibility
|
||||
- Use `self.progress_bar(timesteps)` for progress tracking
|
||||
|
||||
## Gotchas
|
||||
|
||||
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
|
||||
|
||||
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
|
||||
|
||||
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
|
||||
|
||||
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
|
||||
|
||||
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
|
||||
|
||||
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
|
||||
|
||||
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
|
||||
|
||||
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
|
||||
|
||||
---
|
||||
|
||||
## Modular Pipeline Conversion
|
||||
|
||||
3
.github/workflows/claude_review.yml
vendored
3
.github/workflows/claude_review.yml
vendored
@@ -32,6 +32,9 @@ jobs:
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
|
||||
@@ -161,6 +161,8 @@
|
||||
- local: training/ddpo
|
||||
title: Reinforcement learning training with DDPO
|
||||
title: Methods
|
||||
- local: training/nemo_automodel
|
||||
title: NeMo Automodel
|
||||
title: Training
|
||||
- isExpanded: false
|
||||
sections:
|
||||
|
||||
378
docs/source/en/training/nemo_automodel.md
Normal file
378
docs/source/en/training/nemo_automodel.md
Normal file
@@ -0,0 +1,378 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# NeMo Automodel
|
||||
|
||||
[NeMo Automodel](https://github.com/NVIDIA-NeMo/Automodel) is a PyTorch DTensor-native training library from NVIDIA for fine-tuning and pretraining diffusion models at scale. It is Hugging Face native — train any Diffusers-format model from the Hub with no checkpoint conversion. The same YAML recipe and hackable training script runs on any scale from 1 GPU to hundreds of nodes, with [FSDP2](https://pytorch.org/docs/stable/fsdp.html) distributed training, multiresolution bucketed dataloading, and pre-encoded latent space training for maximum GPU utilization. It uses [flow matching](https://huggingface.co/papers/2210.02747) for training and is fully open source (Apache 2.0), NVIDIA-supported, and actively maintained.
|
||||
|
||||
NeMo Automodel integrates directly with Diffusers. It loads pretrained models from the Hugging Face Hub using Diffusers model classes and generates outputs with the [`DiffusionPipeline`].
|
||||
|
||||
The typical workflow is to install NeMo Automodel (pip or Docker), prepare your data by encoding it into `.meta` files, configure a YAML recipe, launch training with `torchrun`, and run inference with the resulting checkpoint.
|
||||
|
||||
## Supported models
|
||||
|
||||
| Model | Hugging Face ID | Task | Parameters | Use case |
|
||||
|-------|----------------|------|------------|----------|
|
||||
| Wan 2.1 T2V 1.3B | [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) | Text-to-Video | 1.3B | video generation on limited hardware (fits on single 40GB A100) |
|
||||
| FLUX.1-dev | [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) | Text-to-Image | 12B | high-quality image generation |
|
||||
| HunyuanVideo 1.5 | [hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v) | Text-to-Video | 13B | high-quality video generation |
|
||||
|
||||
## Installation
|
||||
|
||||
### Hardware requirements
|
||||
|
||||
| Component | Minimum | Recommended |
|
||||
|-----------|---------|-------------|
|
||||
| GPU | A100 40GB | A100 80GB / H100 |
|
||||
| GPUs | 4 | 8+ |
|
||||
| RAM | 128 GB | 256 GB+ |
|
||||
| Storage | 500 GB SSD | 2 TB NVMe |
|
||||
|
||||
Install NeMo Automodel with pip. For the full set of installation methods (including from source), see the [NeMo Automodel installation guide](https://docs.nvidia.com/nemo/automodel/latest/guides/installation.html).
|
||||
|
||||
```bash
|
||||
pip3 install nemo-automodel
|
||||
```
|
||||
|
||||
Alternatively, use the pre-built Docker container which includes all dependencies.
|
||||
|
||||
```bash
|
||||
docker pull nvcr.io/nvidia/nemo-automodel:26.02.00
|
||||
docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/nemo-automodel:26.02.00
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Checkpoints are lost when the container exits unless you bind-mount the checkpoint directory to the host. For example, add `-v /host/path/checkpoints:/workspace/checkpoints` to the `docker run` command.
|
||||
|
||||
|
||||
## Data preparation
|
||||
|
||||
NeMo Automodel trains diffusion models in latent space. Raw images or videos must be preprocessed into `.meta` files containing VAE latents and text embeddings before training. This avoids re-encoding on every training step.
|
||||
|
||||
Use the built-in preprocessing tool to encode your data. The tool automatically distributes work across all available GPUs.
|
||||
|
||||
<hfoptions id="data-prep">
|
||||
<hfoption id="video preprocessing">
|
||||
|
||||
The video preprocessing command is the same for both Wan 2.1 and HunyuanVideo, but the flags differ. Wan 2.1 uses `--processor wan` with `--resolution_preset` and `--caption_format sidecar`, while HunyuanVideo uses `--processor hunyuan` with `--target_frames` to set the frame count and `--caption_format meta_json`.
|
||||
|
||||
**Wan 2.1:**
|
||||
|
||||
```bash
|
||||
python -m tools.diffusion.preprocessing_multiprocess video \
|
||||
--video_dir /data/videos \
|
||||
--output_dir /cache \
|
||||
--processor wan \
|
||||
--resolution_preset 512p \
|
||||
--caption_format sidecar
|
||||
```
|
||||
|
||||
**HunyuanVideo:**
|
||||
|
||||
```bash
|
||||
python -m tools.diffusion.preprocessing_multiprocess video \
|
||||
--video_dir /data/videos \
|
||||
--output_dir /cache \
|
||||
--processor hunyuan \
|
||||
--target_frames 121 \
|
||||
--caption_format meta_json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="image preprocessing">
|
||||
|
||||
```bash
|
||||
python -m tools.diffusion.preprocessing_multiprocess image \
|
||||
--image_dir /data/images \
|
||||
--output_dir /cache \
|
||||
--processor flux \
|
||||
--resolution_preset 512p
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Output format
|
||||
|
||||
Preprocessing produces a cache directory organized by resolution bucket. NeMo Automodel supports multi-resolution training through bucketed sampling. Samples are grouped by spatial resolution so each batch contains same-size samples, avoiding padding waste.
|
||||
|
||||
```
|
||||
/cache/
|
||||
├── 512x512/ # Resolution bucket
|
||||
│ ├── <hash1>.meta # VAE latents + text embeddings
|
||||
│ ├── <hash2>.meta
|
||||
│ └── ...
|
||||
├── 832x480/ # Another resolution bucket
|
||||
│ └── ...
|
||||
├── metadata.json # Global config (processor, model, total items)
|
||||
└── metadata_shard_0000.json # Per-sample metadata (paths, resolutions, captions)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> See the [Diffusion Dataset Preparation](https://docs.nvidia.com/nemo/automodel/latest/guides/diffusion/dataset.html) guide for caption formats, input data requirements, and all available preprocessing arguments.
|
||||
|
||||
## Training configuration
|
||||
|
||||
Fine-tuning is driven by two components:
|
||||
|
||||
1. A recipe script ([finetune.py](https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/diffusion/finetune/finetune.py)) is a Python entry point that contains the training loop: loading the model, building the dataloader, running forward/backward passes, computing the flow matching loss, checkpointing, and logging.
|
||||
2. A YAML configuration file specifies all settings the recipe uses: which model to fine-tune, where the data lives, optimizer hyperparameters, parallelism strategy, and more. You customize training by editing this file rather than modifying code, allowing you to scale from 1 to hundreds of GPUs.
|
||||
|
||||
Any YAML field can also be overridden from the CLI:
|
||||
|
||||
```bash
|
||||
torchrun --nproc-per-node=8 examples/diffusion/finetune/finetune.py \
|
||||
-c examples/diffusion/finetune/wan2_1_t2v_flow.yaml \
|
||||
--optim.learning_rate 1e-5 \
|
||||
--step_scheduler.num_epochs 50
|
||||
```
|
||||
|
||||
Below is the annotated config for fine-tuning Wan 2.1 T2V 1.3B, with each section explained.
|
||||
|
||||
```yaml
|
||||
seed: 42
|
||||
|
||||
# ── Experiment tracking (optional) ──────────────────────────────────────────
|
||||
# Weights & Biases integration for logging metrics, losses, and learning rates.
|
||||
# Set mode: "disabled" to turn off.
|
||||
wandb:
|
||||
project: wan-t2v-flow-matching
|
||||
mode: online
|
||||
name: wan2_1_t2v_fm
|
||||
|
||||
# ── Model ───────────────────────────────────────────────────────────────────
|
||||
# pretrained_model_name_or_path: any Hugging Face model ID or local path.
|
||||
# mode: "finetune" loads pretrained weights; "pretrain" trains from scratch.
|
||||
model:
|
||||
pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
|
||||
mode: finetune
|
||||
|
||||
# ── Training schedule ───────────────────────────────────────────────────────
|
||||
# global_batch_size: effective batch across all GPUs.
|
||||
# Gradient accumulation is computed automatically: global / (local × num_gpus).
|
||||
step_scheduler:
|
||||
global_batch_size: 8
|
||||
local_batch_size: 1
|
||||
ckpt_every_steps: 1000 # Save a checkpoint every N steps
|
||||
num_epochs: 100
|
||||
log_every: 2 # Log metrics every N steps
|
||||
|
||||
# ── Data ────────────────────────────────────────────────────────────────────
|
||||
# _target_: the dataloader factory function.
|
||||
# Use build_video_multiresolution_dataloader for video models (Wan, HunyuanVideo).
|
||||
# Use build_text_to_image_multiresolution_dataloader for image models (FLUX).
|
||||
# model_type: "wan" or "hunyuan" (selects the correct latent format).
|
||||
# base_resolution: target resolution for multiresolution bucketing.
|
||||
data:
|
||||
dataloader:
|
||||
_target_: nemo_automodel.components.datasets.diffusion.build_video_multiresolution_dataloader
|
||||
cache_dir: PATH_TO_YOUR_DATA
|
||||
model_type: wan
|
||||
base_resolution: [512, 512]
|
||||
dynamic_batch_size: false # When true, adjusts batch per bucket to maintain constant memory
|
||||
shuffle: true
|
||||
drop_last: false
|
||||
num_workers: 0
|
||||
|
||||
# ── Optimizer ───────────────────────────────────────────────────────────────
|
||||
# learning_rate: 5e-6 is a good starting point for fine-tuning.
|
||||
# Adjust weight_decay and betas for your dataset.
|
||||
optim:
|
||||
learning_rate: 5e-6
|
||||
optimizer:
|
||||
weight_decay: 0.01
|
||||
betas: [0.9, 0.999]
|
||||
|
||||
# ── Learning rate scheduler ─────────────────────────────────────────────────
|
||||
# Supports cosine, linear, and constant schedules.
|
||||
lr_scheduler:
|
||||
lr_decay_style: cosine
|
||||
lr_warmup_steps: 0
|
||||
min_lr: 1e-6
|
||||
|
||||
# ── Flow matching ───────────────────────────────────────────────────────────
|
||||
# adapter_type: model-specific adapter — must match the model:
|
||||
# "simple" for Wan 2.1, "flux" for FLUX.1-dev, "hunyuan" for HunyuanVideo.
|
||||
# timestep_sampling: "uniform" for Wan, "logit_normal" for FLUX and HunyuanVideo.
|
||||
# flow_shift: shifts the flow schedule (model-dependent).
|
||||
# i2v_prob: probability of image-to-video conditioning during training (video models).
|
||||
flow_matching:
|
||||
adapter_type: "simple"
|
||||
adapter_kwargs: {}
|
||||
timestep_sampling: "uniform"
|
||||
logit_mean: 0.0
|
||||
logit_std: 1.0
|
||||
flow_shift: 3.0
|
||||
num_train_timesteps: 1000
|
||||
i2v_prob: 0.3
|
||||
use_loss_weighting: true
|
||||
|
||||
# ── FSDP2 distributed training ──────────────────────────────────────────────
|
||||
# dp_size: number of GPUs for data parallelism (typically = total GPUs on node).
|
||||
# tp_size, cp_size, pp_size: tensor, context, and pipeline parallelism.
|
||||
# For most fine-tuning, dp_size is all you need; leave others at 1.
|
||||
fsdp:
|
||||
tp_size: 1
|
||||
cp_size: 1
|
||||
pp_size: 1
|
||||
dp_replicate_size: 1
|
||||
dp_size: 8
|
||||
|
||||
# ── Checkpointing ──────────────────────────────────────────────────────────
|
||||
# checkpoint_dir: where to save checkpoints (use a persistent path with Docker).
|
||||
# restore_from: path to resume training from a previous checkpoint.
|
||||
checkpoint:
|
||||
enabled: true
|
||||
checkpoint_dir: PATH_TO_YOUR_CKPT_DIR
|
||||
model_save_format: torch_save
|
||||
save_consolidated: false
|
||||
restore_from: null
|
||||
```
|
||||
|
||||
### Config field reference
|
||||
|
||||
The table below lists the minimal required configs. See the [NeMo Automodel examples](https://github.com/NVIDIA-NeMo/Automodel/tree/main/examples/diffusion/finetune) have full example configs for all models.
|
||||
|
||||
| Section | Required? | What to Change |
|
||||
|---------|-----------|----------------|
|
||||
| `model` | Yes | Set `pretrained_model_name_or_path` to the Hugging Face model ID. Set `mode: finetune` or `mode: pretrain`. |
|
||||
| `step_scheduler` | Yes | `global_batch_size` is the effective batch size across all GPUs. `ckpt_every_steps` controls checkpoint frequency. Gradient accumulation is computed automatically. |
|
||||
| `data` | Yes | Set `cache_dir` to the path containing your preprocessed `.meta` files. Change `_target_` and `model_type` for different models. |
|
||||
| `optim` | Yes | `learning_rate: 5e-6` is a good default for fine-tuning. Adjust for your dataset and model. |
|
||||
| `lr_scheduler` | Yes | Choose `cosine`, `linear`, or `constant` for `lr_decay_style`. Set `lr_warmup_steps` for gradual warmup. |
|
||||
| `flow_matching` | Yes | `adapter_type` must match the model (`simple` for Wan, `flux` for FLUX, `hunyuan` for HunyuanVideo). See model-specific configs for `adapter_kwargs`. |
|
||||
| `fsdp` | Yes | Set `dp_size` to the number of GPUs. For multi-node, set to total GPUs across all nodes. |
|
||||
| `checkpoint` | Recommended | Set `checkpoint_dir` to a persistent path, especially in Docker. Use `restore_from` to resume from a previous checkpoint. |
|
||||
| `wandb` | Optional | Configure to enable Weights & Biases experiment tracking. Set `mode: disabled` to turn off. |
|
||||
|
||||
|
||||
|
||||
## Launch training
|
||||
|
||||
<hfoptions id="launch-training">
|
||||
<hfoption id="single-node">
|
||||
|
||||
```bash
|
||||
torchrun --nproc-per-node=8 \
|
||||
examples/diffusion/finetune/finetune.py \
|
||||
-c examples/diffusion/finetune/wan2_1_t2v_flow.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="multi-node">
|
||||
|
||||
Run the following on each node, setting `NODE_RANK` accordingly:
|
||||
|
||||
```bash
|
||||
export MASTER_ADDR=node0.hostname
|
||||
export MASTER_PORT=29500
|
||||
export NODE_RANK=0 # 0 on master, 1 on second node, etc.
|
||||
|
||||
torchrun \
|
||||
--nnodes=2 \
|
||||
--nproc-per-node=8 \
|
||||
--node_rank=${NODE_RANK} \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
|
||||
examples/diffusion/finetune/finetune.py \
|
||||
-c examples/diffusion/finetune/wan2_1_t2v_flow_multinode.yaml
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For multi-node training, set `fsdp.dp_size` in the YAML to the **total** number of GPUs across all nodes (e.g., 16 for 2 nodes with 8 GPUs each).
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Generation
|
||||
|
||||
After training, generate videos or images from text prompts using the fine-tuned checkpoint.
|
||||
|
||||
<hfoptions id="generation">
|
||||
<hfoption id="Wan 2.1">
|
||||
|
||||
```bash
|
||||
python examples/diffusion/generate/generate.py \
|
||||
-c examples/diffusion/generate/configs/generate_wan.yaml
|
||||
```
|
||||
|
||||
With a fine-tuned checkpoint:
|
||||
|
||||
```bash
|
||||
python examples/diffusion/generate/generate.py \
|
||||
-c examples/diffusion/generate/configs/generate_wan.yaml \
|
||||
--model.checkpoint ./checkpoints/step_1000 \
|
||||
--inference.prompts '["A dog running on a beach"]'
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="FLUX">
|
||||
|
||||
```bash
|
||||
python examples/diffusion/generate/generate.py \
|
||||
-c examples/diffusion/generate/configs/generate_flux.yaml
|
||||
```
|
||||
|
||||
With a fine-tuned checkpoint:
|
||||
|
||||
```bash
|
||||
python examples/diffusion/generate/generate.py \
|
||||
-c examples/diffusion/generate/configs/generate_flux.yaml \
|
||||
--model.checkpoint ./checkpoints/step_1000 \
|
||||
--inference.prompts '["A dog running on a beach"]'
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="HunyuanVideo">
|
||||
|
||||
```bash
|
||||
python examples/diffusion/generate/generate.py \
|
||||
-c examples/diffusion/generate/configs/generate_hunyuan.yaml
|
||||
```
|
||||
|
||||
With a fine-tuned checkpoint:
|
||||
|
||||
```bash
|
||||
python examples/diffusion/generate/generate.py \
|
||||
-c examples/diffusion/generate/configs/generate_hunyuan.yaml \
|
||||
--model.checkpoint ./checkpoints/step_1000 \
|
||||
--inference.prompts '["A dog running on a beach"]'
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Diffusers integration
|
||||
|
||||
NeMo Automodel is built on top of Diffusers and uses it as the backbone for model loading and inference. It loads models directly from the Hugging Face Hub using Diffusers model classes such as [`WanTransformer3DModel`], [`FluxTransformer2DModel`], and [`HunyuanVideoTransformer3DModel`], and generates outputs via Diffusers pipelines like [`WanPipeline`] and [`FluxPipeline`].
|
||||
|
||||
This integration provides several benefits for Diffusers users:
|
||||
|
||||
- **No checkpoint conversion**: pretrained weights from the Hub work out of the box. Point `pretrained_model_name_or_path` at any Diffusers-format model ID and start training immediately.
|
||||
- **Day-0 model support**: when a new diffusion model is added to Diffusers and uploaded to the Hub, it can be fine-tuned with NeMo Automodel without waiting for a dedicated training script.
|
||||
- **Pipeline-compatible outputs**: fine-tuned checkpoints are saved in a format that can be loaded directly back into Diffusers pipelines for inference, sharing on the Hub, or further optimization with tools like quantization and compilation.
|
||||
- **Scalable training for Diffusers models**: NeMo Automodel adds distributed training capabilities (FSDP2, multi-node, multiresolution bucketing) that go beyond what the built-in Diffusers training scripts provide, while keeping the same model and pipeline interfaces.
|
||||
- **Shared ecosystem**: any model, LoRA adapter, or pipeline component from the Diffusers ecosystem remains compatible throughout the training and inference workflow.
|
||||
|
||||
## NVIDIA Team
|
||||
|
||||
- Pranav Prashant Thombre, pthombre@nvidia.com
|
||||
- Linnan Wang, linnanw@nvidia.com
|
||||
- Alexandros Koumparoulis, akoumparouli@nvidia.com
|
||||
|
||||
## Resources
|
||||
|
||||
- [NeMo Automodel GitHub](https://github.com/NVIDIA-NeMo/Automodel)
|
||||
- [Diffusion Fine-Tuning Guide](https://docs.nvidia.com/nemo/automodel/latest/guides/diffusion/finetune.html)
|
||||
- [Diffusion Dataset Preparation](https://docs.nvidia.com/nemo/automodel/latest/guides/diffusion/dataset.html)
|
||||
- [Diffusion Model Coverage](https://docs.nvidia.com/nemo/automodel/latest/model-coverage/diffusion.html)
|
||||
- [NeMo Automodel for Transformers (LLM/VLM fine-tuning)](https://huggingface.co/docs/transformers/en/community_integrations/nemo_automodel_finetuning)
|
||||
@@ -347,16 +347,17 @@ When LoRA was first adapted from language models to diffusion models, it was app
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
|
||||
the exact modules for LoRA training. Here are some examples of target modules you can provide:
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.to_qkv_mlp_proj"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.to_qkv_mlp_proj,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.linear_in,ff.linear_out,ff_context.linear_in,ff_context.linear_out"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.to_qkv_mlp_proj,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.linear_in,ff.linear_out,ff_context.linear_in,ff_context.linear_out,norm_out.linear,norm_out.proj_out"`
|
||||
> [!NOTE]
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
|
||||
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> [!NOTE]
|
||||
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
|
||||
|
||||
> [!NOTE]
|
||||
In FLUX2, the q, k, and v projections are fused into a single linear layer named attn.to_qkv_mlp_proj within the single transformer block. Also, the attention output is just attn.to_out, not attn.to_out.0 — it’s no longer a ModuleList like in transformer block.
|
||||
|
||||
## Training Image-to-Image
|
||||
|
||||
|
||||
@@ -1256,7 +1256,13 @@ def main(args):
|
||||
if args.lora_layers is not None:
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
else:
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
|
||||
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
|
||||
|
||||
# train transformer_blocks and single_transformer_blocks
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
|
||||
"to_qkv_mlp_proj",
|
||||
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(48)],
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
|
||||
@@ -1206,7 +1206,13 @@ def main(args):
|
||||
if args.lora_layers is not None:
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
else:
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
|
||||
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
|
||||
|
||||
# train transformer_blocks and single_transformer_blocks
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
|
||||
"to_qkv_mlp_proj",
|
||||
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(48)],
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
|
||||
@@ -1249,7 +1249,13 @@ def main(args):
|
||||
if args.lora_layers is not None:
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
else:
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
|
||||
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
|
||||
|
||||
# train transformer_blocks and single_transformer_blocks
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
|
||||
"to_qkv_mlp_proj",
|
||||
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(24)],
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
|
||||
@@ -1200,7 +1200,13 @@ def main(args):
|
||||
if args.lora_layers is not None:
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
else:
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
|
||||
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
|
||||
|
||||
# train transformer_blocks and single_transformer_blocks
|
||||
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
|
||||
"to_qkv_mlp_proj",
|
||||
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(24)],
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
|
||||
@@ -862,23 +862,23 @@ def _native_attention_backward_op(
|
||||
key.requires_grad_(True)
|
||||
value.requires_grad_(True)
|
||||
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
with torch.enable_grad():
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
grad_out_t = grad_out.permute(0, 2, 1, 3)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
|
||||
)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out, retain_graph=False
|
||||
)
|
||||
|
||||
grad_query = grad_query_t.permute(0, 2, 1, 3)
|
||||
grad_key = grad_key_t.permute(0, 2, 1, 3)
|
||||
|
||||
@@ -470,8 +470,8 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
if is_torchao_version("<=", "0.9.0"):
|
||||
raise ValueError("TorchAoConfig requires torchao > 0.9.0. Please upgrade with `pip install -U torchao`.")
|
||||
if is_torchao_version("<", "0.15.0"):
|
||||
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")
|
||||
|
||||
from torchao.quantization.quant_api import AOBaseConfig
|
||||
|
||||
@@ -495,8 +495,8 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
|
||||
"""Create configuration from a dictionary."""
|
||||
if not is_torchao_version(">", "0.9.0"):
|
||||
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
|
||||
if not is_torchao_version(">=", "0.15.0"):
|
||||
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
|
||||
config_dict = config_dict.copy()
|
||||
quant_type = config_dict.pop("quant_type")
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ if (
|
||||
is_torch_available()
|
||||
and is_torch_version(">=", "2.6.0")
|
||||
and is_torchao_available()
|
||||
and is_torchao_version(">=", "0.7.0")
|
||||
and is_torchao_version(">=", "0.15.0")
|
||||
):
|
||||
_update_torch_safe_globals()
|
||||
|
||||
@@ -168,10 +168,10 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
raise ImportError(
|
||||
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
|
||||
)
|
||||
torchao_version = version.parse(importlib.metadata.version("torch"))
|
||||
if torchao_version < version.parse("0.7.0"):
|
||||
torchao_version = version.parse(importlib.metadata.version("torchao"))
|
||||
if torchao_version < version.parse("0.15.0"):
|
||||
raise RuntimeError(
|
||||
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
f"The minimum required version of `torchao` is 0.15.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
|
||||
)
|
||||
|
||||
self.offload = False
|
||||
|
||||
@@ -13,24 +13,29 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLWan
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
class AutoencoderKLWanTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderKLWan
|
||||
|
||||
def get_autoencoder_kl_wan_config(self):
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"base_dim": 3,
|
||||
"z_dim": 16,
|
||||
@@ -39,54 +44,51 @@ class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.T
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
image = torch.randn(batch_size, num_channels, num_frames, *sizes).to(torch_device)
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def dummy_input_tiling(self):
|
||||
# Bridge for AutoencoderTesterMixin which still uses the old interface
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return self.get_init_dict(), self.get_dummy_inputs()
|
||||
|
||||
def prepare_init_args_and_inputs_for_tiling(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (128, 128)
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
return {"sample": image}
|
||||
return self.get_init_dict(), {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
class TestAutoencoderKLWan(AutoencoderKLWanTesterConfig, ModelTesterMixin):
|
||||
base_precision = 1e-2
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_wan_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def prepare_init_args_and_inputs_for_tiling(self):
|
||||
init_dict = self.get_autoencoder_kl_wan_config()
|
||||
inputs_dict = self.dummy_input_tiling
|
||||
return init_dict, inputs_dict
|
||||
class TestAutoencoderKLWanTraining(AutoencoderKLWanTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for AutoencoderKLWan."""
|
||||
|
||||
@unittest.skip("Gradient checkpointing has not been implemented yet")
|
||||
@pytest.mark.skip(reason="Gradient checkpointing has not been implemented yet")
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_forward_with_norm_groups(self):
|
||||
|
||||
class TestAutoencoderKLWanMemory(AutoencoderKLWanTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderKLWan."""
|
||||
|
||||
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestAutoencoderKLWanSlicingTiling(AutoencoderKLWanTesterConfig, AutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderKLWan."""
|
||||
|
||||
@@ -44,9 +44,9 @@ class AutoencoderTesterMixin:
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
if not hasattr(self.model_class, "enable_tiling"):
|
||||
|
||||
@@ -465,8 +465,7 @@ class UNetTesterMixin:
|
||||
def test_forward_with_norm_groups(self):
|
||||
if not self._accepts_norm_num_groups(self.model_class):
|
||||
pytest.skip(f"Test not supported for {self.model_class.__name__}")
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
@@ -481,9 +480,9 @@ class UNetTesterMixin:
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
|
||||
@@ -287,9 +287,8 @@ class ModelTesterMixin:
|
||||
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||
)
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
image = model(**inputs_dict, return_dict=False)[0]
|
||||
new_image = new_model(**inputs_dict, return_dict=False)[0]
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@@ -309,9 +308,8 @@ class ModelTesterMixin:
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
image = model(**inputs_dict, return_dict=False)[0]
|
||||
new_image = new_model(**inputs_dict, return_dict=False)[0]
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@@ -339,9 +337,8 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
first = model(**inputs_dict, return_dict=False)[0]
|
||||
second = model(**inputs_dict, return_dict=False)[0]
|
||||
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
@@ -398,9 +395,8 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
outputs_dict = model(**inputs_dict)
|
||||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
@@ -527,10 +523,8 @@ class ModelTesterMixin:
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
# Re-create inputs only if they contain a generator (which needs to be reset)
|
||||
if "generator" in inputs_dict:
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict, return_dict=False)[0]
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
|
||||
@@ -569,10 +563,8 @@ class ModelTesterMixin:
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
# Re-create inputs only if they contain a generator (which needs to be reset)
|
||||
if "generator" in inputs_dict:
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict, return_dict=False)[0]
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
|
||||
@@ -622,10 +614,8 @@ class ModelTesterMixin:
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
# Re-create inputs only if they contain a generator (which needs to be reset)
|
||||
if "generator" in inputs_dict:
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict, return_dict=False)[0]
|
||||
inputs_dict_parallel = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
|
||||
|
||||
@@ -92,6 +92,9 @@ class TorchCompileTesterMixin:
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
|
||||
@@ -98,6 +98,64 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _context_parallel_backward_worker(
|
||||
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict
|
||||
):
|
||||
"""Worker function for context parallel backward pass testing."""
|
||||
try:
|
||||
# Set up distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
# Get device configuration
|
||||
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
|
||||
backend = device_config["backend"]
|
||||
device_module = device_config["module"]
|
||||
|
||||
# Initialize process group
|
||||
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
|
||||
# Set device for this process
|
||||
device_module.set_device(rank)
|
||||
device = torch.device(f"{torch_device}:{rank}")
|
||||
|
||||
# Create model in training mode
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.train()
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
# Run forward and backward pass
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
# Check that backward actually produced at least one valid gradient
|
||||
grads = [p.grad for p in model.parameters() if p.requires_grad and p.grad is not None]
|
||||
has_valid_grads = len(grads) > 0 and all(torch.isfinite(g).all() for g in grads)
|
||||
|
||||
# Only rank 0 reports results
|
||||
if rank == 0:
|
||||
return_dict["status"] = "success"
|
||||
return_dict["has_valid_grads"] = bool(has_valid_grads)
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
return_dict["status"] = "error"
|
||||
return_dict["error"] = str(e)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _custom_mesh_worker(
|
||||
rank,
|
||||
world_size,
|
||||
@@ -204,6 +262,51 @@ class ContextParallelTesterMixin:
|
||||
def test_context_parallel_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_inference(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_backward(self, cp_type, batch_size: int = 1):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
if active_backend == AttentionBackendName.NATIVE:
|
||||
pytest.skip("Ring attention is not supported with the native attention backend.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
|
||||
|
||||
# Move all tensors to CPU for multiprocessing
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
# Find a free port for distributed communication
|
||||
master_port = _find_free_port()
|
||||
|
||||
# Use multiprocessing manager for cross-process communication
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
# Spawn worker processes
|
||||
mp.spawn(
|
||||
_context_parallel_backward_worker,
|
||||
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel backward pass failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
assert return_dict.get("has_valid_grads"), "Context parallel backward pass did not produce valid gradients."
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_backward_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_backward(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -24,39 +26,64 @@ from ...testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
_LAYERWISE_CASTING_XFAIL_REASON = (
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
)
|
||||
|
||||
|
||||
class UNet1DTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet1DModel testing (standard variant)."""
|
||||
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet1DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet1DModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 14, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (14, 16)
|
||||
return (4, 14, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_ema_training(self):
|
||||
pass
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
def test_output(self):
|
||||
super().test_output()
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (8, 8, 16, 16),
|
||||
"in_channels": 14,
|
||||
"out_channels": 14,
|
||||
@@ -70,40 +97,18 @@ class UNet1DTesterConfig(BaseModelTesterConfig):
|
||||
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
|
||||
"act_fn": "swish",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
|
||||
"timestep": torch.tensor([10] * batch_size).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet1D(UNet1DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Not implemented yet for this UNet")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestUNet1DMemory(UNet1DTesterConfig, MemoryTesterMixin):
|
||||
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
|
||||
def test_layerwise_casting_memory(self):
|
||||
super().test_layerwise_casting_memory()
|
||||
|
||||
|
||||
class TestUNet1DHubLoading(UNet1DTesterConfig):
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
|
||||
)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.get_dummy_inputs())
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -126,7 +131,12 @@ class TestUNet1DHubLoading(UNet1DTesterConfig):
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
|
||||
# fmt: on
|
||||
assert torch.allclose(output_slice, expected_output_slice, rtol=1e-3)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# Not implemented yet for this UNet
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_unet_1d_maestro(self):
|
||||
@@ -147,29 +157,98 @@ class TestUNet1DHubLoading(UNet1DTesterConfig):
|
||||
assert (output_sum - 224.0896).abs() < 0.5
|
||||
assert (output_max - 0.0607).abs() < 4e-4
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_inference(self):
|
||||
super().test_layerwise_casting_inference()
|
||||
|
||||
# =============================================================================
|
||||
# UNet1D RL (Value Function) Model Tests
|
||||
# =============================================================================
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
|
||||
class UNet1DRLTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet1DModel testing (RL value function variant)."""
|
||||
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet1DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet1DModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 14, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1,)
|
||||
return (4, 14, 1)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
def test_output(self):
|
||||
# UNetRL is a value-function is different output shape
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_ema_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 14,
|
||||
"out_channels": 14,
|
||||
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
|
||||
@@ -185,54 +264,18 @@ class UNet1DRLTesterConfig(BaseModelTesterConfig):
|
||||
"time_embedding_type": "positional",
|
||||
"act_fn": "mish",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
|
||||
"timestep": torch.tensor([10] * batch_size).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet1DRL(UNet1DRLTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Not implemented yet for this UNet")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def test_output(self):
|
||||
# UNetRL is a value-function with different output shape (batch, 1)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert output is not None
|
||||
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
|
||||
class TestUNet1DRLMemory(UNet1DRLTesterConfig, MemoryTesterMixin):
|
||||
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
|
||||
def test_layerwise_casting_memory(self):
|
||||
super().test_layerwise_casting_memory()
|
||||
|
||||
|
||||
class TestUNet1DRLHubLoading(UNet1DRLTesterConfig):
|
||||
def test_from_pretrained_hub(self):
|
||||
value_function, vf_loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
|
||||
)
|
||||
assert value_function is not None
|
||||
assert len(vf_loading_info["missing_keys"]) == 0
|
||||
self.assertIsNotNone(value_function)
|
||||
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
|
||||
|
||||
value_function.to(torch_device)
|
||||
image = value_function(**self.get_dummy_inputs())
|
||||
image = value_function(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -256,4 +299,31 @@ class TestUNet1DRLHubLoading(UNet1DRLTesterConfig):
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([165.25] * seq_len)
|
||||
# fmt: on
|
||||
assert torch.allclose(output, expected_output_slice, rtol=1e-3)
|
||||
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# Not implemented yet for this UNet
|
||||
pass
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@@ -15,11 +15,12 @@
|
||||
|
||||
import gc
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import UNet2DModel
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
@@ -30,40 +31,39 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Standard UNet2D Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class UNet2DTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for standard UNet2DModel testing."""
|
||||
class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet2DModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 2,
|
||||
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
|
||||
@@ -74,22 +74,11 @@ class UNet2DTesterConfig(BaseModelTesterConfig):
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
def test_mid_block_attn_groups(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["add_attention"] = True
|
||||
init_dict["attn_norm_num_groups"] = 4
|
||||
@@ -98,11 +87,13 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
assert model.mid_block.attentions[0].group_norm is not None, (
|
||||
"Mid block Attention group norm should exist but does not."
|
||||
self.assertIsNotNone(
|
||||
model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
|
||||
)
|
||||
assert model.mid_block.attentions[0].group_norm.num_groups == init_dict["attn_norm_num_groups"], (
|
||||
"Mid block Attention group norm does not have the expected number of groups."
|
||||
self.assertEqual(
|
||||
model.mid_block.attentions[0].group_norm.num_groups,
|
||||
init_dict["attn_norm_num_groups"],
|
||||
"Mid block Attention group norm does not have the expected number of groups.",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -111,15 +102,13 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_mid_block_none(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
mid_none_init_dict = self.get_init_dict()
|
||||
mid_none_inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
mid_none_init_dict["mid_block_type"] = None
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -130,7 +119,7 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
mid_none_model.to(torch_device)
|
||||
mid_none_model.eval()
|
||||
|
||||
assert mid_none_model.mid_block is None, "Mid block should not exist."
|
||||
self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
@@ -144,10 +133,8 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
if isinstance(mid_none_output, dict):
|
||||
mid_none_output = mid_none_output.to_tuple()[0]
|
||||
|
||||
assert not torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different."
|
||||
self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
|
||||
|
||||
|
||||
class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"AttnUpBlock2D",
|
||||
@@ -156,32 +143,41 @@ class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin):
|
||||
"UpBlock2D",
|
||||
"DownBlock2D",
|
||||
}
|
||||
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
attention_head_dim = 8
|
||||
block_out_channels = (16, 32)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# UNet2D LDM Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class UNet2DLDMTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet2DModel LDM variant testing."""
|
||||
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet2DModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 32,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
@@ -191,34 +187,17 @@ class UNet2DLDMTesterConfig(BaseModelTesterConfig):
|
||||
"down_block_types": ("DownBlock2D", "DownBlock2D"),
|
||||
"up_block_types": ("UpBlock2D", "UpBlock2D"),
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet2DLDMTraining(UNet2DLDMTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.get_dummy_inputs()).sample
|
||||
image = model(**self.dummy_input).sample
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -226,7 +205,7 @@ class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
|
||||
def test_from_pretrained_accelerate(self):
|
||||
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model.to(torch_device)
|
||||
image = model(**self.get_dummy_inputs()).sample
|
||||
image = model(**self.dummy_input).sample
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -286,31 +265,44 @@ class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
|
||||
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
|
||||
# fmt: on
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-3)
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
|
||||
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
attention_head_dim = 32
|
||||
block_out_channels = (32, 64)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# NCSN++ Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class NCSNppTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet2DModel NCSN++ variant testing."""
|
||||
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet2DModel
|
||||
def dummy_input(self, sizes=(32, 32)):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": [32, 64, 64, 64],
|
||||
"in_channels": 3,
|
||||
"layers_per_block": 1,
|
||||
@@ -332,71 +324,17 @@ class NCSNppTesterConfig(BaseModelTesterConfig):
|
||||
"SkipUpBlock2D",
|
||||
],
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestNCSNpp(NCSNppTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_keep_in_fp32_modules(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_from_save_pretrained_dtype_inference(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestNCSNppMemory(NCSNppTesterConfig, MemoryTesterMixin):
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestNCSNppTraining(NCSNppTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"UNetMidBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestNCSNppHubLoading(NCSNppTesterConfig):
|
||||
@slow
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.dummy_input
|
||||
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
|
||||
inputs["sample"] = noise
|
||||
image = model(**inputs)
|
||||
@@ -423,7 +361,7 @@ class TestNCSNppHubLoading(NCSNppTesterConfig):
|
||||
expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
|
||||
# fmt: on
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
def test_output_pretrained_ve_large(self):
|
||||
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
|
||||
@@ -444,4 +382,35 @@ class TestNCSNppHubLoading(NCSNppTesterConfig):
|
||||
expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
|
||||
# fmt: on
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# not required for this model
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"UNetMidBlock2D",
|
||||
}
|
||||
|
||||
block_out_channels = (32, 64, 64, 64)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
|
||||
|
||||
@unittest.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@@ -20,7 +20,6 @@ import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from parameterized import parameterized
|
||||
@@ -53,24 +52,17 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
IPAdapterTesterMixin,
|
||||
from ..test_modeling_common import (
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
UNetTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
|
||||
from ..testing_utils.lora import check_if_lora_correctly_set
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -90,6 +82,16 @@ def get_unet_lora_config():
|
||||
return unet_lora_config
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Checks if the LoRA layers are correctly set with peft
|
||||
"""
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def create_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
@@ -352,28 +354,34 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
return custom_diffusion_attn_procs
|
||||
|
||||
|
||||
class UNet2DConditionTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet2DConditionModel testing."""
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
main_input_name = "sample"
|
||||
# We override the items here because the unet under consideration is small.
|
||||
model_split_percents = [0.5, 0.34, 0.4]
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet2DConditionModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int, int]:
|
||||
def input_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list[float]:
|
||||
return [0.5, 0.34, 0.4]
|
||||
def output_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
"""Return UNet2D model initialization arguments."""
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 4,
|
||||
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
@@ -385,24 +393,26 @@ class UNet2DConditionTesterConfig(BaseModelTesterConfig):
|
||||
"layers_per_block": 1,
|
||||
"sample_size": 16,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
"""Return dummy inputs for UNet2D model."""
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
|
||||
}
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
def test_model_with_attention_head_dim_tuple(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -417,13 +427,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_use_linear_projection(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["use_linear_projection"] = True
|
||||
|
||||
@@ -437,13 +446,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_cross_attention_dim_tuple(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["cross_attention_dim"] = (8, 8)
|
||||
|
||||
@@ -457,13 +465,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_simple_projection(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
batch_size, _, _, sample_size = inputs_dict["sample"].shape
|
||||
|
||||
@@ -482,13 +489,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_class_embeddings_concat(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
batch_size, _, _, sample_size = inputs_dict["sample"].shape
|
||||
|
||||
@@ -508,287 +514,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
|
||||
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
|
||||
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty small,
|
||||
# maybe it's fine that this only works for the unclip use-case.
|
||||
@mark.skip(
|
||||
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
|
||||
)
|
||||
def test_model_xattn_padding(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
cond = inputs_dict["encoder_hidden_states"]
|
||||
with torch.no_grad():
|
||||
full_cond_out = model(**inputs_dict).sample
|
||||
assert full_cond_out is not None
|
||||
|
||||
batch, tokens, _ = cond.shape
|
||||
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
|
||||
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
|
||||
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
|
||||
|
||||
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
|
||||
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
|
||||
assert trunc_mask_out.allclose(keeplast_out), (
|
||||
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
|
||||
)
|
||||
|
||||
def test_pickle(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
|
||||
def test_asymmetrical_unet(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
# Add asymmetry to configs
|
||||
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
|
||||
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
output = model(**inputs_dict).sample
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
|
||||
# Check if input and output shapes are the same
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
|
||||
class TestUNet2DConditionHubLoading(UNet2DConditionTesterConfig):
|
||||
"""Hub checkpoint loading tests for UNet2DConditionModel."""
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
|
||||
)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
|
||||
class TestUNet2DConditionLoRA(UNet2DConditionTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for UNet2DConditionModel."""
|
||||
|
||||
@require_peft_backend
|
||||
def test_load_attn_procs_raise_warning(self):
|
||||
"""Test that deprecated load_attn_procs method raises FutureWarning."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample_1 = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
model.unload_lora()
|
||||
|
||||
with pytest.warns(FutureWarning, match="Using the `load_attn_procs\\(\\)` method has been deprecated"):
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
# import to still check for the rest of the stuff.
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with torch.no_grad():
|
||||
lora_sample_2 = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
|
||||
"LoRA injected UNet should produce different results."
|
||||
)
|
||||
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
|
||||
"Loading from a saved checkpoint should produce identical results."
|
||||
)
|
||||
|
||||
@require_peft_backend
|
||||
def test_save_attn_procs_raise_warning(self):
|
||||
"""Test that deprecated save_attn_procs method raises FutureWarning."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with pytest.warns(FutureWarning, match="Using the `save_attn_procs\\(\\)` method has been deprecated"):
|
||||
model.save_attn_procs(os.path.join(tmpdirname))
|
||||
|
||||
|
||||
class TestUNet2DConditionMemory(UNet2DConditionTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for UNet2DConditionModel."""
|
||||
|
||||
|
||||
class TestUNet2DConditionTraining(UNet2DConditionTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for UNet2DConditionModel."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"UpBlock2D",
|
||||
"Transformer2DModel",
|
||||
"DownBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for UNet2DConditionModel."""
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -813,7 +544,7 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
|
||||
assert output is not None
|
||||
|
||||
def test_model_sliceable_head_dim(self):
|
||||
init_dict = self.get_init_dict()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -831,6 +562,21 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
|
||||
for module in model.children():
|
||||
check_sliceable_dim_attr(module)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"UpBlock2D",
|
||||
"Transformer2DModel",
|
||||
"DownBlock2D",
|
||||
}
|
||||
attention_head_dim = (8, 16)
|
||||
block_out_channels = (16, 32)
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
def test_special_attn_proc(self):
|
||||
class AttnEasyProc(torch.nn.Module):
|
||||
def __init__(self, num):
|
||||
@@ -872,8 +618,7 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
|
||||
return hidden_states
|
||||
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -900,8 +645,7 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
|
||||
]
|
||||
)
|
||||
def test_model_xattn_mask(self, mask_dtype):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)})
|
||||
model.to(torch_device)
|
||||
@@ -931,13 +675,39 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
|
||||
"masking the last token from our cond should be equivalent to truncating that token out of the condition"
|
||||
)
|
||||
|
||||
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
|
||||
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
|
||||
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric.
|
||||
# maybe it's fine that this only works for the unclip use-case.
|
||||
@mark.skip(
|
||||
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
|
||||
)
|
||||
def test_model_xattn_padding(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
|
||||
"""Custom Diffusion processor tests for UNet2DConditionModel."""
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
cond = inputs_dict["encoder_hidden_states"]
|
||||
with torch.no_grad():
|
||||
full_cond_out = model(**inputs_dict).sample
|
||||
assert full_cond_out is not None
|
||||
|
||||
batch, tokens, _ = cond.shape
|
||||
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
|
||||
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
|
||||
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
|
||||
|
||||
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
|
||||
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
|
||||
assert trunc_mask_out.allclose(keeplast_out), (
|
||||
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
|
||||
)
|
||||
|
||||
def test_custom_diffusion_processors(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -963,8 +733,8 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
|
||||
assert (sample1 - sample2).abs().max() < 3e-3
|
||||
|
||||
def test_custom_diffusion_save_load(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -984,7 +754,7 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname, safe_serialization=False)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
@@ -1003,8 +773,8 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_custom_diffusion_xformers_on_off(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -1028,28 +798,41 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
|
||||
def test_pickle(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for UNet2DConditionModel."""
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
return (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
return create_ip_adapter_state_dict(model)
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
|
||||
# for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim]
|
||||
cross_attention_dim = getattr(model.config, "cross_attention_dim", 8)
|
||||
image_embeds = floats_tensor((batch_size, 1, cross_attention_dim)).to(torch_device)
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
|
||||
return inputs_dict
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
|
||||
def test_asymmetrical_unet(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
# Add asymmetry to configs
|
||||
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
|
||||
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
output = model(**inputs_dict).sample
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
|
||||
# Check if input and output shapes are the same
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -1122,8 +905,7 @@ class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterM
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_ip_adapter_plus(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -1195,16 +977,185 @@ class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterM
|
||||
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for UNet2DConditionModel."""
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
return super().test_torch_compile_repeated_blocks(recompile_limit=2)
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
|
||||
)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_peft_backend
|
||||
def test_load_attn_procs_raise_warning(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample_1 = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
model.unload_lora()
|
||||
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
|
||||
|
||||
# import to still check for the rest of the stuff.
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with torch.no_grad():
|
||||
lora_sample_2 = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
|
||||
"LoRA injected UNet should produce different results."
|
||||
)
|
||||
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
|
||||
"Loading from a saved checkpoint should produce identical results."
|
||||
)
|
||||
|
||||
@require_peft_backend
|
||||
def test_save_attn_procs_raise_warning(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
|
||||
|
||||
|
||||
class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for UNet2DConditionModel."""
|
||||
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -18,44 +18,47 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import UNet3DConditionModel
|
||||
from diffusers.models import ModelMixin, UNet3DConditionModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
)
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@skip_mps
|
||||
class UNet3DConditionTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet3DConditionModel testing."""
|
||||
class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet3DConditionModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet3DConditionModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
num_frames = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 4, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 4, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 4,
|
||||
"down_block_types": (
|
||||
@@ -70,25 +73,27 @@ class UNet3DConditionTesterConfig(BaseModelTesterConfig):
|
||||
"layers_per_block": 1,
|
||||
"sample_size": 16,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
num_frames = 4
|
||||
sizes = (16, 16)
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
|
||||
}
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
# Overriding to set `norm_num_groups` needs to be different for this model.
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict["block_out_channels"] = (32, 64)
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
@@ -102,74 +107,39 @@ class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTes
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
# Overriding since the UNet3D outputs a different structure.
|
||||
@torch.no_grad()
|
||||
def test_determinism(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
model(**self.dummy_input)
|
||||
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.sample
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.sample
|
||||
|
||||
second = model(**inputs_dict)
|
||||
if isinstance(second, dict):
|
||||
second = second.sample
|
||||
second = model(**inputs_dict)
|
||||
if isinstance(second, dict):
|
||||
second = second.sample
|
||||
|
||||
out_1 = first.cpu().numpy()
|
||||
out_2 = second.cpu().numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
assert max_diff <= 1e-5
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict["block_out_channels"] = (32, 64)
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
|
||||
model.enable_forward_chunking()
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
assert output.shape == output_2.shape, "Shape doesn't match"
|
||||
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
|
||||
|
||||
|
||||
class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterMixin):
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = 8
|
||||
@@ -192,3 +162,22 @@ class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterM
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict["block_out_channels"] = (32, 64)
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
|
||||
model.enable_forward_chunking()
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
|
||||
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
|
||||
|
||||
@@ -13,42 +13,59 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNetControlNetXSModel testing."""
|
||||
class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNetControlNetXSModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNetControlNetXSModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
|
||||
conditioning_scale = 1
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"controlnet_cond": controlnet_cond,
|
||||
"conditioning_scale": conditioning_scale,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 16,
|
||||
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
@@ -63,23 +80,11 @@ class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
|
||||
"ctrl_max_norm_num_groups": 2,
|
||||
"ctrl_conditioning_embedding_out_channels": (2, 2),
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
conditioning_image_size = (3, 32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
|
||||
"controlnet_cond": floats_tensor((batch_size, *conditioning_image_size)).to(torch_device),
|
||||
"conditioning_scale": 1,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_unet(self):
|
||||
"""Build the underlying UNet for tests that construct UNetControlNetXSModel from UNet + Adapter."""
|
||||
"""For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
|
||||
return UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
@@ -94,16 +99,10 @@ class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
|
||||
)
|
||||
|
||||
def get_dummy_controlnet_from_unet(self, unet, **kwargs):
|
||||
"""Build the ControlNetXS-Adapter from a UNet."""
|
||||
"""For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
|
||||
# size_ratio and conditioning_embedding_out_channels chosen to keep model small
|
||||
return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
|
||||
|
||||
|
||||
class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# UNetControlNetXSModel only supports SD/SDXL with norm_num_groups=32
|
||||
pass
|
||||
|
||||
def test_from_unet(self):
|
||||
unet = self.get_dummy_unet()
|
||||
controlnet = self.get_dummy_controlnet_from_unet(unet)
|
||||
@@ -116,7 +115,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
|
||||
|
||||
# # check unet
|
||||
# everything except down,mid,up blocks
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_unet = [
|
||||
"time_embedding",
|
||||
"conv_in",
|
||||
@@ -153,7 +152,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
|
||||
|
||||
# # check controlnet
|
||||
# everything except down,mid,up blocks
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_controlnet = {
|
||||
"controlnet_cond_embedding": "controlnet_cond_embedding",
|
||||
"conv_in": "ctrl_conv_in",
|
||||
@@ -194,12 +193,12 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
for p in module.parameters():
|
||||
assert p.requires_grad
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = UNetControlNetXSModel(**init_dict)
|
||||
model.freeze_unet_params()
|
||||
|
||||
# # check unet
|
||||
# everything except down,mid,up blocks
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_unet = [
|
||||
model.base_time_embedding,
|
||||
model.base_conv_in,
|
||||
@@ -237,7 +236,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
assert_frozen(u.upsamplers)
|
||||
|
||||
# # check controlnet
|
||||
# everything except down,mid,up blocks
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_controlnet = [
|
||||
model.controlnet_cond_embedding,
|
||||
model.ctrl_conv_in,
|
||||
@@ -268,6 +267,16 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
for u in model.up_blocks:
|
||||
assert_unfrozen(u.ctrl_to_base)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"Transformer2DModel",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"ControlNetXSCrossAttnDownBlock2D",
|
||||
"ControlNetXSCrossAttnMidBlock2D",
|
||||
"ControlNetXSCrossAttnUpBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@is_flaky
|
||||
def test_forward_no_control(self):
|
||||
unet = self.get_dummy_unet()
|
||||
@@ -278,7 +287,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
unet = unet.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ = self.get_dummy_inputs()
|
||||
input_ = self.dummy_input
|
||||
|
||||
control_specific_input = ["controlnet_cond", "conditioning_scale"]
|
||||
input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
|
||||
@@ -303,7 +312,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
model = model.to(torch_device)
|
||||
model_mix_time = model_mix_time.to(torch_device)
|
||||
|
||||
input_ = self.get_dummy_inputs()
|
||||
input_ = self.dummy_input
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**input_).sample
|
||||
@@ -311,14 +320,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
|
||||
assert output.shape == output_mix_time.shape
|
||||
|
||||
|
||||
class TestUNetControlNetXSTraining(UNetControlNetXSTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"Transformer2DModel",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"ControlNetXSCrossAttnDownBlock2D",
|
||||
"ControlNetXSCrossAttnMidBlock2D",
|
||||
"ControlNetXSCrossAttnUpBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
|
||||
pass
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import UNetSpatioTemporalConditionModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ...testing_utils import (
|
||||
@@ -28,34 +28,45 @@ from ...testing_utils import (
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@skip_mps
|
||||
class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNetSpatioTemporalConditionModel testing."""
|
||||
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNetSpatioTemporalConditionModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNetSpatioTemporalConditionModel
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 2
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"added_time_ids": self._get_add_time_ids(),
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (2, 2, 4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
return 6
|
||||
@@ -72,8 +83,8 @@ class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
|
||||
def addition_time_embed_dim(self):
|
||||
return 32
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (32, 64),
|
||||
"down_block_types": (
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
@@ -92,23 +103,8 @@ class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
|
||||
"projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3,
|
||||
"addition_time_embed_dim": self.addition_time_embed_dim,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 2
|
||||
num_frames = 2
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"added_time_ids": self._get_add_time_ids(),
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def _get_add_time_ids(self, do_classifier_free_guidance=True):
|
||||
add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength]
|
||||
@@ -128,15 +124,43 @@ class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
|
||||
|
||||
return add_time_ids
|
||||
|
||||
|
||||
class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Number of Norm Groups is not configurable")
|
||||
@unittest.skip("Number of Norm Groups is not configurable")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Deprecated functionality")
|
||||
def test_model_attention_slicing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported")
|
||||
def test_model_with_use_linear_projection(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported")
|
||||
def test_model_with_simple_projection(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported")
|
||||
def test_model_with_class_embeddings_concat(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
def test_model_with_num_attention_heads_tuple(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["num_attention_heads"] = (8, 16)
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -149,13 +173,12 @@ class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, U
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_cross_attention_dim_tuple(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["cross_attention_dim"] = (32, 32)
|
||||
|
||||
@@ -169,13 +192,27 @@ class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, U
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"TransformerSpatioTemporalModel",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"UNetMidBlockSpatioTemporal",
|
||||
}
|
||||
num_attention_heads = (8, 16)
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, num_attention_heads=num_attention_heads
|
||||
)
|
||||
|
||||
def test_pickle(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["num_attention_heads"] = (8, 16)
|
||||
|
||||
@@ -188,33 +225,3 @@ class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, U
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
|
||||
|
||||
class TestUNetSpatioTemporalAttention(UNetSpatioTemporalTesterConfig, AttentionTesterMixin):
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
|
||||
class TestUNetSpatioTemporalTraining(UNetSpatioTemporalTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"TransformerSpatioTemporalModel",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"UNetMidBlockSpatioTemporal",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -14,13 +14,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import importlib.metadata
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
@@ -82,18 +80,17 @@ if is_torchao_available():
|
||||
Float8WeightOnlyConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8DynamicActivationIntxWeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
IntxWeightOnlyConfig,
|
||||
)
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
from torchao.utils import get_model_size_in_bytes
|
||||
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.10.0"):
|
||||
from torchao.quantization import Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
@@ -128,7 +125,7 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
@@ -527,7 +524,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
_ = pipe(**inputs)
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.9.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
def test_aobase_config(self):
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
@@ -540,7 +537,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
model_name = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
@@ -650,7 +647,7 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self._check_serialization_expected_slice(quant_type, expected_slice, device)
|
||||
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
@@ -696,7 +693,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoTests(unittest.TestCase):
|
||||
@@ -854,7 +851,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@require_torchao_version_greater_or_equal("0.15.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user