mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-11 05:15:31 +08:00
Compare commits
6 Commits
modular-su
...
fix-module
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed734a0e63 | ||
|
|
d676b03490 | ||
|
|
e117274aa5 | ||
|
|
5bf248ddd8 | ||
|
|
bedc67c75f | ||
|
|
a1804cfa80 |
@@ -29,8 +29,31 @@ text_encoder = AutoModel.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
## Custom models
|
||||
|
||||
[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
|
||||
|
||||
A custom model repository needs a Python module with the model class, and a `config.json` with an `auto_map` entry that maps `"AutoModel"` to `"module_file.ClassName"`.
|
||||
|
||||
```
|
||||
custom/custom-transformer-model/
|
||||
├── config.json
|
||||
├── my_model.py
|
||||
└── diffusion_pytorch_model.safetensors
|
||||
```
|
||||
|
||||
The `config.json` includes the `auto_map` field pointing to the custom class.
|
||||
|
||||
```json
|
||||
{
|
||||
"auto_map": {
|
||||
"AutoModel": "my_model.MyCustomModel"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Then load it with `trust_remote_code=True`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
@@ -40,7 +63,39 @@ transformer = AutoModel.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
For a real-world example, [Overworld/Waypoint-1-Small](https://huggingface.co/Overworld/Waypoint-1-Small/tree/main/transformer) hosts a custom `WorldModel` class across several modules in its `transformer` subfolder.
|
||||
|
||||
```
|
||||
transformer/
|
||||
├── config.json # auto_map: "model.WorldModel"
|
||||
├── model.py
|
||||
├── attn.py
|
||||
├── nn.py
|
||||
├── cache.py
|
||||
├── quantize.py
|
||||
├── __init__.py
|
||||
└── diffusion_pytorch_model.safetensors
|
||||
```
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
|
||||
|
||||
> [!WARNING]
|
||||
> As a precaution with `trust_remote_code=True`, pass a commit hash to the `revision` argument in [`AutoModel.from_pretrained`] to make sure the code hasn't been updated with new malicious code (unless you fully trust the model owners).
|
||||
>
|
||||
> ```py
|
||||
> transformer = AutoModel.from_pretrained(
|
||||
> "Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, revision="a3d8cb2"
|
||||
> )
|
||||
> ```
|
||||
|
||||
> [!NOTE]
|
||||
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
|
||||
@@ -18,7 +18,6 @@ import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import ftfy
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import ftfy
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
@@ -19,7 +19,6 @@ import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import ftfy
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
@@ -21,11 +21,8 @@ import torch
|
||||
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
|
||||
from diffusers.utils.import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_gguf_available,
|
||||
is_nvidia_modelopt_available,
|
||||
is_optimum_quanto_available,
|
||||
is_torchao_available,
|
||||
is_torchao_version,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
@@ -59,13 +56,6 @@ if is_bitsandbytes_available():
|
||||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import QLinear
|
||||
|
||||
if is_gguf_available():
|
||||
pass
|
||||
|
||||
if is_torchao_available():
|
||||
if is_torchao_version(">=", "0.9.0"):
|
||||
pass
|
||||
|
||||
|
||||
class LoRALayer(torch.nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only.
|
||||
@@ -132,14 +122,14 @@ class QuantizationTesterMixin:
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
|
||||
|
||||
def _is_module_quantized(self, module):
|
||||
def _is_module_quantized(self, module, quant_config_kwargs=None):
|
||||
"""
|
||||
Check if a module is quantized. Returns True if quantized, False otherwise.
|
||||
Default implementation tries _verify_if_layer_quantized and catches exceptions.
|
||||
Subclasses can override for more efficient checking.
|
||||
"""
|
||||
try:
|
||||
self._verify_if_layer_quantized("", module, {})
|
||||
self._verify_if_layer_quantized("", module, quant_config_kwargs or {})
|
||||
return True
|
||||
except (AssertionError, AttributeError):
|
||||
return False
|
||||
@@ -279,7 +269,9 @@ class QuantizationTesterMixin:
|
||||
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
|
||||
)
|
||||
|
||||
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
|
||||
def _test_quantization_modules_to_not_convert(
|
||||
self, config_kwargs, modules_to_not_convert, to_not_convert_key="modules_to_not_convert"
|
||||
):
|
||||
"""
|
||||
Test that modules specified in modules_to_not_convert are not quantized.
|
||||
|
||||
@@ -289,7 +281,7 @@ class QuantizationTesterMixin:
|
||||
"""
|
||||
# Create config with modules_to_not_convert
|
||||
config_kwargs_with_exclusion = config_kwargs.copy()
|
||||
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
|
||||
config_kwargs_with_exclusion[to_not_convert_key] = modules_to_not_convert
|
||||
|
||||
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
|
||||
|
||||
@@ -301,7 +293,7 @@ class QuantizationTesterMixin:
|
||||
if any(excluded in name for excluded in modules_to_not_convert):
|
||||
found_excluded = True
|
||||
# This module should NOT be quantized
|
||||
assert not self._is_module_quantized(module), (
|
||||
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
|
||||
f"Module {name} should not be quantized but was found to be quantized"
|
||||
)
|
||||
|
||||
@@ -313,7 +305,7 @@ class QuantizationTesterMixin:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is NOT in the exclusion list
|
||||
if not any(excluded in name for excluded in modules_to_not_convert):
|
||||
if self._is_module_quantized(module):
|
||||
if self._is_module_quantized(module, config_kwargs_with_exclusion):
|
||||
found_quantized = True
|
||||
break
|
||||
|
||||
@@ -618,7 +610,7 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(
|
||||
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude
|
||||
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude, "llm_int8_skip_modules"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"])
|
||||
@@ -817,7 +809,14 @@ class TorchAoConfigMixin:
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
|
||||
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
|
||||
# Check if the weight is actually quantized
|
||||
weight = module.weight
|
||||
is_quantized = isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))
|
||||
assert is_quantized, f"Layer {name} weight is not quantized, got {type(weight)}"
|
||||
|
||||
|
||||
# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
|
||||
@@ -913,9 +912,39 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(
|
||||
TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
|
||||
)
|
||||
# Custom implementation for torchao that skips memory footprint check
|
||||
# because get_memory_footprint() doesn't accurately reflect torchao quantization
|
||||
config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]
|
||||
config_kwargs_with_exclusion = config_kwargs.copy()
|
||||
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_exclude
|
||||
|
||||
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
|
||||
|
||||
# Find a module that should NOT be quantized
|
||||
found_excluded = False
|
||||
for name, module in model_with_exclusion.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is in the exclusion list
|
||||
if any(excluded in name for excluded in modules_to_exclude):
|
||||
found_excluded = True
|
||||
# This module should NOT be quantized
|
||||
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
|
||||
f"Module {name} should not be quantized but was found to be quantized"
|
||||
)
|
||||
|
||||
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_exclude}"
|
||||
|
||||
# Find a module that SHOULD be quantized (not in exclusion list)
|
||||
found_quantized = False
|
||||
for name, module in model_with_exclusion.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is NOT in the exclusion list
|
||||
if not any(excluded in name for excluded in modules_to_exclude):
|
||||
if self._is_module_quantized(module, config_kwargs_with_exclusion):
|
||||
found_quantized = True
|
||||
break
|
||||
|
||||
assert found_quantized, "No quantized layers found outside of excluded modules"
|
||||
|
||||
def test_torchao_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
|
||||
@@ -318,6 +318,10 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def modules_to_not_convert_for_test(self):
|
||||
return ["norm_out.linear"]
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
"""Quanto quantization tests for Flux Transformer."""
|
||||
@@ -330,10 +334,18 @@ class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
def pretrained_model_kwargs(self):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def modules_to_not_convert_for_test(self):
|
||||
return ["norm_out.linear"]
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def modules_to_not_convert_for_test(self):
|
||||
return ["norm_out.linear"]
|
||||
|
||||
|
||||
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
|
||||
@property
|
||||
@@ -402,6 +414,10 @@ class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTes
|
||||
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
|
||||
"""ModelOpt quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def modules_to_not_convert_for_test(self):
|
||||
return ["norm_out.linear"]
|
||||
|
||||
|
||||
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
|
||||
"""ModelOpt + compile tests for Flux Transformer."""
|
||||
|
||||
Reference in New Issue
Block a user