mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-03 07:10:34 +08:00
Compare commits
7 Commits
modular-no
...
modular-sa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db9483f781 | ||
|
|
369e123bde | ||
|
|
20995862b0 | ||
|
|
a9269d2cf2 | ||
|
|
c167fe335e | ||
|
|
e340b52a92 | ||
|
|
71ce634d1e |
100
.claude/CLAUDE.md
Normal file
100
.claude/CLAUDE.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Build, Lint, and Test Commands
|
||||
|
||||
```bash
|
||||
# Install in development mode
|
||||
pip install -e ".[dev]"
|
||||
|
||||
# Run full test suite (requires beefy machine)
|
||||
make test
|
||||
# Or directly:
|
||||
python -m pytest -n auto --dist=loadfile -s -v ./tests/
|
||||
|
||||
# Run a single test file
|
||||
python -m pytest tests/<TEST_FILE>.py
|
||||
|
||||
# Run slow tests (downloads many GBs of models)
|
||||
RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/
|
||||
|
||||
# Format code (ruff + doc-builder)
|
||||
make style
|
||||
|
||||
# Check code quality without modifying
|
||||
make quality
|
||||
|
||||
# Fast fixup for modified files only (recommended before commits)
|
||||
make fixup
|
||||
|
||||
# Fix copied code snippets and dummy objects
|
||||
make fix-copies
|
||||
|
||||
# Check repository consistency (dummies, inits, repo structure)
|
||||
make repo-consistency
|
||||
```
|
||||
|
||||
## Code Architecture
|
||||
|
||||
Diffusers is built on three core component types that work together:
|
||||
|
||||
### Pipelines (`src/diffusers/pipelines/`)
|
||||
- End-to-end inference workflows combining models and schedulers
|
||||
- Base class: `DiffusionPipeline` (in `pipeline_utils.py`)
|
||||
- Follow **single-file policy**: each pipeline in its own directory
|
||||
- Loaded via `DiffusionPipeline.from_pretrained()` which reads `model_index.json`
|
||||
- Components registered via `register_modules()` become pipeline attributes
|
||||
- ~99 pipeline implementations (Stable Diffusion, SDXL, Flux, etc.)
|
||||
|
||||
### Models (`src/diffusers/models/`)
|
||||
- Configurable neural network architectures extending PyTorch's Module
|
||||
- Base classes: `ModelMixin` + `ConfigMixin` (in `modeling_utils.py`)
|
||||
- **Do NOT follow single-file policy**: use shared building blocks (`attention.py`, `embeddings.py`, `resnet.py`)
|
||||
- Key subdirectories:
|
||||
- `autoencoders/`: VAEs for latent space compression
|
||||
- `unets/`: Diffusion model architectures (UNet2DConditionModel, etc.)
|
||||
- `transformers/`: Transformer-based models (Flux, SD3, etc.)
|
||||
- `controlnets/`: ControlNet variants
|
||||
|
||||
### Schedulers (`src/diffusers/schedulers/`)
|
||||
- Guide denoising process during inference
|
||||
- Base class: `SchedulerMixin` + `ConfigMixin` (in `scheduling_utils.py`)
|
||||
- Follow **single-file policy**: one scheduler per file
|
||||
- Key methods: `set_num_inference_steps()`, `step()`, `timesteps` property
|
||||
- Easily swappable via `ConfigMixin.from_config()`
|
||||
- ~55 scheduler algorithms (DDPM, DDIM, Euler, DPM-Solver, etc.)
|
||||
|
||||
### Supporting Systems
|
||||
|
||||
- **Loaders** (`src/diffusers/loaders/`): Mixins for LoRA, IP-Adapter, textual inversion, single-file loading
|
||||
- **Quantizers** (`src/diffusers/quantizers/`): BitsAndBytes, GGUF, TorchAO, Quanto support
|
||||
- **Hooks** (`src/diffusers/hooks/`): Runtime optimizations (offloading, layer skipping, caching)
|
||||
- **Guiders** (`src/diffusers/guiders/`): Guidance algorithms (CFG, PAG, etc.)
|
||||
|
||||
## Configuration System
|
||||
|
||||
All components use `ConfigMixin` for serialization:
|
||||
- Constructor arguments stored via `register_to_config(**kwargs)`
|
||||
- Instantiate from config: `Component.from_config(config_dict)`
|
||||
- Save/load as JSON files
|
||||
|
||||
## Key Design Principles
|
||||
|
||||
1. **Usability over Performance**: Models load at float32/CPU by default
|
||||
2. **Simple over Easy**: Explicit > implicit; expose complexity rather than hide it
|
||||
3. **Single-file policy**: Pipelines and schedulers are self-contained; models share building blocks
|
||||
4. **Copy-paste over abstraction**: Prefer duplicated code over hasty abstractions for contributor-friendliness
|
||||
|
||||
## Code Style
|
||||
|
||||
- Uses `ruff` for linting and formatting (line length: 119)
|
||||
- Documentation follows [Google style](https://google.github.io/styleguide/pyguide.html)
|
||||
- Use `# Copied from` mechanism for sharing code between similar files
|
||||
- Avoid lambda functions and advanced PyTorch operators for readability
|
||||
|
||||
## Testing
|
||||
|
||||
- Tests use `pytest` with `pytest-xdist` for parallelization
|
||||
- Slow tests gated by `RUN_SLOW=yes` environment variable
|
||||
- Test dependencies: `pip install -e ".[test]"`
|
||||
75
_modular_model_index.json
Normal file
75
_modular_model_index.json
Normal file
@@ -0,0 +1,75 @@
|
||||
{
|
||||
"_blocks_class_name": "SequentialPipelineBlocks",
|
||||
"_class_name": "Flux2ModularPipeline",
|
||||
"_diffusers_version": "0.36.0.dev0",
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"FlowMatchEulerDiscreteScheduler",
|
||||
{
|
||||
"repo": "hf-internal-testing/tiny-flux2",
|
||||
"revision": null,
|
||||
"subfolder": "scheduler",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"FlowMatchEulerDiscreteScheduler"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"Mistral3ForConditionalGeneration",
|
||||
{
|
||||
"repo": "hf-internal-testing/tiny-flux2",
|
||||
"revision": null,
|
||||
"subfolder": "text_encoder",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"Mistral3ForConditionalGeneration"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"AutoProcessor",
|
||||
{
|
||||
"repo": "hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor",
|
||||
"revision": null,
|
||||
"subfolder": "",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"AutoProcessor"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"transformer": [
|
||||
"diffusers",
|
||||
"Flux2Transformer2DModel",
|
||||
{
|
||||
"repo": "hf-internal-testing/tiny-flux2",
|
||||
"revision": null,
|
||||
"subfolder": "transformer",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"Flux2Transformer2DModel"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKLFlux2",
|
||||
{
|
||||
"repo": "hf-internal-testing/tiny-flux2",
|
||||
"revision": null,
|
||||
"subfolder": "vae",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"AutoencoderKLFlux2"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
]
|
||||
}
|
||||
239
custom_model_automodel_guide.md
Normal file
239
custom_model_automodel_guide.md
Normal file
@@ -0,0 +1,239 @@
|
||||
# Loading Custom Models with `AutoModel` and `trust_remote_code`
|
||||
|
||||
This guide shows how to create a custom model class that lives outside the `diffusers` library and load it via `AutoModel` with `trust_remote_code=True`.
|
||||
|
||||
## How It Works
|
||||
|
||||
When `AutoModel.from_pretrained()` (or `from_config()`) is called with `trust_remote_code=True`, it:
|
||||
|
||||
1. Loads the `config.json` from the model repository.
|
||||
2. Checks for an `"auto_map"` key in the config that maps `"AutoModel"` to a `"<module_file>.<ClassName>"` reference.
|
||||
3. Downloads the referenced Python module from the repository.
|
||||
4. Dynamically imports and instantiates the class from that module.
|
||||
|
||||
This allows anyone to define and share completely custom model architectures without requiring changes to the `diffusers` library itself.
|
||||
|
||||
## Step 1: Define Your Custom Model
|
||||
|
||||
Create a Python file (e.g., `modeling_my_model.py`) that defines your model class. The class must inherit from `ModelMixin` and `ConfigMixin`, and use the `@register_to_config` decorator on `__init__`.
|
||||
|
||||
```python
|
||||
# modeling_my_model.py
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
|
||||
|
||||
class MyCustomModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(self, in_channels: int = 3, hidden_dim: int = 64, out_channels: int = 3):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(hidden_dim, out_channels, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
```
|
||||
|
||||
Key requirements:
|
||||
|
||||
- **`ModelMixin`** provides `save_pretrained()` / `from_pretrained()` for weight serialization.
|
||||
- **`ConfigMixin`** provides `save_config()` / `from_config()` and the `config.json` machinery.
|
||||
- **`@register_to_config`** automatically captures all `__init__` parameters into `config.json` so the model can be reconstructed from config alone.
|
||||
|
||||
## Step 2: Save the Model Locally
|
||||
|
||||
```python
|
||||
from modeling_my_model import MyCustomModel
|
||||
|
||||
model = MyCustomModel(in_channels=3, hidden_dim=128, out_channels=3)
|
||||
model.save_pretrained("./my-custom-model")
|
||||
```
|
||||
|
||||
This creates a directory with:
|
||||
|
||||
```
|
||||
my-custom-model/
|
||||
├── config.json
|
||||
└── diffusion_pytorch_model.safetensors
|
||||
```
|
||||
|
||||
The generated `config.json` will look like:
|
||||
|
||||
```json
|
||||
{
|
||||
"_class_name": "MyCustomModel",
|
||||
"_diffusers_version": "0.32.0",
|
||||
"in_channels": 3,
|
||||
"hidden_dim": 128,
|
||||
"out_channels": 3
|
||||
}
|
||||
```
|
||||
|
||||
## Step 3: Add the `auto_map` and Model File to the Repository
|
||||
|
||||
To make `AutoModel` aware of your custom class, you need to:
|
||||
|
||||
1. **Copy `modeling_my_model.py` into the saved model directory.**
|
||||
2. **Add an `"auto_map"` entry to `config.json`** that points `AutoModel` to your class.
|
||||
|
||||
The `auto_map` value format is `"<module_name_without_.py>.<ClassName>"`:
|
||||
|
||||
```json
|
||||
{
|
||||
"_class_name": "MyCustomModel",
|
||||
"_diffusers_version": "0.32.0",
|
||||
"in_channels": 3,
|
||||
"hidden_dim": 128,
|
||||
"out_channels": 3,
|
||||
"auto_map": {
|
||||
"AutoModel": "modeling_my_model.MyCustomModel"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Your final directory structure should be:
|
||||
|
||||
```
|
||||
my-custom-model/
|
||||
├── config.json # with auto_map added
|
||||
├── diffusion_pytorch_model.safetensors
|
||||
└── modeling_my_model.py # your custom model code
|
||||
```
|
||||
|
||||
## Step 4: Load with `AutoModel`
|
||||
|
||||
### From a Local Directory
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained("./my-custom-model", trust_remote_code=True)
|
||||
print(model)
|
||||
```
|
||||
|
||||
### From the Hugging Face Hub
|
||||
|
||||
First, push the model directory to a Hub repository:
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo("your-username/my-custom-model", exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path="./my-custom-model",
|
||||
repo_id="your-username/my-custom-model",
|
||||
)
|
||||
```
|
||||
|
||||
Then load it:
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"your-username/my-custom-model",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Initializing from Config (Random Weights)
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel
|
||||
|
||||
model = AutoModel.from_config("./my-custom-model", trust_remote_code=True)
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch import nn
|
||||
from diffusers import ModelMixin, ConfigMixin, AutoModel
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
|
||||
|
||||
# 1. Define
|
||||
class MyCustomModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(self, in_channels: int = 3, hidden_dim: int = 64, out_channels: int = 3):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(hidden_dim, out_channels, kernel_size=3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# 2. Save
|
||||
model = MyCustomModel(in_channels=3, hidden_dim=128, out_channels=3)
|
||||
model.save_pretrained("./my-custom-model")
|
||||
|
||||
# 3. Manually add auto_map to config.json and copy modeling file
|
||||
import json, shutil
|
||||
|
||||
config_path = "./my-custom-model/config.json"
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
config["auto_map"] = {"AutoModel": "modeling_my_model.MyCustomModel"}
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
shutil.copy("modeling_my_model.py", "./my-custom-model/modeling_my_model.py")
|
||||
|
||||
# 4. Load via AutoModel
|
||||
loaded_model = AutoModel.from_pretrained("./my-custom-model", trust_remote_code=True)
|
||||
|
||||
# 5. Verify
|
||||
x = torch.randn(1, 3, 32, 32)
|
||||
with torch.no_grad():
|
||||
out_original = model(x)
|
||||
out_loaded = loaded_model(x)
|
||||
|
||||
assert torch.allclose(out_original, out_loaded)
|
||||
print("Models produce identical outputs!")
|
||||
```
|
||||
|
||||
## Using Relative Imports in Custom Code
|
||||
|
||||
If your custom model depends on additional modules, you can use relative imports. For example, if your model uses a custom attention layer defined in a separate file:
|
||||
|
||||
```
|
||||
my-custom-model/
|
||||
├── config.json
|
||||
├── diffusion_pytorch_model.safetensors
|
||||
├── modeling_my_model.py # imports from .my_attention
|
||||
└── my_attention.py # custom attention implementation
|
||||
```
|
||||
|
||||
In `modeling_my_model.py`:
|
||||
|
||||
```python
|
||||
from .my_attention import MyAttention
|
||||
```
|
||||
|
||||
The dynamic module loader will automatically resolve and download all relatively imported files.
|
||||
|
||||
## Security Note
|
||||
|
||||
`trust_remote_code=True` executes arbitrary Python code from the model repository. Only use it with repositories you trust. You can globally disable remote code execution by setting the environment variable:
|
||||
|
||||
```bash
|
||||
export DIFFUSERS_DISABLE_REMOTE_CODE=1
|
||||
```
|
||||
@@ -97,32 +97,5 @@ If the custom model inherits from the [`ModelMixin`] class, it gets access to th
|
||||
> )
|
||||
> ```
|
||||
|
||||
### Saving custom models
|
||||
|
||||
Use [`~ConfigMixin.register_for_auto_class`] to add the `auto_map` entry to `config.json` automatically when saving. This avoids having to manually edit the config file.
|
||||
|
||||
```py
|
||||
# my_model.py
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
|
||||
class MyCustomModel(ModelMixin, ConfigMixin):
|
||||
...
|
||||
|
||||
MyCustomModel.register_for_auto_class("AutoModel")
|
||||
|
||||
model = MyCustomModel(...)
|
||||
model.save_pretrained("./my_model")
|
||||
```
|
||||
|
||||
The saved `config.json` will include the `auto_map` field.
|
||||
|
||||
```json
|
||||
{
|
||||
"auto_map": {
|
||||
"AutoModel": "my_model.MyCustomModel"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
|
||||
120
example.py
Normal file
120
example.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class QwenImageTransformerTesterConfig:
|
||||
model_class = QwenImageTransformer2DModel
|
||||
pretrained_model_name_or_path = ""
|
||||
pretrained_model_kwargs = {"subfolder": "transformer"}
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
# __init__ parameters:
|
||||
# patch_size: int = 2
|
||||
# in_channels: int = 64
|
||||
# out_channels: Optional[int] = 16
|
||||
# num_layers: int = 60
|
||||
# attention_head_dim: int = 128
|
||||
# num_attention_heads: int = 24
|
||||
# joint_attention_dim: int = 3584
|
||||
# guidance_embeds: bool = False
|
||||
# axes_dims_rope: Tuple[int, int, int] = <complex>
|
||||
return {}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
# forward() parameters:
|
||||
# hidden_states: torch.Tensor
|
||||
# encoder_hidden_states: torch.Tensor
|
||||
# encoder_hidden_states_mask: torch.Tensor
|
||||
# timestep: torch.LongTensor
|
||||
# img_shapes: Optional[List[Tuple[int, int, int]]]
|
||||
# txt_seq_lens: Optional[List[int]]
|
||||
# guidance: torch.Tensor
|
||||
# attention_kwargs: Optional[Dict[str, Any]]
|
||||
# controlnet_block_samples
|
||||
# return_dict: bool = True
|
||||
# TODO: Fill in dummy inputs
|
||||
return {}
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (1, 1)
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (1, 1)
|
||||
|
||||
|
||||
class TestQwenImageTransformerModel(QwenImageTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestQwenImageTransformerTorchCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
# TODO: Implement dynamic input generation
|
||||
return {}
|
||||
|
||||
|
||||
class TestQwenImageTransformerLora(QwenImageTransformerTesterConfig, LoraTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoraHotSwappingForModel(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
# TODO: Implement dynamic input generation
|
||||
return {}
|
||||
73
modular_model_index.json
Normal file
73
modular_model_index.json
Normal file
@@ -0,0 +1,73 @@
|
||||
{
|
||||
"_blocks_class_name": "SequentialPipelineBlocks",
|
||||
"_class_name": "Flux2ModularPipeline",
|
||||
"_diffusers_version": "0.36.0.dev0",
|
||||
"scheduler": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev",
|
||||
"revision": null,
|
||||
"subfolder": "scheduler",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"FlowMatchEulerDiscreteScheduler"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"text_encoder": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"revision": null,
|
||||
"subfolder": "text_encoder",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"Mistral3ForConditionalGeneration"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"tokenizer": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"revision": null,
|
||||
"subfolder": "tokenizer",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"AutoProcessor"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"transformer": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"pretrained_model_name_or_path": "diffusers/FLUX.2-dev-bnb-4bit",
|
||||
"revision": null,
|
||||
"subfolder": "transformer",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"Flux2Transformer2DModel"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
"vae": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev",
|
||||
"revision": null,
|
||||
"subfolder": "vae",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"AutoencoderKLFlux2"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
]
|
||||
}
|
||||
56
pr_review/12498.md
Normal file
56
pr_review/12498.md
Normal file
@@ -0,0 +1,56 @@
|
||||
Now let me look at what `dequantize_gguf_tensor` does to understand the fix better:
|
||||
|
||||
## Code Review: GGUF fix for unquantized types
|
||||
|
||||
### 1. Summary of Changes
|
||||
|
||||
The PR fixes a bug in the `_fused_mul_mat_gguf` function (line 79-105) where unquantized GGUF tensor types (F32, F16, BF16) were incorrectly handled.
|
||||
|
||||
**Before:** When `qweight_type` was an unquantized type, the code directly performed matrix multiplication: `x @ qweight.T`
|
||||
|
||||
**After:** It now calls `dequantize_gguf_tensor(qweight)` first, then performs the matrix multiplication: `x @ weight.T`
|
||||
|
||||
The issue was that even "unquantized" GGUF tensors are stored in an 8-bit tensor format and need to be converted to their proper data type representation before use.
|
||||
|
||||
### 2. Potential Issues or Bugs
|
||||
|
||||
**None identified.** The fix is correct and addresses a real bug:
|
||||
|
||||
- The `dequantize_gguf_tensor` function (lines 509-527) checks if the tensor has a `quant_type` attribute and handles the appropriate conversion
|
||||
- For BF16 specifically, there's a dedicated `dequantize_blocks_BF16` function (lines 428-429) that properly converts the 8-bit storage format
|
||||
- The fix aligns with how the native path already works in `forward_native` (lines 593-599), which always calls `dequantize_gguf_tensor`
|
||||
|
||||
### 3. Code Quality Observations
|
||||
|
||||
**Strengths:**
|
||||
- The fix is minimal and surgical - only changes what's necessary
|
||||
- Maintains consistency with the `forward_native` path which already uses `dequantize_gguf_tensor`
|
||||
- The variable naming (`weight` instead of reusing `qweight`) makes it clear a transformation occurred
|
||||
|
||||
**Minor observation:**
|
||||
- The comment on line 80 "there is no need to call any kernel for fp16/bf16" is now slightly misleading since we DO need to call dequantization logic. Consider updating it to something like: "no need to call specialized GGUF kernel for fp16/bf16, but still need to dequantize from 8-bit storage"
|
||||
|
||||
### 4. Security Considerations
|
||||
|
||||
**No security concerns.** The change:
|
||||
- Doesn't introduce any external input handling
|
||||
- Doesn't modify control flow in a way that could bypass security checks
|
||||
- Only fixes a data type conversion issue
|
||||
|
||||
### 5. Suggestions for Improvement
|
||||
|
||||
1. **Update the comment** on line 80 in `src/diffusers/quantizers/gguf/utils.py:80`:
|
||||
```python
|
||||
# unquantized types still need dequantization from 8-bit storage, but don't need specialized kernels
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
weight = dequantize_gguf_tensor(qweight)
|
||||
return x @ weight.T
|
||||
```
|
||||
|
||||
2. **Consider adding a test** to prevent regression of this issue. A test should verify that unquantized GGUF tensors produce correct output shapes and values.
|
||||
|
||||
3. **Documentation:** The PR description mentions torch 2.8/2.9 build availability. This might be worth tracking in a GitHub issue if not already done.
|
||||
|
||||
### Verdict
|
||||
|
||||
**Approve with minor comment update suggestion.** The fix correctly addresses a real shape mismatch bug where GGUF's 8-bit storage format wasn't being properly converted for unquantized types. The logic is sound and aligns with the existing native implementation path.
|
||||
186
pr_review/12744.md
Normal file
186
pr_review/12744.md
Normal file
@@ -0,0 +1,186 @@
|
||||
I'll provide a comprehensive code review of this MagCache PR.
|
||||
|
||||
## Summary of Changes
|
||||
|
||||
This PR implements MagCache (Magnitude-aware Cache), a training-free inference acceleration technique for diffusion transformers. The implementation:
|
||||
|
||||
- Adds a `MagCacheConfig` class for configuration
|
||||
- Implements `MagCacheHeadHook` and `MagCacheBlockHook` following the existing ModelHook pattern
|
||||
- Includes calibration mode to compute magnitude ratios for any transformer model
|
||||
- Provides pre-computed `FLUX_MAG_RATIOS` for Flux models
|
||||
- Adds comprehensive documentation and tests
|
||||
|
||||
## Potential Issues and Bugs
|
||||
|
||||
### 1. **Critical: Missing Hook Removal in `disable_cache()`**
|
||||
```python
|
||||
# In cache_utils.py, line ~127
|
||||
elif isinstance(self._cache_config, MagCacheConfig):
|
||||
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True)
|
||||
```
|
||||
|
||||
**Issue**: The code only removes the leader/head hook but not the block hooks (`_MAG_CACHE_BLOCK_HOOK`). This will leave hooks attached when disabling the cache.
|
||||
|
||||
**Fix**: Add removal of block hooks:
|
||||
```python
|
||||
elif isinstance(self._cache_config, MagCacheConfig):
|
||||
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
|
||||
```
|
||||
|
||||
### 2. **Shape Mismatch Handling Logic Issue**
|
||||
In `mag_cache.py` lines 224-248, the shape mismatch handling has a potential issue:
|
||||
|
||||
```python
|
||||
elif (
|
||||
output.ndim == 3
|
||||
and res.ndim == 3
|
||||
and output.shape[0] == res.shape[0]
|
||||
and output.shape[2] == res.shape[2]
|
||||
):
|
||||
diff = output.shape[1] - res.shape[1]
|
||||
if diff > 0:
|
||||
output = output.clone()
|
||||
output[:, diff:, :] = output[:, diff:, :] + res
|
||||
```
|
||||
|
||||
**Issue**: This assumes text tokens come first and image tokens come last. This may not be universal across all models (e.g., some models interleave tokens differently).
|
||||
|
||||
**Suggestion**: Add a comment explaining this assumption or add configuration to specify the concatenation strategy.
|
||||
|
||||
### 3. **Residual Calculation Fallback is Unsafe**
|
||||
In `mag_cache.py` line 343:
|
||||
|
||||
```python
|
||||
else:
|
||||
# Fallback for completely mismatched shapes
|
||||
residual = out_hidden
|
||||
```
|
||||
|
||||
**Issue**: This fallback doesn't compute a residual at all—it just uses the output. This will cause incorrect behavior in subsequent steps.
|
||||
|
||||
**Suggestion**: Either raise an error or add a warning that calibration is required for this model architecture.
|
||||
|
||||
### 4. **Device Mismatch Handling is Incomplete**
|
||||
```python
|
||||
if res.device != output.device:
|
||||
res = res.to(output.device)
|
||||
```
|
||||
|
||||
**Issue**: This only handles device mismatch for the residual, but doesn't handle dtype mismatches which could occur with mixed precision training.
|
||||
|
||||
**Suggestion**: Add dtype handling:
|
||||
```python
|
||||
if res.device != output.device or res.dtype != output.dtype:
|
||||
res = res.to(device=output.device, dtype=output.dtype)
|
||||
```
|
||||
|
||||
### 5. **Calibration Logging Could Be Missed**
|
||||
The calibration results are printed to stdout (line 380) and logged. However, if the user has logging disabled or redirected, they might miss this critical information.
|
||||
|
||||
**Suggestion**: Consider returning calibration results from the pipeline or raising a more visible notification.
|
||||
|
||||
### 6. **Test Suite is Skipped**
|
||||
```python
|
||||
@unittest.skip("MagCache unit tests are skipped.")
|
||||
class MagCacheTests(unittest.TestCase):
|
||||
```
|
||||
|
||||
**Issue**: All unit tests are skipped, which means the core logic isn't being validated in CI.
|
||||
|
||||
**Action Required**: Remove the skip decorator before merging or add a comment explaining why it's temporarily skipped.
|
||||
|
||||
## Code Quality Observations
|
||||
|
||||
### Strengths:
|
||||
1. **Well-structured**: Follows existing patterns (ModelHook, StateManager) consistently
|
||||
2. **Good documentation**: Comprehensive docstrings and inline comments
|
||||
3. **Calibration mode**: Clever design allowing model-agnostic usage
|
||||
4. **Error handling**: Validates configuration upfront
|
||||
5. **Interpolation logic**: Smart handling of different step counts via `nearest_interp()`
|
||||
|
||||
### Areas for Improvement:
|
||||
|
||||
1. **Magic Numbers**: Several hardcoded values could be constants:
|
||||
```python
|
||||
eps = 1e-8 # Line 335 in _perform_calibration_step
|
||||
expected_atol = 0.1 # Line 2989 in test
|
||||
```
|
||||
|
||||
2. **Code Duplication**: The logic for handling tuple returns appears multiple times. Consider extracting to a helper method.
|
||||
|
||||
3. **Type Hints**: Some methods lack return type hints (e.g., `nearest_interp`)
|
||||
|
||||
4. **Compiler Disable Decorator**: The `@torch.compiler.disable` decorator is used but not explained. Add a comment about why compilation is disabled.
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Low Risk:
|
||||
- No external network calls
|
||||
- No file system access beyond logging
|
||||
- No execution of arbitrary code
|
||||
- Tensor operations are standard PyTorch
|
||||
|
||||
### Observations:
|
||||
1. **Device Transfer**: The `.to(device)` calls are safe but could consume unexpected memory if tensors are large
|
||||
2. **State Management**: The state is properly isolated and reset between inference runs
|
||||
|
||||
## Suggestions for Improvement
|
||||
|
||||
### 1. Add Configuration Validation
|
||||
```python
|
||||
def __post_init__(self):
|
||||
# Existing checks...
|
||||
|
||||
# Add bounds checking
|
||||
if not 0.0 <= self.retention_ratio <= 1.0:
|
||||
raise ValueError(f"retention_ratio must be in [0, 1], got {self.retention_ratio}")
|
||||
if self.max_skip_steps < 1:
|
||||
raise ValueError(f"max_skip_steps must be >= 1, got {self.max_skip_steps}")
|
||||
if self.threshold <= 0:
|
||||
raise ValueError(f"threshold must be positive, got {self.threshold}")
|
||||
```
|
||||
|
||||
### 2. Add Metrics/Statistics
|
||||
Consider adding optional statistics collection:
|
||||
- How many blocks were skipped per step
|
||||
- Average accumulated error
|
||||
- Total compute savings
|
||||
|
||||
This would help users optimize their thresholds.
|
||||
|
||||
### 3. Improve Documentation Example
|
||||
The documentation example could show expected speedup or quality metrics to set user expectations.
|
||||
|
||||
### 4. Add Gradient Mode Check
|
||||
```python
|
||||
if torch.is_grad_enabled():
|
||||
logger.warning("MagCache is designed for inference only. Gradients are enabled but will not flow correctly through cached blocks.")
|
||||
```
|
||||
|
||||
### 5. Consider Memory Cleanup
|
||||
The `previous_residual` is held in state indefinitely. Consider adding explicit cleanup:
|
||||
```python
|
||||
def cleanup(self):
|
||||
if self.previous_residual is not None:
|
||||
del self.previous_residual
|
||||
self.previous_residual = None
|
||||
```
|
||||
|
||||
## Minor Issues
|
||||
|
||||
1. **Line 26**: Unused import or should be used in logger initialization
|
||||
2. **Line 332**: Comment says "Fallback to matching tail" but logic is unclear
|
||||
3. **Documentation**: The TIP about batched CFG could include more detail about why this works
|
||||
|
||||
## Conclusion
|
||||
|
||||
This is a **well-implemented feature** with good design patterns and documentation. The main concerns are:
|
||||
|
||||
1. **Critical**: Fix the missing block hook removal in `disable_cache()` (Line 127)
|
||||
2. **Important**: Unskip and fix the unit tests
|
||||
3. **Recommended**: Improve shape mismatch handling with better error messages
|
||||
|
||||
The implementation is production-ready once these issues are addressed. The calibration mode is particularly clever and makes this genuinely model-agnostic.
|
||||
|
||||
**Recommendation**: Request changes for items #1 and #2, then approve once fixed.
|
||||
99
pr_review/13028.md
Normal file
99
pr_review/13028.md
Normal file
@@ -0,0 +1,99 @@
|
||||
# PR #13028: [Modular] add explicit workflow support
|
||||
|
||||
**Author:** @yiyixuxu
|
||||
**Branch:** `modular-workflow` -> `main`
|
||||
**Files changed:** `modular_pipeline.py`, `modular_pipeline_utils.py`, `qwenimage/modular_blocks_qwenimage.py`
|
||||
**+298 / -165**
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
This PR adds a `_workflow_map` class attribute to `SequentialPipelineBlocks` that maps named workflows (e.g., `"text2image"`, `"inpainting"`) to their trigger inputs. Users can then call `get_workflow("text2image")` to get the execution blocks for that workflow. The PR also refactors `get_execution_blocks` into `ConditionalPipelineBlocks` and `SequentialPipelineBlocks`, moves `combine_inputs`/`combine_outputs` to module-level functions, and improves docstrings.
|
||||
|
||||
## Main Concern: "Workflow" as a New Concept
|
||||
|
||||
Modular Diffusers already requires users to learn: **Pipelines**, **Blocks** (Sequential, Conditional, Auto, Loop), **Steps**, **Components**, **Inputs/Outputs**, **Trigger Inputs**, **Execution Blocks**, **PipelineState**, and **BlockState**. Adding "workflow" as yet another term increases cognitive overhead.
|
||||
|
||||
The underlying feature is useful — named presets for trigger inputs are genuinely helpful for discoverability. But "workflow" may not be the right label:
|
||||
|
||||
1. **Overloaded term**: "Workflow" is heavily used in the AI/ML ecosystem (ComfyUI workflows, orchestration workflows, CI/CD workflows). Users may expect something more complex than what this is.
|
||||
|
||||
2. **It's really a task/mode, not a workflow**: `"text2image"`, `"inpainting"`, `"image2image"` are *tasks* or *modes*. The rest of diffusers already uses "task" terminology — `AutoPipelineForText2Image`, `AutoPipelineForInpainting`, etc. Calling the same concept "workflow" in Modular Diffusers creates inconsistency.
|
||||
|
||||
3. **It's a thin wrapper**: `get_workflow("text2image")` is just `get_execution_blocks(prompt=True)`. Users still need to understand `get_execution_blocks` and trigger inputs to do anything beyond the predefined workflows. The abstraction doesn't save much complexity.
|
||||
|
||||
**Suggestion**: Consider `_task_map` / `get_task()` / `task_names` to align with existing diffusers terminology, or `_mode_map` / `get_mode()` / `mode_names` for something more neutral. The existing `auto_pipeline.py` already uses "task" internally — `_get_task_class()` maps pipeline class names to task-specific variants (text2image, image2image, inpainting), and the public API follows the `AutoPipelineFor<Task>` naming pattern. These are the exact same concepts this PR calls "workflows." Alternatively, this could simply be better documentation on `get_execution_blocks` with named examples, rather than a new API surface.
|
||||
|
||||
## Code Issues
|
||||
|
||||
### Behavioral change: `outputs` -> `intermediate_outputs` in traversal
|
||||
|
||||
`modular_pipeline.py` — In `SequentialPipelineBlocks.get_execution_blocks`, the old `_traverse_trigger_blocks` tracked `block.outputs` to propagate available values to downstream blocks. The new code tracks `block.intermediate_outputs` instead:
|
||||
|
||||
```python
|
||||
# Old
|
||||
if hasattr(block, "outputs"):
|
||||
for out in block.outputs:
|
||||
active_inputs[out.name] = True
|
||||
|
||||
# New
|
||||
if hasattr(block, "intermediate_outputs"):
|
||||
for out in block.intermediate_outputs:
|
||||
active_inputs[out.name] = True
|
||||
```
|
||||
|
||||
`intermediate_outputs` and `outputs` can differ — `intermediate_outputs` includes values passed between blocks in the pipeline state, while `outputs` are the final outputs. This could change which downstream conditional blocks get triggered. If this is intentional, it should be called out explicitly in the PR description since it affects existing behavior.
|
||||
|
||||
### `_workflow_map` on base class, implementations only on `SequentialPipelineBlocks`
|
||||
|
||||
`_workflow_map = None` is defined on `ModularPipelineBlocks` (the base class), but `workflow_names` and `get_workflow()` are only implemented on `SequentialPipelineBlocks`. The base class stubs raise `NotImplementedError`. This is misleading — it suggests workflows *could* be implemented for other block types. If workflows are intentionally only for `SequentialPipelineBlocks`, define `_workflow_map` there and don't add stubs to the base class.
|
||||
|
||||
### `get_execution_blocks` no longer filters None values
|
||||
|
||||
Old code:
|
||||
```python
|
||||
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
```
|
||||
|
||||
New code:
|
||||
```python
|
||||
active_inputs = dict(kwargs)
|
||||
```
|
||||
|
||||
This is a behavioral change to the public `get_execution_blocks` API. The old code explicitly stripped `None` values so users could write `get_execution_blocks(prompt="a cat", image=None)` and `image` wouldn't trigger anything. The new code passes `None` through. It happens to still work because `select_block` checks `is not None` internally, but callers can no longer rely on the documented filtering behavior. This should be noted.
|
||||
|
||||
### `default_block_name` changed from property to instance attribute
|
||||
|
||||
In `AutoPipelineBlocks`, `default_block_name` was a `@property` that derived the default from `block_trigger_inputs` on every access. It's now set as an instance attribute in `__init__`. This is mostly fine, but the new code also adds a validation that `default_block_name is not None` raises an error before it's set — so subclasses that accidentally set `default_block_name` as a class attribute will now break. This is a stricter contract that should be documented.
|
||||
|
||||
### Typo
|
||||
|
||||
`modular_pipeline.py` — `# currentlyonly ConditionalPipelineBlocks` should be `# currently only`.
|
||||
|
||||
### `_get_trigger_inputs()` called multiple times in `__repr__`
|
||||
|
||||
In `SequentialPipelineBlocks.__repr__`, `self._get_trigger_inputs()` is called 3 times (condition check, trigger inputs display, example input). This recursively traverses all blocks each time. Should be computed once and reused.
|
||||
|
||||
### Duplicate `format_workflow` calls in `__repr__` and `doc`
|
||||
|
||||
Both `SequentialPipelineBlocks.__repr__` and `SequentialPipelineBlocks.doc` build the description + workflow string independently with identical logic:
|
||||
|
||||
```python
|
||||
description = self.description
|
||||
if self._workflow_map is not None:
|
||||
workflow_str = format_workflow(self._workflow_map)
|
||||
description = f"{self.description}\n\n{workflow_str}"
|
||||
```
|
||||
|
||||
This should be extracted into a property or helper.
|
||||
|
||||
### No tests
|
||||
|
||||
The PR description mentions "I will add a test suite for this too!" but there are no tests included. Workflow resolution, edge cases (empty workflow map, missing workflow name, workflows with overlapping triggers), and the `get_execution_blocks` refactoring should all be tested before merge.
|
||||
|
||||
## Refactoring Quality
|
||||
|
||||
The refactoring of `get_execution_blocks` from a monolithic method on `SequentialPipelineBlocks` into separate implementations on `ConditionalPipelineBlocks` and `SequentialPipelineBlocks` is a good separation of concerns. Moving `combine_inputs`/`combine_outputs` to module-level functions is also reasonable since they don't depend on instance state.
|
||||
|
||||
The improved `AutoPipelineBlocks` docstring with the example is a significant documentation improvement.
|
||||
97
pr_review/13075.md
Normal file
97
pr_review/13075.md
Normal file
@@ -0,0 +1,97 @@
|
||||
I'll review this PR that addresses PyTorch version compatibility for distributed operations.
|
||||
|
||||
## Summary of Changes
|
||||
|
||||
The PR refactors the `gather_size_by_comm` function in `_modeling_parallel.py` to handle PyTorch versions prior to 2.6 that don't have the `torch.accelerator` API. The changes replace a single ternary expression with a multi-level conditional that:
|
||||
|
||||
1. First checks if "cpu" is in the backend string
|
||||
2. Then checks if `torch.accelerator` exists (PyTorch >= 2.6)
|
||||
3. Falls back to CUDA as a default device
|
||||
|
||||
## Potential Issues or Bugs
|
||||
|
||||
**1. Device Type Inconsistency**
|
||||
The original code returns a string `"cpu"` but the new code returns `torch.device("cuda")` objects. This inconsistency could cause issues:
|
||||
|
||||
```python
|
||||
gather_device = "cpu" # str
|
||||
# vs
|
||||
gather_device = torch.device("cuda") # torch.device object
|
||||
```
|
||||
|
||||
**Recommendation:** Use `torch.device()` consistently:
|
||||
```python
|
||||
if "cpu" in comm_backends:
|
||||
gather_device = torch.device("cpu")
|
||||
elif hasattr(torch, "accelerator"):
|
||||
acc = torch.accelerator.current_accelerator()
|
||||
gather_device = torch.device(acc if acc is not None else "cuda")
|
||||
else:
|
||||
gather_device = torch.device("cuda")
|
||||
```
|
||||
|
||||
**2. Unclear Accelerator Return Behavior**
|
||||
The comment states "Fall back to CUDA when no accelerator is returned" but it's unclear when `torch.accelerator.current_accelerator()` would return `None`. This should be verified or documented.
|
||||
|
||||
**3. Missing Type Information**
|
||||
What type does `torch.accelerator.current_accelerator()` return? If it returns a string like `"cuda"` or `"mps"`, the code should handle it consistently. If it returns a device object, the logic might need adjustment.
|
||||
|
||||
## Code Quality Observations
|
||||
|
||||
**Positive:**
|
||||
- Clear comments explaining the fallback logic
|
||||
- Proper use of `hasattr()` for backward compatibility
|
||||
- Addresses the reported issue #13074
|
||||
|
||||
**Areas for Improvement:**
|
||||
|
||||
1. **Device type consistency** (mentioned above)
|
||||
|
||||
2. **Consider alternative hardware accelerators:** The fallback to CUDA might not be appropriate for all systems (e.g., MPS on macOS, XPU on Intel). Consider:
|
||||
```python
|
||||
else:
|
||||
# Fallback for PyTorch < 2.6
|
||||
if torch.cuda.is_available():
|
||||
gather_device = torch.device("cuda")
|
||||
else:
|
||||
gather_device = torch.device("cpu")
|
||||
```
|
||||
|
||||
3. **Code style:** The expanded conditional is more readable but could benefit from extracting into a helper function if this pattern appears elsewhere:
|
||||
```python
|
||||
def _get_gather_device(comm_backends: str) -> torch.device:
|
||||
"""Determine device for distributed gather operations."""
|
||||
# ... implementation
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
No significant security issues identified. This is primarily a compatibility fix for internal device selection logic.
|
||||
|
||||
## Suggestions for Improvement
|
||||
|
||||
1. **Add a test case** to verify behavior on PyTorch < 2.6 (if not already covered)
|
||||
|
||||
2. **Document the behavior** more explicitly:
|
||||
```python
|
||||
# Determine gather device based on backend and PyTorch version
|
||||
# Priority: CPU backend > torch.accelerator (>= 2.6) > CUDA fallback (< 2.6)
|
||||
```
|
||||
|
||||
3. **Consider this more defensive approach:**
|
||||
```python
|
||||
if "cpu" in comm_backends:
|
||||
gather_device = torch.device("cpu")
|
||||
elif hasattr(torch, "accelerator"):
|
||||
acc = torch.accelerator.current_accelerator()
|
||||
gather_device = torch.device(acc if acc else "cuda")
|
||||
elif torch.cuda.is_available():
|
||||
gather_device = torch.device("cuda")
|
||||
else:
|
||||
# Fallback to CPU if no GPU available
|
||||
gather_device = torch.device("cpu")
|
||||
```
|
||||
|
||||
## Verdict
|
||||
|
||||
The PR addresses the compatibility issue but has a **type inconsistency bug** that should be fixed before merging. The string vs `torch.device` object mismatch could cause runtime errors. Once that's addressed, the change is sound for backward compatibility.
|
||||
66
pr_review/13116.md
Normal file
66
pr_review/13116.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# PR #13116: [tests] tests for `modules_to_not_convert`
|
||||
|
||||
**Author:** @sayakpaul
|
||||
**Branch:** `fix-modules-no-convert-torchao` -> `main`
|
||||
**Files changed:** `tests/models/testing_utils/quantization.py`, `tests/models/transformers/test_models_transformer_flux.py`
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
This PR fixes the `modules_to_not_convert` tests that were effectively dead code. They existed in the base `QuantizationTesterMixin` but never ran because no test class defined `modules_to_not_convert_for_test`. The PR activates these tests for Flux and fixes several underlying bugs that would have caused them to fail.
|
||||
|
||||
## Key Changes
|
||||
|
||||
1. **BnB config key fix**: `BitsAndBytesConfig` uses `llm_int8_skip_modules`, not `modules_to_not_convert`. The base test was setting the wrong key, so modules were never actually excluded.
|
||||
|
||||
2. **TorchAO `_verify_if_layer_quantized` fix**: Previously only checked `isinstance(module, torch.nn.Linear)`, which is always true for TorchAO (it doesn't replace the module class). Now properly checks weight tensor types (`AffineQuantizedTensor`, `LinearActivationQuantizedTensor`).
|
||||
|
||||
3. **`_is_module_quantized` fix**: Now passes `quant_config_kwargs` to `_verify_if_layer_quantized`. Previously it passed `{}`, which caused BnB to always check for `Int8Params` even on 4-bit models.
|
||||
|
||||
4. **Cleanup**: Removes unused guard blocks (`is_gguf_available`, `is_torchao_available`) that only contained `pass`.
|
||||
|
||||
5. **Activates tests**: Adds `modules_to_not_convert_for_test` returning `["norm_out.linear"]` to BnB, Quanto, TorchAo, and ModelOpt Flux test classes.
|
||||
|
||||
## Issues
|
||||
|
||||
### `to_not_convert_key` parameter pollutes the base class interface
|
||||
|
||||
`quantization.py:271-273` — The new `to_not_convert_key` parameter on `_test_quantization_modules_to_not_convert` exists solely for BnB's naming quirk (`llm_int8_skip_modules` vs `modules_to_not_convert`). Every other backend uses the default. This leaks a BnB-specific detail into the shared base method.
|
||||
|
||||
BnB already has its own `test_bnb_modules_to_not_convert` that could handle the key translation locally — either by building the correct `config_kwargs` with `llm_int8_skip_modules` before calling `_create_quantized_model` directly, or by overriding the test. This keeps the base method clean and isolates BnB's naming quirk in `BitsAndBytesTesterMixin` where it belongs.
|
||||
|
||||
### Code duplication in TorchAO `test_torchao_modules_to_not_convert`
|
||||
|
||||
`quantization.py:915-950` — The TorchAO test inlines ~30 lines from `_test_quantization_modules_to_not_convert` to skip the memory footprint comparison. If the base method is updated in the future, this copy won't get the fix. Consider parameterizing the base method instead:
|
||||
|
||||
```python
|
||||
def _test_quantization_modules_to_not_convert(
|
||||
self, config_kwargs, modules_to_not_convert, check_memory_footprint=True,
|
||||
):
|
||||
# ... existing module-walking logic ...
|
||||
|
||||
if check_memory_footprint:
|
||||
# Compare memory footprint with fully quantized model
|
||||
...
|
||||
```
|
||||
|
||||
Then TorchAO could simply call:
|
||||
```python
|
||||
self._test_quantization_modules_to_not_convert(
|
||||
TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude,
|
||||
check_memory_footprint=False,
|
||||
)
|
||||
```
|
||||
|
||||
### TorchAO imports inside method body
|
||||
|
||||
`quantization.py:822-823` — The `torchao` imports are placed inside `_verify_if_layer_quantized`. While functional (avoids import errors when torchao isn't installed), these could be placed at module level under the existing `is_torchao_available()` guard for consistency with how `bnb` and `QLinear` imports are handled. Minor style point.
|
||||
|
||||
### `_is_module_quantized` callers not updated
|
||||
|
||||
`quantization.py:368` — The `_test_dequantize` method still calls `self._is_module_quantized(module)` without `quant_config_kwargs`. This happens to work correctly (for BnB, checking `Int8Params` after dequantization correctly returns False; for TorchAO, the weight won't be an `AffineQuantizedTensor`), but it means BnB dequantize for 4-bit models asserts the weight is not `Int8Params` rather than asserting it's not `Params4bit`. Consider updating for correctness.
|
||||
|
||||
### Missing GGUF test coverage
|
||||
|
||||
GGUF's `GGUFTesterMixin` doesn't have a `test_gguf_modules_to_not_convert` method. If GGUF is expected to support `modules_to_not_convert`, a test should be added. If not, a comment explaining why would be helpful.
|
||||
144
pr_review/pr_12700_flashpack.md
Normal file
144
pr_review/pr_12700_flashpack.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# PR #12700 — FlashPack Integration Review
|
||||
|
||||
**URL**: https://github.com/huggingface/diffusers/pull/12700
|
||||
**State**: OPEN
|
||||
**Branch**: `flashpack` → `main`
|
||||
|
||||
## Summary
|
||||
|
||||
Adds FlashPack as a new weight serialization format for faster model loading. FlashPack packs model weights into a single contiguous file (`model.flashpack`) that can be loaded efficiently, especially for larger models. The PR integrates it across `ModelMixin` (save/load), `DiffusionPipeline` (save/load/download), and supporting utilities.
|
||||
|
||||
## Files Changed
|
||||
|
||||
- `setup.py` / `dependency_versions_table.py` — add `flashpack` dependency
|
||||
- `src/diffusers/utils/constants.py` — `FLASHPACK_WEIGHTS_NAME`, `FLASHPACK_FILE_EXTENSION`
|
||||
- `src/diffusers/utils/import_utils.py` — `is_flashpack_available()`
|
||||
- `src/diffusers/utils/__init__.py` — re-exports
|
||||
- `src/diffusers/models/model_loading_utils.py` — `load_flashpack_checkpoint()`, dispatch in `load_state_dict()`
|
||||
- `src/diffusers/models/modeling_utils.py` — `save_pretrained(use_flashpack=...)`, `from_pretrained(use_flashpack=..., flashpack_kwargs=...)`
|
||||
- `src/diffusers/pipelines/pipeline_utils.py` — pipeline-level `save_pretrained`, `from_pretrained`, `download` with `use_flashpack`
|
||||
- `src/diffusers/pipelines/pipeline_loading_utils.py` — `load_sub_model`, `_get_ignore_patterns`, `get_class_obj_and_candidates`, `filter_model_files`
|
||||
|
||||
---
|
||||
|
||||
## Issues
|
||||
|
||||
### 1. `use_flashpack=True` default in `DiffusionPipeline.download()`
|
||||
|
||||
```python
|
||||
# pipeline_utils.py, in download()
|
||||
use_flashpack = kwargs.pop("use_flashpack", True)
|
||||
```
|
||||
|
||||
This defaults to `True`, meaning `download()` will always try to download FlashPack files by default. Every other call site defaults to `False`. This looks like a bug — it would change download behavior for all users even if they never asked for FlashPack. Should be `False`.
|
||||
|
||||
### 2. `load_flashpack_checkpoint` is unused in the `from_pretrained` path
|
||||
|
||||
`load_flashpack_checkpoint()` is added to `model_loading_utils.py` and wired into `load_state_dict()`. However, in `ModelMixin.from_pretrained`, when `use_flashpack=True`, the code **early-returns** after calling `flashpack.mixin.assign_from_file()` directly — it never goes through `load_state_dict()`. So `load_flashpack_checkpoint` is dead code in the `from_pretrained` flow. Either:
|
||||
- Remove it if FlashPack always uses its own assign path, or
|
||||
- Use it consistently (load state dict → assign to model, like safetensors/pickle)
|
||||
|
||||
### 3. `resolved_model_file` may be undefined when `use_flashpack=True` and file fetch fails
|
||||
|
||||
```python
|
||||
# modeling_utils.py, from_pretrained
|
||||
elif use_flashpack:
|
||||
try:
|
||||
resolved_model_file = _get_model_file(...)
|
||||
except IOError as e:
|
||||
logger.error(...)
|
||||
if not allow_pickle:
|
||||
raise
|
||||
logger.warning("Defaulting to unsafe serialization...")
|
||||
```
|
||||
|
||||
If the `IOError` is caught and `allow_pickle` is truthy, `resolved_model_file` is never set but is used later at `flashpack.mixin.assign_from_file(model=model, path=resolved_model_file[0], ...)`. This would crash with `NameError` or `UnboundLocalError`. The fallback logic (copied from the safetensors block) doesn't make sense for FlashPack — there's no pickle fallback for FlashPack. The `except` block should just re-raise unconditionally.
|
||||
|
||||
### 4. `resolved_model_file[0]` assumes a list, but `_get_model_file` returns a string
|
||||
|
||||
```python
|
||||
flashpack.mixin.assign_from_file(
|
||||
model=model,
|
||||
path=resolved_model_file[0], # indexing into a string
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
`_get_model_file` returns a single file path (string), not a list. `resolved_model_file[0]` would give the first character of the path. Should be just `resolved_model_file`.
|
||||
|
||||
### 5. `device_map` handling assumes `device_map[""]` exists
|
||||
|
||||
```python
|
||||
flashpack_device = device_map[""]
|
||||
```
|
||||
|
||||
`device_map` can be a dict with arbitrary keys (layer names, module names), not just `{"": device}`. This would raise `KeyError` for any non-trivial device map. Should handle the general case or document the constraint.
|
||||
|
||||
### 6. `FlashPack` prefix stripping in `get_class_obj_and_candidates` is unexplained
|
||||
|
||||
```python
|
||||
if class_name.startswith("FlashPack"):
|
||||
class_name = class_name.removeprefix("FlashPack")
|
||||
```
|
||||
|
||||
This is injected into a general-purpose utility function with no explanation of when/why a class name would have a `FlashPack` prefix. This seems like it handles a specific config format but there's no corresponding code that writes `FlashPack`-prefixed class names. If this is for some external convention, it should be documented. If not needed, remove it.
|
||||
|
||||
### 7. Duplicated availability check pattern
|
||||
|
||||
The `is_flashpack_available()` check + import + error message pattern is repeated 3 times:
|
||||
- `load_flashpack_checkpoint()` in `model_loading_utils.py`
|
||||
- `save_pretrained()` in `modeling_utils.py`
|
||||
- `from_pretrained()` in `modeling_utils.py`
|
||||
|
||||
Each has slightly different wording. Should be consolidated — e.g., a helper or just use a single `require_flashpack()` function, consistent with how other optional deps are handled.
|
||||
|
||||
### 8. `save_pretrained` error message says "load" instead of "save"
|
||||
|
||||
```python
|
||||
# modeling_utils.py, save_pretrained, use_flashpack=True branch
|
||||
raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.")
|
||||
```
|
||||
|
||||
This is in the **save** path, but the message says "load". Should say "save".
|
||||
|
||||
### 9. No `config.json` saved alongside FlashPack weights in `save_pretrained`
|
||||
|
||||
When `use_flashpack=True` in `ModelMixin.save_pretrained`, the model config is saved normally at the top of the method, but the FlashPack branch calls `flashpack.serialization.pack_to_file()` with `target_dtype=self.dtype`. It's not clear if FlashPack's own `config.json` (mentioned in the benchmark script as `flashpack_config.json`) is the same as diffusers' `config.json`. If they're different files, loading back with `from_pretrained(use_flashpack=True)` might fail to reconstruct the model architecture since `from_config` needs the diffusers config.
|
||||
|
||||
### 10. `output_loading_info` warning placement
|
||||
|
||||
```python
|
||||
if output_loading_info:
|
||||
logger.warning("`output_loading_info` is not supported with FlashPack.")
|
||||
return model, {}
|
||||
```
|
||||
|
||||
This returns an empty dict silently. The warning is fine, but returning `{}` instead of a proper `loading_info` structure (with `missing_keys`, `unexpected_keys`, etc.) could break code that destructures the result.
|
||||
|
||||
### 11. No tests included
|
||||
|
||||
The PR has no test files. At minimum there should be:
|
||||
- Unit tests for `load_flashpack_checkpoint` (mocking `flashpack`)
|
||||
- Unit tests for save/load roundtrip with `use_flashpack=True`
|
||||
- Integration test for pipeline save/load
|
||||
|
||||
### 12. FlashPack doesn't support sharding
|
||||
|
||||
The `save_pretrained` FlashPack branch ignores `max_shard_size` entirely and always saves a single file. This is fine for the format but should either:
|
||||
- Log a warning if `max_shard_size` is explicitly set alongside `use_flashpack=True`
|
||||
- Document this limitation
|
||||
|
||||
---
|
||||
|
||||
## Minor Issues
|
||||
|
||||
- The benchmark in the PR description shows FlashPack is actually **slower** for fp16 SD v1.5 (0.95x). The claimed benefit is only for bf16. This should be prominently noted.
|
||||
- `FLASHPACK_WEIGHTS_NAME = "model.flashpack"` breaks the diffusers naming convention (`diffusion_pytorch_model.*` for other formats).
|
||||
- The PR modifies `_get_ignore_patterns` but doesn't handle the case where both `use_safetensors` and `use_flashpack` are True.
|
||||
- `filter_model_files` adds `FLASHPACK_WEIGHTS_NAME` to the known weights list but there are no corresponding tests for this filtering.
|
||||
|
||||
---
|
||||
|
||||
## Verdict
|
||||
|
||||
The PR needs significant work before it's mergeable. The critical issues are the `use_flashpack=True` default in `download()`, the `resolved_model_file[0]` indexing bug, the dead code path with `load_flashpack_checkpoint`, and the lack of tests. The integration pattern also doesn't feel consistent with how other formats (safetensors, GGUF) are integrated — FlashPack bypasses the standard state dict loading path entirely via its own `assign_from_file`, making it a special case that's harder to maintain.
|
||||
286
pr_review/teacache_pr_12652_review.md
Normal file
286
pr_review/teacache_pr_12652_review.md
Normal file
@@ -0,0 +1,286 @@
|
||||
# TeaCache PR #12652 Review Notes
|
||||
|
||||
## PR Overview
|
||||
|
||||
- **PR**: https://github.com/huggingface/diffusers/pull/12652
|
||||
- **Title**: Implement TeaCache
|
||||
- **Author**: LawJarp-A (Prajwal A)
|
||||
- **Status**: Open
|
||||
- **Changes**: +1335 / -22 lines across 6 files
|
||||
|
||||
### What is TeaCache?
|
||||
|
||||
[TeaCache](https://huggingface.co/papers/2411.19108) (Timestep Embedding Aware Cache) is a training-free caching technique that speeds up diffusion model inference by **1.5x-2.6x** by reusing transformer block computations when consecutive timestep embeddings are similar.
|
||||
|
||||
### Algorithm
|
||||
|
||||
1. Extract modulated input from first transformer block (after norm1 + timestep embedding)
|
||||
2. Compute relative L1 distance vs previous timestep
|
||||
3. Apply model-specific polynomial rescaling: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]`
|
||||
4. Accumulate rescaled distance across timesteps
|
||||
5. If accumulated < threshold → Reuse cached residual (FAST)
|
||||
6. If accumulated >= threshold → Full transformer pass (SLOW, update cache)
|
||||
|
||||
---
|
||||
|
||||
## The Mid-Forward Intercept Problem
|
||||
|
||||
### Why TeaCache is Model-Specific
|
||||
|
||||
TeaCache needs to intercept **within** a model's forward method, not just at module boundaries:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Model Forward │
|
||||
│ │
|
||||
│ PREPROCESSING (must always run) │
|
||||
│ ├── x_embedder(hidden_states) │
|
||||
│ ├── time_text_embed(timestep, ...) │
|
||||
│ └── context_embedder(encoder_hidden_states) │
|
||||
│ │
|
||||
│ ═══════════════════════════════════════════════════════════│
|
||||
│ DECISION POINT ◄── TeaCache needs to intercept HERE │
|
||||
│ └── Extract: transformer_blocks[0].norm1(hs, temb)[0] │
|
||||
│ ═══════════════════════════════════════════════════════════│
|
||||
│ │
|
||||
│ CACHEABLE REGION (can be skipped if cached) │
|
||||
│ ├── for block in transformer_blocks: ... │
|
||||
│ └── for block in single_transformer_blocks: ... │
|
||||
│ │
|
||||
│ POSTPROCESSING (must always run) │
|
||||
│ ├── norm_out(hidden_states, temb) │
|
||||
│ └── proj_out(hidden_states) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
PyTorch hooks only intercept at **module boundaries** (before/after `forward()`), not within a forward method. The `for` loop over blocks is Python control flow - there's no hook point to skip it.
|
||||
|
||||
### Workaround: Custom Forward Replacement
|
||||
|
||||
The PR replaces the entire model forward with a custom implementation that has cache logic inserted at the right point. This works but requires maintaining separate forward functions for each model.
|
||||
|
||||
---
|
||||
|
||||
## Comparison of Caching Approaches
|
||||
|
||||
### TeaCache vs FirstBlockCache vs FasterCache
|
||||
|
||||
| Aspect | TeaCache | FirstBlockCache | FasterCache |
|
||||
|--------|----------|-----------------|-------------|
|
||||
| **Hook target** | Model forward | Transformer blocks | Attention layers |
|
||||
| **Decision signal** | Modulated input (norm1 output) | Block output residual | Iteration count |
|
||||
| **Where signal is** | Inside first block | Block boundary | Attention output |
|
||||
| **Model-specific needs** | norm1 structure | Block output format | Attention class type |
|
||||
| **Model-agnostic?** | ❌ No | ✅ Yes | ✅ Yes |
|
||||
|
||||
### Why FirstBlockCache is Model-Agnostic
|
||||
|
||||
FirstBlockCache uses the **first block's output residual** as its signal:
|
||||
|
||||
```python
|
||||
# FirstBlockCache: hooks individual blocks
|
||||
def new_forward(self, module, *args, **kwargs):
|
||||
original_hidden_states = args[0]
|
||||
output = self.fn_ref.original_forward(*args, **kwargs) # Run block fully
|
||||
residual = output - original_hidden_states # Signal from OUTPUT
|
||||
should_compute = self._compare_residual(residual)
|
||||
...
|
||||
```
|
||||
|
||||
It doesn't need to understand block internals - just input and output.
|
||||
|
||||
### Why FasterCache is Model-Agnostic
|
||||
|
||||
FasterCache hooks **attention layers** (not blocks) using class type checking:
|
||||
|
||||
```python
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
if isinstance(submodule, _ATTENTION_CLASSES):
|
||||
# Hook this attention module
|
||||
```
|
||||
|
||||
All transformer models use standardized attention classes.
|
||||
|
||||
---
|
||||
|
||||
## Model Architecture Analysis
|
||||
|
||||
### Models That Fit TeaCache Pattern
|
||||
|
||||
Models with `norm1(hidden_states, temb)` returning modulated input:
|
||||
|
||||
| Model | norm1 Signature | Modulation Location | Single Residual |
|
||||
|-------|----------------|---------------------|-----------------|
|
||||
| FLUX 1 | `norm1(hs, emb=temb) → (tensor, gate)` | Inside norm1 | ✅ |
|
||||
| FLUX Kontext | `norm1(hs, emb=temb) → (tensor, gate)` | Inside norm1 | ✅ |
|
||||
| Mochi | `norm1(hs, temb) → (tensor, g, s, g)` | Inside norm1 | ✅ |
|
||||
| Lumina2 | `norm1(hs, temb) → (tensor, gate)` | Inside norm1 | ✅ |
|
||||
|
||||
### Models That DON'T Fit Pattern
|
||||
|
||||
| Model | norm1 Signature | Modulation Location | Issue |
|
||||
|-------|----------------|---------------------|-------|
|
||||
| **FLUX 2** | `norm1(hs) → tensor` | Outside norm1 | Plain LayerNorm |
|
||||
| **Wan** | `norm1(hs) → tensor` | Outside norm1 | Plain LayerNorm |
|
||||
| **ZImage** | `attention_norm1(x) → tensor` | Outside norm1 | Plain LayerNorm |
|
||||
| **CogVideoX** | N/A (uses `emb` directly) | N/A | Dual residual needed |
|
||||
|
||||
### FLUX 1 vs FLUX 2 Architecture Difference
|
||||
|
||||
**FLUX 1** (AdaLayerNorm - modulation inside):
|
||||
```python
|
||||
class FluxTransformerBlock:
|
||||
self.norm1 = AdaLayerNormZero(dim) # Takes temb!
|
||||
|
||||
def forward(self, hidden_states, temb, ...):
|
||||
norm_hs, gate = self.norm1(hidden_states, emb=temb) # Modulation inside
|
||||
```
|
||||
|
||||
**FLUX 2** (Plain LayerNorm - modulation outside):
|
||||
```python
|
||||
class Flux2TransformerBlock:
|
||||
self.norm1 = nn.LayerNorm(dim) # NO temb!
|
||||
|
||||
def forward(self, hidden_states, temb_mod_params_img, ...):
|
||||
(shift_msa, scale_msa, gate_msa), ... = temb_mod_params_img
|
||||
norm_hs = self.norm1(hidden_states) # Plain norm
|
||||
norm_hs = (1 + scale_msa) * norm_hs + shift_msa # Modulation outside
|
||||
```
|
||||
|
||||
FLUX 2 follows the Wan/ZImage pattern and would need a separate custom forward.
|
||||
|
||||
---
|
||||
|
||||
## CogVideoX: The Architectural Outlier
|
||||
|
||||
CogVideoX has two unique requirements that don't fit the pattern:
|
||||
|
||||
### 1. Different Modulated Input Source
|
||||
|
||||
```python
|
||||
# Other models: extract from norm1
|
||||
modulated_inp = block.norm1(hidden_states, temb)[0]
|
||||
|
||||
# CogVideoX: uses timestep embedding directly
|
||||
modulated_inp = emb # Just the embedding, computed before blocks!
|
||||
```
|
||||
|
||||
### 2. Dual Residual Caching
|
||||
|
||||
CogVideoX blocks return and modify TWO tensors:
|
||||
```python
|
||||
def forward(self, hidden_states, encoder_hidden_states, temb, ...):
|
||||
# Both are modified!
|
||||
return hidden_states, encoder_hidden_states
|
||||
```
|
||||
|
||||
Requires caching two residuals:
|
||||
```python
|
||||
state.previous_residual = hs_output - hs_input
|
||||
state.previous_residual_encoder = enc_output - enc_input # Extra!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Recommendations
|
||||
|
||||
### Simplification: FLUX-Only Support
|
||||
|
||||
Given the architectural diversity, recommend supporting only FLUX 1 and FLUX Kontext initially:
|
||||
|
||||
```python
|
||||
_MODEL_CONFIG = {
|
||||
"FluxKontext": {
|
||||
"forward_func": _flux_teacache_forward,
|
||||
"coefficients": [-1.04655119e03, 3.12563399e02, -1.69500694e01, 4.10995971e-01, 3.74537863e-02],
|
||||
},
|
||||
"Flux": {
|
||||
"forward_func": _flux_teacache_forward,
|
||||
"coefficients": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### What to Remove from PR
|
||||
|
||||
1. **CogVideoX support** - Dual residual architecture doesn't fit
|
||||
2. **Mochi support** - Can be added later if needed
|
||||
3. **Lumina2 support** - Can be added later if needed
|
||||
4. **FLUX 2 support** - Different architecture (plain LayerNorm)
|
||||
|
||||
### Estimated Code Reduction
|
||||
|
||||
| Component | Original (PR) | FLUX-Only |
|
||||
|-----------|---------------|-----------|
|
||||
| Forward functions | 4 (~400 lines) | 1 (~100 lines) |
|
||||
| Model configs | 10 entries | 2 entries |
|
||||
| State fields | 8 | 5 |
|
||||
| Utility functions | 6 | 3 |
|
||||
| **Total teacache.py** | ~900 lines | ~350 lines |
|
||||
|
||||
### Simplified State
|
||||
|
||||
```python
|
||||
class TeaCacheState(BaseState):
|
||||
def __init__(self):
|
||||
self.cnt = 0
|
||||
self.num_steps = 0
|
||||
self.accumulated_rel_l1_distance = 0.0
|
||||
self.previous_modulated_input = None
|
||||
self.previous_residual = None
|
||||
# Removed: previous_residual_encoder (CogVideoX)
|
||||
# Removed: cache_dict (Lumina2)
|
||||
# Removed: uncond_seq_len (Lumina2)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Why Custom Forwards Are Necessary
|
||||
|
||||
Despite the maintenance burden, custom forwards are the pragmatic approach for TeaCache because:
|
||||
|
||||
1. **Mid-forward intercept required** - Need to access `norm1` output before blocks run
|
||||
2. **Architectural diversity** - Models differ in where/how modulation happens
|
||||
3. **Block-level hooks insufficient** - Can't extract modulated input from block hooks
|
||||
4. **Algorithm requirements** - TeaCache paper specifically uses modulated input as signal
|
||||
|
||||
### Alternative Approaches Considered
|
||||
|
||||
| Approach | Works? | Issue |
|
||||
|----------|--------|-------|
|
||||
| Block-level hooks (like FirstBlockCache) | ❌ | Can't access modulated input inside block |
|
||||
| Attention-level hooks (like FasterCache) | ❌ | Different algorithm, not TeaCache |
|
||||
| Hook norm1 directly | ⚠️ | norm1 interface varies per model |
|
||||
| Hybrid (FirstBlockCache signal + TeaCache algorithm) | ⚠️ | Loses "optimal" signal per paper |
|
||||
|
||||
---
|
||||
|
||||
## PR Code Quality Issues (From Review)
|
||||
|
||||
1. **torch.compile incompatibility** - `.item()` calls in `_compute_rel_l1_distance` create graph breaks
|
||||
2. **Boundary check bug** - `state.cnt == state.num_steps - 1` when `num_steps=0` evaluates to `-1`
|
||||
3. **Incomplete Lumina2 state reset** - `cache_dict` and `uncond_seq_len` not reset
|
||||
4. **Model auto-detection fragility** - Substring matching relies on iteration order
|
||||
|
||||
---
|
||||
|
||||
## Extension Path
|
||||
|
||||
If support for additional models is needed later:
|
||||
|
||||
1. **Mochi** - Same pattern as FLUX, just add coefficients and reuse `_flux_teacache_forward` or create similar
|
||||
2. **Lumina2** - Same pattern but needs per-sequence-length caching for CFG
|
||||
3. **FLUX 2 / Wan / ZImage** - Need separate forwards that extract modulated input differently
|
||||
4. **CogVideoX** - Needs dual residual support, significant additional complexity
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
- **TeaCache requires custom forwards** due to mid-forward intercept requirement
|
||||
- **FLUX 1 + FLUX Kontext only** is the recommended scope for initial implementation
|
||||
- **~60% code reduction** possible by removing unsupported models
|
||||
- **Clear extension path** for adding models later as needed
|
||||
- **Maintenance burden** is acceptable given the architectural constraints
|
||||
129
release_notes/v0.37.0.md
Normal file
129
release_notes/v0.37.0.md
Normal file
@@ -0,0 +1,129 @@
|
||||
# Diffusers v0.37.0 Release Notes
|
||||
|
||||
*Release based on 191 commits since v0.36.0*
|
||||
|
||||
---
|
||||
|
||||
## Highlights
|
||||
|
||||
- **Modular Pipelines overhaul**: Major investment in the modular pipeline system with explicit workflow support, improved loaders, documentation, and modular implementations for Wan, Flux2, Z-Image, Qwen, and Mellon pipelines.
|
||||
- **New pipelines and models**: Cosmos Predict2.5, LTX 2.0 Video, LongCat-Image, Fibo Edit, Z-Image Omni Base, and more.
|
||||
- **Distributed inference improvements**: Unified Sequence Parallel attention, Ulysses Anything Attention, and context parallel support in native flash attention.
|
||||
- **Python 3.8 dropped**: Sunset Python 3.8 and cleaned up explicit `typing` exports.
|
||||
|
||||
---
|
||||
|
||||
## New Pipelines and Models
|
||||
|
||||
- **Cosmos Predict2.5**: Base inference pipeline, scheduler, and checkpoint conversion; 14B model support (#12852, #12863)
|
||||
- **Cosmos Transfer2.5**: General transfer pipelines for segmentation, depth, blur, and edge (#13066)
|
||||
- **LTX 2.0 Video Pipelines**: New video generation pipelines (#12915), distilled checkpoint support (#12934), single-file loading (#12983), LoRA support (#12933), long multi-prompt (#12614)
|
||||
- **LongCat-Image**: New pipeline with offloading/quantization support and regional compile acceleration (#12828, #12963, #12699, #13019, #13021)
|
||||
- **Fibo Edit Pipeline**: New editing pipeline (#12930)
|
||||
- **Z-Image Omni Base**: New implementation (#12857)
|
||||
- **Z-Image Turbo ControlNet**: ControlNet support for Z-Image Turbo (#12792)
|
||||
- **Z-Image Inpaint Pipeline**: Inpainting support (#13006)
|
||||
- **Z-Image ControlNet CFG**: CFG support for Z-Image ControlNet (#13080)
|
||||
- **Chroma Inpaint Pipeline**: New inpainting pipeline for Chroma (#12848)
|
||||
- **Flux2 Klein**: New model variant (#12982)
|
||||
- **Qwen Image Edit 2511**: New editing support (#12839)
|
||||
- **Qwen Image Layered Support** (#12853)
|
||||
|
||||
## Modular Pipelines
|
||||
|
||||
- Explicit workflow support for modular pipelines (#13028)
|
||||
- Modular implementations for: Wan (#13063), Flux2 (#12763), Z-Image (#12808), Qwen (#12872), Mellon (#12978, #12924, #13051)
|
||||
- Improved loader support (#13025)
|
||||
- Custom block tests (#12557)
|
||||
- Auto-docstring generation and documentation refactors (#12958)
|
||||
- Quick start guide (#13029)
|
||||
- Guard `ModularPipeline.blocks` attribute (#13014)
|
||||
- Better docstrings and template pipeline card (#13072, #12932)
|
||||
|
||||
## Core Improvements
|
||||
|
||||
- **Device-type device maps with offloading support** (#12811)
|
||||
- **`disable_mmap` in pipeline `from_pretrained`** (#12854)
|
||||
- **`apply_lora_scale` helper** to remove boilerplate (#12994)
|
||||
- **MagCache support**: Caching mechanism for faster inference (#12744)
|
||||
- **Mambo-G Guidance**: New guider implementation (#12862)
|
||||
- **Laplace Scheduler for DDPM** (#11320)
|
||||
- **Custom sigmas in UniPCMultistepScheduler** (#12109)
|
||||
- **Control-LoRA support** (#10686)
|
||||
- **Latent Perceptual Loss (LPL) for SDXL** (#11573)
|
||||
- **MultiControlNet support for SD3 Inpainting** (#11251)
|
||||
- Remove 8-bit device restriction (#12972)
|
||||
- Graceful error for unsupported attn-backend / context-parallel combos (#12832)
|
||||
- Handle progress bar and logging in distributed environments (#12806)
|
||||
- Remove unneeded autoencoder methods from `AutoencoderMixin` subclasses (#12873)
|
||||
- Remove k-diffusion support (#13152)
|
||||
- Flag Flax schedulers as deprecated (#13031)
|
||||
|
||||
## Distributed Inference
|
||||
|
||||
- **Unified Sequence Parallel attention** (#12693)
|
||||
- **Ulysses Anything Attention** (#12996)
|
||||
- **Context parallel in native flash attention** (#12829)
|
||||
- NPU Ulysses attention support (#12919)
|
||||
- Fix Wan 2.1 I2V context parallel (#12909)
|
||||
- Fix Qwen-Image context parallel (#12970)
|
||||
|
||||
## LoRA
|
||||
|
||||
- Z-Image LoRA training (#13056)
|
||||
- Fix non-diffusers LoRA key handling for Flux2 (#13119)
|
||||
- Fix LoRA loading for Flux2 Klein with adaptive block enumeration (#13030)
|
||||
- Fix wrong LTX2 LoRA mixin (#13144)
|
||||
|
||||
## Bug Fixes
|
||||
|
||||
- Fix QwenImageEditPlus on NPU (#13017)
|
||||
- Fix MT5Tokenizer → use `T5Tokenizer` for Transformers v5.0+ compatibility (#12877)
|
||||
- Fix Wan/WanI2V patchification (#13038)
|
||||
- Fix LTX-2 inference with `num_videos_per_prompt > 1` and CFG (#13121)
|
||||
- Fix Flux2 img2img prediction (#12855)
|
||||
- Fix QwenImage `txt_seq_lens` handling (#12702)
|
||||
- Fix `prefix_token_len` bug (#12845)
|
||||
- Fix ftfy imports in Wan and SkyReels-V2 (#12314, #13113)
|
||||
- Fix `is_fsdp` determination (#12960)
|
||||
- Fix GLM-Image `get_image_features` API (#13052)
|
||||
- Fix Wan 2.2 when either transformer isn't present (#13055)
|
||||
- Fix guider issue (#13147)
|
||||
- Fix torchao quantizer for new versions (#12901)
|
||||
- Fix GGUF for unquantized types with unquantize kernels (#12498)
|
||||
- Make Qwen hidden states contiguous for torchao (#13081)
|
||||
- Make Flux hidden states contiguous (#13068)
|
||||
- Fix Kandinsky 5 hardcoded CUDA autocast (#12814)
|
||||
- Fix `aiter` availability check (#13059)
|
||||
- Fix attention mask check for unsupported backends (#12892)
|
||||
- Allow `prompt` and `prior_token_ids` simultaneously in `GlmImagePipeline` (#13092)
|
||||
- GLM-Image batch support (#13007)
|
||||
- Cosmos 2.5 Video2World frame extraction fix (#13018)
|
||||
- ResNet: only use contiguous in training mode (#12977)
|
||||
|
||||
## Testing and CI
|
||||
|
||||
- Refactor model tests (#12822)
|
||||
- Refactor Wan model tests (#13082)
|
||||
- Accept `recompile_limit` from user in tests (#13150)
|
||||
- CodeQL workflow for security analysis (#12917)
|
||||
- Upgrade GitHub Actions for Node 24 compatibility (#12865, #12866)
|
||||
- Fix `setuptools` / `pkg_resources` CI bugs (#13129, #13132)
|
||||
- CUDA 12.9 upgrade (#13045)
|
||||
- FSDP option for Flux2 (#12860)
|
||||
|
||||
## Documentation
|
||||
|
||||
- Custom code AutoModel guide (#13099)
|
||||
- Remote inference docs (#12372)
|
||||
- Improved distributed inference docs (#12810, #12827, #12971)
|
||||
- Improved caching docs (#12684)
|
||||
- Numerous scheduler docstring improvements (#12798, #12871, #12928, #12931, #12936, #12992, #13010, #13020, #13023, #13024, #13027, #13044, #13083, #13085, #13122, #13127, #13130)
|
||||
- Various typo and syntax fixes
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
- **Python 3.8 support removed** (#12524)
|
||||
- **k-diffusion removed** (#13152)
|
||||
- **Flax schedulers flagged as deprecated** (#13031)
|
||||
- ControlNet implementations outside the controlnet module removed (#12152)
|
||||
183
scripts/compare_test_coverage.py
Normal file
183
scripts/compare_test_coverage.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compare test coverage between main and model-test-refactor branches
|
||||
for the Flux transformer tests.
|
||||
|
||||
Usage:
|
||||
python scripts/compare_test_coverage.py
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
|
||||
|
||||
TEST_FILE = "tests/models/transformers/test_models_transformer_flux.py"
|
||||
BRANCHES = ["main", "model-test-refactor"]
|
||||
|
||||
|
||||
def run_command(cmd, capture=True):
|
||||
"""Run a shell command and return output."""
|
||||
result = subprocess.run(cmd, shell=True, capture_output=capture, text=True)
|
||||
return result.stdout, result.stderr, result.returncode
|
||||
|
||||
|
||||
def get_current_branch():
|
||||
"""Get the current git branch name."""
|
||||
stdout, _, _ = run_command("git branch --show-current")
|
||||
return stdout.strip()
|
||||
|
||||
|
||||
def stash_changes():
|
||||
"""Stash any uncommitted changes."""
|
||||
run_command("git stash")
|
||||
|
||||
|
||||
def pop_stash():
|
||||
"""Pop stashed changes."""
|
||||
run_command("git stash pop")
|
||||
|
||||
|
||||
def checkout_branch(branch):
|
||||
"""Checkout a git branch."""
|
||||
_, stderr, code = run_command(f"git checkout {branch}")
|
||||
if code != 0:
|
||||
print(f"Failed to checkout {branch}: {stderr}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def collect_tests(test_file):
|
||||
"""Collect tests from a test file and return test info."""
|
||||
cmd = f"python -m pytest {test_file} --collect-only -q 2>/dev/null"
|
||||
stdout, stderr, code = run_command(cmd)
|
||||
|
||||
tests = []
|
||||
for line in stdout.strip().split("\n"):
|
||||
if "::" in line and not line.startswith("="):
|
||||
tests.append(line.strip())
|
||||
|
||||
return tests
|
||||
|
||||
|
||||
def run_tests_verbose(test_file):
|
||||
"""Run tests and capture pass/skip/fail status."""
|
||||
cmd = f"python -m pytest {test_file} -v --tb=no 2>&1"
|
||||
stdout, _, _ = run_command(cmd)
|
||||
|
||||
results = {"passed": [], "skipped": [], "failed": [], "errors": []}
|
||||
|
||||
for line in stdout.split("\n"):
|
||||
if " PASSED" in line:
|
||||
test_name = line.split(" PASSED")[0].strip()
|
||||
results["passed"].append(test_name)
|
||||
elif " SKIPPED" in line:
|
||||
test_name = line.split(" SKIPPED")[0].strip()
|
||||
reason = ""
|
||||
if "SKIPPED" in line and "[" in line:
|
||||
reason = line.split("[")[-1].rstrip("]") if "[" in line else ""
|
||||
results["skipped"].append((test_name, reason))
|
||||
elif " FAILED" in line:
|
||||
test_name = line.split(" FAILED")[0].strip()
|
||||
results["failed"].append(test_name)
|
||||
elif " ERROR" in line:
|
||||
test_name = line.split(" ERROR")[0].strip()
|
||||
results["errors"].append(test_name)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compare_results(main_results, pr_results):
|
||||
"""Compare test results between branches."""
|
||||
print("\n" + "=" * 70)
|
||||
print("COVERAGE COMPARISON REPORT")
|
||||
print("=" * 70)
|
||||
|
||||
print("\n## Test Counts")
|
||||
print(f"{'Category':<20} {'main':<15} {'PR':<15} {'Diff':<10}")
|
||||
print("-" * 60)
|
||||
|
||||
for category in ["passed", "skipped", "failed", "errors"]:
|
||||
main_count = len(main_results[category])
|
||||
pr_count = len(pr_results[category])
|
||||
diff = pr_count - main_count
|
||||
diff_str = f"+{diff}" if diff > 0 else str(diff)
|
||||
print(f"{category:<20} {main_count:<15} {pr_count:<15} {diff_str:<10}")
|
||||
|
||||
main_tests = set(main_results["passed"] + [t[0] for t in main_results["skipped"]])
|
||||
pr_tests = set(pr_results["passed"] + [t[0] for t in pr_results["skipped"]])
|
||||
|
||||
missing_in_pr = main_tests - pr_tests
|
||||
new_in_pr = pr_tests - main_tests
|
||||
|
||||
if missing_in_pr:
|
||||
print("\n## Tests in main but MISSING in PR:")
|
||||
for test in sorted(missing_in_pr):
|
||||
print(f" - {test}")
|
||||
|
||||
if new_in_pr:
|
||||
print("\n## NEW tests in PR (not in main):")
|
||||
for test in sorted(new_in_pr):
|
||||
print(f" + {test}")
|
||||
|
||||
print("\n## Skipped Tests Comparison")
|
||||
main_skipped = {t[0]: t[1] for t in main_results["skipped"]}
|
||||
pr_skipped = {t[0]: t[1] for t in pr_results["skipped"]}
|
||||
|
||||
newly_skipped = set(pr_skipped.keys()) - set(main_skipped.keys())
|
||||
no_longer_skipped = set(main_skipped.keys()) - set(pr_skipped.keys())
|
||||
|
||||
if newly_skipped:
|
||||
print("\nNewly skipped in PR:")
|
||||
for test in sorted(newly_skipped):
|
||||
print(f" - {test}: {pr_skipped.get(test, 'unknown reason')}")
|
||||
|
||||
if no_longer_skipped:
|
||||
print("\nNo longer skipped in PR (now running):")
|
||||
for test in sorted(no_longer_skipped):
|
||||
print(f" + {test}")
|
||||
|
||||
if not newly_skipped and not no_longer_skipped:
|
||||
print("\nNo changes in skipped tests.")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
|
||||
def main():
|
||||
original_branch = get_current_branch()
|
||||
print(f"Current branch: {original_branch}")
|
||||
|
||||
results = {}
|
||||
|
||||
print("Stashing uncommitted changes...")
|
||||
stash_changes()
|
||||
|
||||
try:
|
||||
for branch in BRANCHES:
|
||||
print(f"\n--- Analyzing branch: {branch} ---")
|
||||
|
||||
if not checkout_branch(branch):
|
||||
print(f"Skipping {branch}")
|
||||
continue
|
||||
|
||||
print(f"Collecting and running tests from {TEST_FILE}...")
|
||||
results[branch] = run_tests_verbose(TEST_FILE)
|
||||
|
||||
print(f" Passed: {len(results[branch]['passed'])}")
|
||||
print(f" Skipped: {len(results[branch]['skipped'])}")
|
||||
print(f" Failed: {len(results[branch]['failed'])}")
|
||||
|
||||
checkout_branch(original_branch)
|
||||
|
||||
if "main" in results and "model-test-refactor" in results:
|
||||
compare_results(results["main"], results["model-test-refactor"])
|
||||
else:
|
||||
print("Could not compare - missing results from one or both branches")
|
||||
|
||||
finally:
|
||||
print("\nRestoring stashed changes...")
|
||||
pop_stash()
|
||||
|
||||
checkout_branch(original_branch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -107,38 +107,6 @@ class ConfigMixin:
|
||||
has_compatibles = False
|
||||
|
||||
_deprecated_kwargs = []
|
||||
_auto_class = None
|
||||
|
||||
@classmethod
|
||||
def register_for_auto_class(cls, auto_class="AutoModel"):
|
||||
"""
|
||||
Register this class with the given auto class so that it can be loaded with `AutoModel.from_pretrained(...,
|
||||
trust_remote_code=True)`.
|
||||
|
||||
When the config is saved, the resulting `config.json` will include an `auto_map` entry mapping the auto class
|
||||
to this class's module and class name.
|
||||
|
||||
Args:
|
||||
auto_class (`str` or type, *optional*, defaults to `"AutoModel"`):
|
||||
The auto class to register this class with. Can be a string (e.g. `"AutoModel"`) or the class itself.
|
||||
Currently only `"AutoModel"` is supported.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
|
||||
|
||||
class MyCustomModel(ModelMixin, ConfigMixin): ...
|
||||
|
||||
|
||||
MyCustomModel.register_for_auto_class("AutoModel")
|
||||
```
|
||||
"""
|
||||
if auto_class != "AutoModel":
|
||||
raise ValueError(f"Only 'AutoModel' is supported, got '{auto_class}'.")
|
||||
|
||||
cls._auto_class = auto_class
|
||||
|
||||
def register_to_config(self, **kwargs):
|
||||
if self.config_name is None:
|
||||
@@ -653,12 +621,6 @@ class ConfigMixin:
|
||||
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||
_ = config_dict.pop("_pre_quantization_dtype", None)
|
||||
|
||||
if getattr(self, "_auto_class", None) is not None:
|
||||
module = self.__class__.__module__.split(".")[-1]
|
||||
auto_map = config_dict.get("auto_map", {})
|
||||
auto_map[self._auto_class] = f"{module}.{self.__class__.__name__}"
|
||||
config_dict["auto_map"] = auto_map
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: str | os.PathLike):
|
||||
|
||||
@@ -1707,8 +1707,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
|
||||
)
|
||||
|
||||
self._pretrained_model_name_or_path = pretrained_model_name_or_path
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -2325,16 +2323,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
elif "default" in value:
|
||||
# check if the default is specified
|
||||
component_load_kwargs[key] = value["default"]
|
||||
# Only pass trust_remote_code to components from the same repo as the pipeline.
|
||||
# When a user passes trust_remote_code=True, they intend to trust code from the
|
||||
# pipeline's repo, not from external repos referenced in modular_model_index.json.
|
||||
if (
|
||||
"trust_remote_code" in component_load_kwargs
|
||||
and self._pretrained_model_name_or_path is not None
|
||||
and spec.pretrained_model_name_or_path != self._pretrained_model_name_or_path
|
||||
):
|
||||
component_load_kwargs.pop("trust_remote_code")
|
||||
|
||||
try:
|
||||
components_to_register[name] = spec.load(**component_load_kwargs)
|
||||
except Exception:
|
||||
|
||||
@@ -31,18 +31,14 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
trained_betas (`np.ndarray` or `List[float]`, *optional*):
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
trained_betas: np.ndarray | list[float] | None = None,
|
||||
):
|
||||
def __init__(self, num_train_timesteps: int = 1000, trained_betas: np.ndarray | list[float] | None = None):
|
||||
# set `betas`, `alphas`, `timesteps`
|
||||
self.set_timesteps(num_train_timesteps)
|
||||
|
||||
@@ -60,29 +56,21 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = None
|
||||
|
||||
@property
|
||||
def step_index(self) -> int | None:
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The index counter for current timestep.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> int | None:
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The index for the first timestep.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
@@ -181,7 +169,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
@@ -240,30 +228,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_prev_sample(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep_index: int,
|
||||
prev_timestep_index: int,
|
||||
ets: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Predicts the previous sample based on the current sample, timestep indices, and running model outputs.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The current sample.
|
||||
timestep_index (`int`):
|
||||
Index of the current timestep in the schedule.
|
||||
prev_timestep_index (`int`):
|
||||
Index of the previous timestep in the schedule.
|
||||
ets (`torch.Tensor`):
|
||||
The running sequence of model outputs.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The predicted previous sample.
|
||||
"""
|
||||
def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
|
||||
alpha = self.alphas[timestep_index]
|
||||
sigma = self.betas[timestep_index]
|
||||
|
||||
@@ -275,5 +240,5 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return prev_sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
14
test_automodel_meta.py
Normal file
14
test_automodel_meta.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
|
||||
repo = "meituan-longcat/LongCat-Image"
|
||||
subfolder = "transformer"
|
||||
|
||||
config = AutoModel.load_config(repo, subfolder=subfolder)
|
||||
|
||||
with torch.device("meta"):
|
||||
model = AutoModel.from_config(config)
|
||||
print(f"model.config:")
|
||||
for k, v in dict(model.config).items():
|
||||
if not k.startswith("_"):
|
||||
print(f" {k}: {v}")
|
||||
11
test_dataclass_config.py
Normal file
11
test_dataclass_config.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import dataclasses
|
||||
from diffusers import AutoModel, LongCatImageTransformer2DModel
|
||||
|
||||
config_dict = AutoModel.load_config(
|
||||
"meituan-longcat/LongCat-Image",
|
||||
subfolder="transformer",
|
||||
)
|
||||
# import DiT based on _class_name
|
||||
typed_config = LongCatImageTransformer2DModel._get_dataclass_from_config(config_dict)
|
||||
for f in dataclasses.fields(typed_config):
|
||||
print(f"{f.name}: {f.type}")
|
||||
29
test_pretrained_config.py
Normal file
29
test_pretrained_config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import dataclasses
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models import AutoModel
|
||||
|
||||
repo = "black-forest-labs/FLUX.2-dev"
|
||||
subfolder = "transformer"
|
||||
|
||||
print("=== From load_config (no model instantiation) ===")
|
||||
config_dict = FluxTransformer2DModel.load_config(repo, subfolder=subfolder)
|
||||
tc = FluxTransformer2DModel._get_dataclass_from_config(config_dict)
|
||||
print(f"Type: {type(tc).__name__}")
|
||||
for k, v in dataclasses.asdict(tc).items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
print()
|
||||
print("=== From AutoModel.from_config on meta device ===")
|
||||
with torch.device("meta"):
|
||||
model = AutoModel.from_config(repo, subfolder=subfolder)
|
||||
print(f"model.config:")
|
||||
for k, v in dict(model.config).items():
|
||||
if not k.startswith("_"):
|
||||
print(f" {k}: {v}")
|
||||
|
||||
print()
|
||||
print("=== Comparison ===")
|
||||
dc_dict = dataclasses.asdict(tc)
|
||||
config = {k: v for k, v in dict(model.config).items() if not k.startswith("_")}
|
||||
print(f"Match: {dc_dict == config}")
|
||||
@@ -7,9 +7,7 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
from transformers import CLIPTextModel, LongformerModel
|
||||
|
||||
from diffusers import ConfigMixin
|
||||
from diffusers.models import AutoModel, UNet2DConditionModel
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class TestAutoModel(unittest.TestCase):
|
||||
@@ -145,51 +143,3 @@ class TestAutoModelFromConfig(unittest.TestCase):
|
||||
def test_from_config_raises_on_none(self):
|
||||
with self.assertRaises(ValueError, msg="Please provide a `pretrained_model_name_or_path_or_dict`"):
|
||||
AutoModel.from_config(None)
|
||||
|
||||
|
||||
class TestRegisterForAutoClass(unittest.TestCase):
|
||||
def test_register_for_auto_class_sets_attribute(self):
|
||||
class DummyModel(ModelMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
DummyModel.register_for_auto_class("AutoModel")
|
||||
self.assertEqual(DummyModel._auto_class, "AutoModel")
|
||||
|
||||
def test_register_for_auto_class_rejects_unsupported(self):
|
||||
class DummyModel(ModelMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
with self.assertRaises(ValueError, msg="Only 'AutoModel' is supported"):
|
||||
DummyModel.register_for_auto_class("AutoPipeline")
|
||||
|
||||
def test_auto_map_in_saved_config(self):
|
||||
class DummyModel(ModelMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
DummyModel.register_for_auto_class("AutoModel")
|
||||
model = DummyModel()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_config(tmpdir)
|
||||
config_path = os.path.join(tmpdir, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.assertIn("auto_map", config)
|
||||
self.assertIn("AutoModel", config["auto_map"])
|
||||
module_name = DummyModel.__module__.split(".")[-1]
|
||||
self.assertEqual(config["auto_map"]["AutoModel"], f"{module_name}.DummyModel")
|
||||
|
||||
def test_no_auto_map_without_register(self):
|
||||
class DummyModel(ModelMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
model = DummyModel()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_config(tmpdir)
|
||||
config_path = os.path.join(tmpdir, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.assertNotIn("auto_map", config)
|
||||
|
||||
Reference in New Issue
Block a user