mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-14 16:34:27 +08:00
Compare commits
41 Commits
modular-do
...
add-quanto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1b46a32f2a | ||
|
|
cf4694e19e | ||
|
|
deebc22ebd | ||
|
|
d5ab9cadc0 | ||
|
|
6cad1d537a | ||
|
|
bb7fb66b4d | ||
|
|
8163687e08 | ||
|
|
8afff1bb0d | ||
|
|
830b7345b7 | ||
|
|
4516f2238b | ||
|
|
156db084d2 | ||
|
|
963559f69b | ||
|
|
dbaef7c3a4 | ||
|
|
f512c2893f | ||
|
|
4eabed7f97 | ||
|
|
0736f87d1f | ||
|
|
6cf9a78108 | ||
|
|
c29684f44f | ||
|
|
c4b6e24fe5 | ||
|
|
79901e4dec | ||
|
|
9a72fefd4b | ||
|
|
d355e6aa9b | ||
|
|
c80d4d4a72 | ||
|
|
2c7f30325d | ||
|
|
b136d239e4 | ||
|
|
9e5a3d0766 | ||
|
|
559f12470a | ||
|
|
e090177766 | ||
|
|
7b841dc52d | ||
|
|
4ae86916b2 | ||
|
|
e96686e9c9 | ||
|
|
7472f18b9c | ||
|
|
f734c096e7 | ||
|
|
5cff237f75 | ||
|
|
f67d97c0ac | ||
|
|
f4c14c222d | ||
|
|
f52050a39f | ||
|
|
39e20e2405 | ||
|
|
aa8cdaf056 | ||
|
|
ba5bba74f2 | ||
|
|
ff50418472 |
2
.github/workflows/nightly_tests.yml
vendored
2
.github/workflows/nightly_tests.yml
vendored
@@ -418,6 +418,8 @@ jobs:
|
|||||||
test_location: "gguf"
|
test_location: "gguf"
|
||||||
- backend: "torchao"
|
- backend: "torchao"
|
||||||
test_location: "torchao"
|
test_location: "torchao"
|
||||||
|
- backend: "optimum_quanto"
|
||||||
|
test_location: "quanto"
|
||||||
runs-on:
|
runs-on:
|
||||||
group: aws-g6e-xlarge-plus
|
group: aws-g6e-xlarge-plus
|
||||||
container:
|
container:
|
||||||
|
|||||||
@@ -173,6 +173,8 @@
|
|||||||
title: gguf
|
title: gguf
|
||||||
- local: quantization/torchao
|
- local: quantization/torchao
|
||||||
title: torchao
|
title: torchao
|
||||||
|
- local: quantization/quanto
|
||||||
|
title: quanto
|
||||||
title: Quantization Methods
|
title: Quantization Methods
|
||||||
- sections:
|
- sections:
|
||||||
- local: optimization/fp16
|
- local: optimization/fp16
|
||||||
|
|||||||
@@ -31,6 +31,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
|
|||||||
## GGUFQuantizationConfig
|
## GGUFQuantizationConfig
|
||||||
|
|
||||||
[[autodoc]] GGUFQuantizationConfig
|
[[autodoc]] GGUFQuantizationConfig
|
||||||
|
|
||||||
|
## QuantoConfig
|
||||||
|
|
||||||
|
[[autodoc]] QuantoConfig
|
||||||
|
|
||||||
## TorchAoConfig
|
## TorchAoConfig
|
||||||
|
|
||||||
[[autodoc]] TorchAoConfig
|
[[autodoc]] TorchAoConfig
|
||||||
|
|||||||
@@ -36,5 +36,6 @@ Diffusers currently supports the following quantization methods.
|
|||||||
- [BitsandBytes](./bitsandbytes)
|
- [BitsandBytes](./bitsandbytes)
|
||||||
- [TorchAO](./torchao)
|
- [TorchAO](./torchao)
|
||||||
- [GGUF](./gguf)
|
- [GGUF](./gguf)
|
||||||
|
- [Quanto](./quanto.md)
|
||||||
|
|
||||||
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||||
|
|||||||
148
docs/source/en/quantization/quanto.md
Normal file
148
docs/source/en/quantization/quanto.md
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Quanto
|
||||||
|
|
||||||
|
[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind:
|
||||||
|
|
||||||
|
- All features are available in eager mode (works with non-traceable models)
|
||||||
|
- Supports quantization aware training
|
||||||
|
- Quantized models are compatible with `torch.compile`
|
||||||
|
- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU)
|
||||||
|
|
||||||
|
In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate`
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install optimum-quanto accelerate
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||||
|
|
||||||
|
model_id = "black-forest-labs/FLUX.1-dev"
|
||||||
|
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||||
|
transformer = FluxTransformer2DModel.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
subfolder="transformer",
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
|
||||||
|
pipe.to("cuda")
|
||||||
|
|
||||||
|
prompt = "A cat holding a sign that says hello world"
|
||||||
|
image = pipe(
|
||||||
|
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
|
||||||
|
).images[0]
|
||||||
|
image.save("output.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Skipping Quantization on specific modules
|
||||||
|
|
||||||
|
It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict`
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||||
|
|
||||||
|
model_id = "black-forest-labs/FLUX.1-dev"
|
||||||
|
quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"])
|
||||||
|
transformer = FluxTransformer2DModel.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
subfolder="transformer",
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using `from_single_file` with the Quanto Backend
|
||||||
|
|
||||||
|
`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||||
|
|
||||||
|
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||||
|
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||||
|
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Saving Quantized models
|
||||||
|
|
||||||
|
Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method.
|
||||||
|
|
||||||
|
The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized
|
||||||
|
with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained`
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||||
|
|
||||||
|
model_id = "black-forest-labs/FLUX.1-dev"
|
||||||
|
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||||
|
transformer = FluxTransformer2DModel.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
subfolder="transformer",
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
# save quantized model to reuse
|
||||||
|
transformer.save_pretrained("<your quantized model save path>")
|
||||||
|
|
||||||
|
# you can reload your quantized model with
|
||||||
|
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using `torch.compile` with Quanto
|
||||||
|
|
||||||
|
Currently the Quanto backend supports `torch.compile` for the following quantization types:
|
||||||
|
|
||||||
|
- `int8` weights
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
|
||||||
|
|
||||||
|
model_id = "black-forest-labs/FLUX.1-dev"
|
||||||
|
quantization_config = QuantoConfig(weights_dtype="int8")
|
||||||
|
transformer = FluxTransformer2DModel.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
subfolder="transformer",
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
|
||||||
|
|
||||||
|
pipe = FluxPipeline.from_pretrained(
|
||||||
|
model_id, transformer=transformer, torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
pipe.to("cuda")
|
||||||
|
images = pipe("A cat holding a sign that says hello").images[0]
|
||||||
|
images.save("flux-quanto-compile.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Quantization Types
|
||||||
|
|
||||||
|
### Weights
|
||||||
|
|
||||||
|
- float8
|
||||||
|
- int8
|
||||||
|
- int4
|
||||||
|
- int2
|
||||||
|
|
||||||
|
|
||||||
9
setup.py
9
setup.py
@@ -128,6 +128,10 @@ _deps = [
|
|||||||
"GitPython<3.1.19",
|
"GitPython<3.1.19",
|
||||||
"scipy",
|
"scipy",
|
||||||
"onnx",
|
"onnx",
|
||||||
|
"optimum_quanto>=0.2.6",
|
||||||
|
"gguf>=0.10.0",
|
||||||
|
"torchao>=0.7.0",
|
||||||
|
"bitsandbytes>=0.43.3",
|
||||||
"regex!=2019.12.17",
|
"regex!=2019.12.17",
|
||||||
"requests",
|
"requests",
|
||||||
"tensorboard",
|
"tensorboard",
|
||||||
@@ -235,6 +239,11 @@ extras["test"] = deps_list(
|
|||||||
)
|
)
|
||||||
extras["torch"] = deps_list("torch", "accelerate")
|
extras["torch"] = deps_list("torch", "accelerate")
|
||||||
|
|
||||||
|
extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
|
||||||
|
extras["gguf"] = deps_list("gguf", "accelerate")
|
||||||
|
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
|
||||||
|
extras["torchao"] = deps_list("torchao", "accelerate")
|
||||||
|
|
||||||
if os.name == "nt": # windows
|
if os.name == "nt": # windows
|
||||||
extras["flax"] = [] # jax is not supported on windows
|
extras["flax"] = [] # jax is not supported on windows
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -2,6 +2,15 @@ __version__ = "0.33.0.dev0"
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from diffusers.quantizers import quantization_config
|
||||||
|
from diffusers.utils import dummy_gguf_objects
|
||||||
|
from diffusers.utils.import_utils import (
|
||||||
|
is_bitsandbytes_available,
|
||||||
|
is_gguf_available,
|
||||||
|
is_optimum_quanto_version,
|
||||||
|
is_torchao_available,
|
||||||
|
)
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DIFFUSERS_SLOW_IMPORT,
|
DIFFUSERS_SLOW_IMPORT,
|
||||||
OptionalDependencyNotAvailable,
|
OptionalDependencyNotAvailable,
|
||||||
@@ -11,6 +20,7 @@ from .utils import (
|
|||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
is_note_seq_available,
|
is_note_seq_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
|
is_optimum_quanto_available,
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -32,7 +42,7 @@ _import_structure = {
|
|||||||
"loaders": ["FromOriginalModelMixin"],
|
"loaders": ["FromOriginalModelMixin"],
|
||||||
"models": [],
|
"models": [],
|
||||||
"pipelines": [],
|
"pipelines": [],
|
||||||
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
|
"quantizers.quantization_config": [],
|
||||||
"schedulers": [],
|
"schedulers": [],
|
||||||
"utils": [
|
"utils": [
|
||||||
"OptionalDependencyNotAvailable",
|
"OptionalDependencyNotAvailable",
|
||||||
@@ -54,6 +64,55 @@ _import_structure = {
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_bitsandbytes_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils import dummy_bitsandbytes_objects
|
||||||
|
|
||||||
|
_import_structure["utils.dummy_bitsandbytes_objects"] = [
|
||||||
|
name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_gguf_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils import dummy_gguf_objects
|
||||||
|
|
||||||
|
_import_structure["utils.dummy_gguf_objects"] = [
|
||||||
|
name for name in dir(dummy_gguf_objects) if not name.startswith("_")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_torchao_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils import dummy_torchao_objects
|
||||||
|
|
||||||
|
_import_structure["utils.dummy_torchao_objects"] = [
|
||||||
|
name for name in dir(dummy_torchao_objects) if not name.startswith("_")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_optimum_quanto_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils import dummy_optimum_quanto_objects
|
||||||
|
|
||||||
|
_import_structure["utils.dummy_optimum_quanto_objects"] = [
|
||||||
|
name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_onnx_available():
|
if not is_onnx_available():
|
||||||
raise OptionalDependencyNotAvailable()
|
raise OptionalDependencyNotAvailable()
|
||||||
@@ -598,7 +657,38 @@ else:
|
|||||||
|
|
||||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
from .configuration_utils import ConfigMixin
|
from .configuration_utils import ConfigMixin
|
||||||
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
|
|
||||||
|
try:
|
||||||
|
if not is_bitsandbytes_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils.dummy_bitsandbytes_objects import *
|
||||||
|
else:
|
||||||
|
from .quantizers.quantization_config import BitsAndBytesConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_gguf_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils.dummy_gguf_objects import *
|
||||||
|
else:
|
||||||
|
from .quantizers.quantization_config import GGUFQuantizationConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_torchao_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils.dummy_torchao_objects import *
|
||||||
|
else:
|
||||||
|
from .quantizers.quantization_config import TorchAoConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_optimum_quanto_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from .utils.dummy_optimum_quanto_objects import *
|
||||||
|
else:
|
||||||
|
from .quantizers.quantization_config import QuantoConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_onnx_available():
|
if not is_onnx_available():
|
||||||
|
|||||||
@@ -35,6 +35,10 @@ deps = {
|
|||||||
"GitPython": "GitPython<3.1.19",
|
"GitPython": "GitPython<3.1.19",
|
||||||
"scipy": "scipy",
|
"scipy": "scipy",
|
||||||
"onnx": "onnx",
|
"onnx": "onnx",
|
||||||
|
"optimum_quanto": "optimum_quanto>=0.2.6",
|
||||||
|
"gguf": "gguf>=0.10.0",
|
||||||
|
"torchao": "torchao>=0.7.0",
|
||||||
|
"bitsandbytes": "bitsandbytes>=0.43.3",
|
||||||
"regex": "regex!=2019.12.17",
|
"regex": "regex!=2019.12.17",
|
||||||
"requests": "requests",
|
"requests": "requests",
|
||||||
"tensorboard": "tensorboard",
|
"tensorboard": "tensorboard",
|
||||||
|
|||||||
@@ -245,6 +245,9 @@ def load_model_dict_into_meta(
|
|||||||
):
|
):
|
||||||
param = param.to(torch.float32)
|
param = param.to(torch.float32)
|
||||||
set_module_kwargs["dtype"] = torch.float32
|
set_module_kwargs["dtype"] = torch.float32
|
||||||
|
# For quantizers have save weights using torch.float8_e4m3fn
|
||||||
|
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
param = param.to(dtype)
|
param = param.to(dtype)
|
||||||
set_module_kwargs["dtype"] = dtype
|
set_module_kwargs["dtype"] = dtype
|
||||||
@@ -292,7 +295,9 @@ def load_model_dict_into_meta(
|
|||||||
elif is_quantized and (
|
elif is_quantized and (
|
||||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
|
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
|
||||||
):
|
):
|
||||||
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
|
hf_quantizer.create_quantized_param(
|
||||||
|
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
|
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,10 @@ from .quantization_config import (
|
|||||||
GGUFQuantizationConfig,
|
GGUFQuantizationConfig,
|
||||||
QuantizationConfigMixin,
|
QuantizationConfigMixin,
|
||||||
QuantizationMethod,
|
QuantizationMethod,
|
||||||
|
QuantoConfig,
|
||||||
TorchAoConfig,
|
TorchAoConfig,
|
||||||
)
|
)
|
||||||
|
from .quanto import QuantoQuantizer
|
||||||
from .torchao import TorchAoHfQuantizer
|
from .torchao import TorchAoHfQuantizer
|
||||||
|
|
||||||
|
|
||||||
@@ -35,6 +37,7 @@ AUTO_QUANTIZER_MAPPING = {
|
|||||||
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
|
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
|
||||||
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
|
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
|
||||||
"gguf": GGUFQuantizer,
|
"gguf": GGUFQuantizer,
|
||||||
|
"quanto": QuantoQuantizer,
|
||||||
"torchao": TorchAoHfQuantizer,
|
"torchao": TorchAoHfQuantizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,6 +45,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
|||||||
"bitsandbytes_4bit": BitsAndBytesConfig,
|
"bitsandbytes_4bit": BitsAndBytesConfig,
|
||||||
"bitsandbytes_8bit": BitsAndBytesConfig,
|
"bitsandbytes_8bit": BitsAndBytesConfig,
|
||||||
"gguf": GGUFQuantizationConfig,
|
"gguf": GGUFQuantizationConfig,
|
||||||
|
"quanto": QuantoConfig,
|
||||||
"torchao": TorchAoConfig,
|
"torchao": TorchAoConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ class QuantizationMethod(str, Enum):
|
|||||||
BITS_AND_BYTES = "bitsandbytes"
|
BITS_AND_BYTES = "bitsandbytes"
|
||||||
GGUF = "gguf"
|
GGUF = "gguf"
|
||||||
TORCHAO = "torchao"
|
TORCHAO = "torchao"
|
||||||
|
QUANTO = "quanto"
|
||||||
|
|
||||||
|
|
||||||
if is_torchao_available():
|
if is_torchao_available():
|
||||||
@@ -686,3 +687,38 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|||||||
return (
|
return (
|
||||||
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
|
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QuantoConfig(QuantizationConfigMixin):
|
||||||
|
"""
|
||||||
|
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
||||||
|
loaded using `quanto`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights_dtype (`str`, *optional*, defaults to `"int8"`):
|
||||||
|
The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
|
||||||
|
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||||
|
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
|
||||||
|
modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weights_dtype: str = "int8",
|
||||||
|
modules_to_not_convert: Optional[List[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.quant_method = QuantizationMethod.QUANTO
|
||||||
|
self.weights_dtype = weights_dtype
|
||||||
|
self.modules_to_not_convert = modules_to_not_convert
|
||||||
|
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def post_init(self):
|
||||||
|
r"""
|
||||||
|
Safety checker that arguments are correct
|
||||||
|
"""
|
||||||
|
accepted_weights = ["float8", "int8", "int4", "int2"]
|
||||||
|
if self.weights_dtype not in accepted_weights:
|
||||||
|
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
|
||||||
|
|||||||
1
src/diffusers/quantizers/quanto/__init__.py
Normal file
1
src/diffusers/quantizers/quanto/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .quanto_quantizer import QuantoQuantizer
|
||||||
177
src/diffusers/quantizers/quanto/quanto_quantizer.py
Normal file
177
src/diffusers/quantizers/quanto/quanto_quantizer.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||||
|
|
||||||
|
from diffusers.utils.import_utils import is_optimum_quanto_version
|
||||||
|
|
||||||
|
from ...utils import (
|
||||||
|
get_module_from_name,
|
||||||
|
is_accelerate_available,
|
||||||
|
is_accelerate_version,
|
||||||
|
is_optimum_quanto_available,
|
||||||
|
is_torch_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from ..base import DiffusersQuantizer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...models.modeling_utils import ModelMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate.utils import CustomDtype, set_module_tensor_to_device
|
||||||
|
|
||||||
|
if is_optimum_quanto_available():
|
||||||
|
from .utils import _replace_with_quanto_layers
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantoQuantizer(DiffusersQuantizer):
|
||||||
|
r"""
|
||||||
|
Diffusers Quantizer for Optimum Quanto
|
||||||
|
"""
|
||||||
|
|
||||||
|
use_keep_in_fp32_modules = True
|
||||||
|
requires_calibration = False
|
||||||
|
required_packages = ["quanto", "accelerate"]
|
||||||
|
|
||||||
|
def __init__(self, quantization_config, **kwargs):
|
||||||
|
super().__init__(quantization_config, **kwargs)
|
||||||
|
|
||||||
|
def validate_environment(self, *args, **kwargs):
|
||||||
|
if not is_optimum_quanto_available():
|
||||||
|
raise ImportError(
|
||||||
|
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
|
||||||
|
)
|
||||||
|
if not is_optimum_quanto_version(">=", "0.2.6"):
|
||||||
|
raise ImportError(
|
||||||
|
"Loading an optimum-quanto quantized model requires `optimum-quanto>=0.2.6`. "
|
||||||
|
"Please upgrade your installation with `pip install --upgrade optimum-quanto"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_accelerate_available():
|
||||||
|
raise ImportError(
|
||||||
|
"Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
device_map = kwargs.get("device_map", None)
|
||||||
|
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend"
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_if_quantized_param(
|
||||||
|
self,
|
||||||
|
model: "ModelMixin",
|
||||||
|
param_value: "torch.Tensor",
|
||||||
|
param_name: str,
|
||||||
|
state_dict: Dict[str, Any],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# Quanto imports diffusers internally. This is here to prevent circular imports
|
||||||
|
from optimum.quanto import QModuleMixin, QTensor
|
||||||
|
from optimum.quanto.tensor.packed import PackedTensor
|
||||||
|
|
||||||
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
|
if self.pre_quantized and any(isinstance(module, t) for t in [QTensor, PackedTensor]):
|
||||||
|
return True
|
||||||
|
elif isinstance(module, QModuleMixin) and "weight" in tensor_name:
|
||||||
|
return not module.frozen
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def create_quantized_param(
|
||||||
|
self,
|
||||||
|
model: "ModelMixin",
|
||||||
|
param_value: "torch.Tensor",
|
||||||
|
param_name: str,
|
||||||
|
target_device: "torch.device",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create the quantized parameter by calling .freeze() after setting it to the module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dtype = kwargs.get("dtype", torch.float32)
|
||||||
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
|
if self.pre_quantized:
|
||||||
|
setattr(module, tensor_name, param_value)
|
||||||
|
else:
|
||||||
|
set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
|
||||||
|
module.freeze()
|
||||||
|
module.weight.requires_grad = False
|
||||||
|
|
||||||
|
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||||
|
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
||||||
|
return max_memory
|
||||||
|
|
||||||
|
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||||
|
if is_accelerate_version(">=", "0.27.0"):
|
||||||
|
mapping = {
|
||||||
|
"int8": torch.int8,
|
||||||
|
"float8": CustomDtype.FP8,
|
||||||
|
"int4": CustomDtype.INT4,
|
||||||
|
"int2": CustomDtype.INT2,
|
||||||
|
}
|
||||||
|
target_dtype = mapping[self.quantization_config.weights_dtype]
|
||||||
|
|
||||||
|
return target_dtype
|
||||||
|
|
||||||
|
def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
|
||||||
|
if torch_dtype is None:
|
||||||
|
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
|
||||||
|
torch_dtype = torch.float32
|
||||||
|
return torch_dtype
|
||||||
|
|
||||||
|
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||||
|
# Quanto imports diffusers internally. This is here to prevent circular imports
|
||||||
|
from optimum.quanto import QModuleMixin
|
||||||
|
|
||||||
|
not_missing_keys = []
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, QModuleMixin):
|
||||||
|
for missing in missing_keys:
|
||||||
|
if (
|
||||||
|
(name in missing or name in f"{prefix}.{missing}")
|
||||||
|
and not missing.endswith(".weight")
|
||||||
|
and not missing.endswith(".bias")
|
||||||
|
):
|
||||||
|
not_missing_keys.append(missing)
|
||||||
|
return [k for k in missing_keys if k not in not_missing_keys]
|
||||||
|
|
||||||
|
def _process_model_before_weight_loading(
|
||||||
|
self,
|
||||||
|
model: "ModelMixin",
|
||||||
|
device_map,
|
||||||
|
keep_in_fp32_modules: List[str] = [],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
|
||||||
|
|
||||||
|
if not isinstance(self.modules_to_not_convert, list):
|
||||||
|
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||||
|
|
||||||
|
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||||
|
|
||||||
|
model = _replace_with_quanto_layers(
|
||||||
|
model,
|
||||||
|
modules_to_not_convert=self.modules_to_not_convert,
|
||||||
|
quantization_config=self.quantization_config,
|
||||||
|
pre_quantized=self.pre_quantized,
|
||||||
|
)
|
||||||
|
model.config.quantization_config = self.quantization_config
|
||||||
|
|
||||||
|
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||||
|
return model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_trainable(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_serializable(self):
|
||||||
|
return True
|
||||||
60
src/diffusers/quantizers/quanto/utils.py
Normal file
60
src/diffusers/quantizers/quanto/utils.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ...utils import is_accelerate_available, logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list, pre_quantized=False):
|
||||||
|
# Quanto imports diffusers internally. These are placed here to avoid circular imports
|
||||||
|
from optimum.quanto import QLinear, freeze, qfloat8, qint2, qint4, qint8
|
||||||
|
|
||||||
|
def _get_weight_type(dtype: str):
|
||||||
|
return {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}[dtype]
|
||||||
|
|
||||||
|
def _replace_layers(model, quantization_config, modules_to_not_convert):
|
||||||
|
has_children = list(model.children())
|
||||||
|
if not has_children:
|
||||||
|
return model
|
||||||
|
|
||||||
|
for name, module in model.named_children():
|
||||||
|
_replace_layers(module, quantization_config, modules_to_not_convert)
|
||||||
|
|
||||||
|
if name in modules_to_not_convert:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
with init_empty_weights():
|
||||||
|
qlinear = QLinear(
|
||||||
|
in_features=module.in_features,
|
||||||
|
out_features=module.out_features,
|
||||||
|
bias=module.bias is not None,
|
||||||
|
dtype=module.weight.dtype,
|
||||||
|
weights=_get_weight_type(quantization_config.weights_dtype),
|
||||||
|
)
|
||||||
|
model._modules[name] = qlinear
|
||||||
|
model._modules[name].source_cls = type(module)
|
||||||
|
model._modules[name].requires_grad_(False)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
model = _replace_layers(model, quantization_config, modules_to_not_convert)
|
||||||
|
has_been_replaced = any(isinstance(replaced_module, QLinear) for _, replaced_module in model.named_modules())
|
||||||
|
|
||||||
|
if not has_been_replaced:
|
||||||
|
logger.warning(
|
||||||
|
f"{model.__class__.__name__} does not appear to have any `nn.Linear` modules. Quantization will not be applied."
|
||||||
|
" Please check your model architecture, or submit an issue on Github if you think this is a bug."
|
||||||
|
" https://github.com/huggingface/diffusers/issues/new"
|
||||||
|
)
|
||||||
|
|
||||||
|
# We need to freeze the pre_quantized model in order for the loaded state_dict and model state dict
|
||||||
|
# to match when trying to load weights with load_model_dict_into_meta
|
||||||
|
if pre_quantized:
|
||||||
|
freeze(model)
|
||||||
|
|
||||||
|
return model
|
||||||
@@ -79,6 +79,8 @@ from .import_utils import (
|
|||||||
is_matplotlib_available,
|
is_matplotlib_available,
|
||||||
is_note_seq_available,
|
is_note_seq_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
|
is_optimum_quanto_available,
|
||||||
|
is_optimum_quanto_version,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
is_peft_version,
|
is_peft_version,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
|
|||||||
17
src/diffusers/utils/dummy_bitsandbytes_objects.py
Normal file
17
src/diffusers/utils/dummy_bitsandbytes_objects.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||||
|
from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class BitsAndBytesConfig(metaclass=DummyObject):
|
||||||
|
_backends = ["bitsandbytes"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["bitsandbytes"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["bitsandbytes"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["bitsandbytes"])
|
||||||
17
src/diffusers/utils/dummy_gguf_objects.py
Normal file
17
src/diffusers/utils/dummy_gguf_objects.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||||
|
from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class GGUFQuantizationConfig(metaclass=DummyObject):
|
||||||
|
_backends = ["gguf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["gguf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["gguf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["gguf"])
|
||||||
17
src/diffusers/utils/dummy_optimum_quanto_objects.py
Normal file
17
src/diffusers/utils/dummy_optimum_quanto_objects.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||||
|
from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class QuantoConfig(metaclass=DummyObject):
|
||||||
|
_backends = ["optimum_quanto"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["optimum_quanto"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["optimum_quanto"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["optimum_quanto"])
|
||||||
17
src/diffusers/utils/dummy_torchao_objects.py
Normal file
17
src/diffusers/utils/dummy_torchao_objects.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||||
|
from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class TorchAoConfig(metaclass=DummyObject):
|
||||||
|
_backends = ["torchao"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torchao"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torchao"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torchao"])
|
||||||
@@ -365,6 +365,15 @@ if _is_torchao_available:
|
|||||||
_is_torchao_available = False
|
_is_torchao_available = False
|
||||||
|
|
||||||
|
|
||||||
|
_is_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
|
||||||
|
if _is_optimum_quanto_available:
|
||||||
|
try:
|
||||||
|
_optimum_quanto_version = importlib_metadata.version("optimum_quanto")
|
||||||
|
logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
_is_optimum_quanto_available = False
|
||||||
|
|
||||||
|
|
||||||
def is_torch_available():
|
def is_torch_available():
|
||||||
return _torch_available
|
return _torch_available
|
||||||
|
|
||||||
@@ -493,6 +502,10 @@ def is_torchao_available():
|
|||||||
return _is_torchao_available
|
return _is_torchao_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_optimum_quanto_available():
|
||||||
|
return _is_optimum_quanto_available
|
||||||
|
|
||||||
|
|
||||||
# docstyle-ignore
|
# docstyle-ignore
|
||||||
FLAX_IMPORT_ERROR = """
|
FLAX_IMPORT_ERROR = """
|
||||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||||
@@ -636,6 +649,11 @@ TORCHAO_IMPORT_ERROR = """
|
|||||||
torchao`
|
torchao`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
QUANTO_IMPORT_ERROR = """
|
||||||
|
{0} requires the optimum-quanto library but it was not found in your environment. You can install it with pip: `pip
|
||||||
|
install optimum-quanto`
|
||||||
|
"""
|
||||||
|
|
||||||
BACKENDS_MAPPING = OrderedDict(
|
BACKENDS_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
||||||
@@ -663,6 +681,7 @@ BACKENDS_MAPPING = OrderedDict(
|
|||||||
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
|
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
|
||||||
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
|
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
|
||||||
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
|
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
|
||||||
|
("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -864,6 +883,21 @@ def is_k_diffusion_version(operation: str, version: str):
|
|||||||
return compare_versions(parse(_k_diffusion_version), operation, version)
|
return compare_versions(parse(_k_diffusion_version), operation, version)
|
||||||
|
|
||||||
|
|
||||||
|
def is_optimum_quanto_version(operation: str, version: str):
|
||||||
|
"""
|
||||||
|
Compares the current Accelerate version to a given reference with an operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation (`str`):
|
||||||
|
A string representation of an operator, such as `">"` or `"<="`
|
||||||
|
version (`str`):
|
||||||
|
A version string
|
||||||
|
"""
|
||||||
|
if not _is_optimum_quanto_available:
|
||||||
|
return False
|
||||||
|
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
||||||
|
|
||||||
|
|
||||||
def get_objects_from_module(module):
|
def get_objects_from_module(module):
|
||||||
"""
|
"""
|
||||||
Returns a dict of object names and values in a module, while skipping private/internal objects
|
Returns a dict of object names and values in a module, while skipping private/internal objects
|
||||||
|
|||||||
346
tests/quantization/quanto/test_quanto.py
Normal file
346
tests/quantization/quanto/test_quanto.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
import gc
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
|
||||||
|
from diffusers.models.attention_processor import Attention
|
||||||
|
from diffusers.utils import is_optimum_quanto_available, is_torch_available
|
||||||
|
from diffusers.utils.testing_utils import (
|
||||||
|
nightly,
|
||||||
|
numpy_cosine_similarity_distance,
|
||||||
|
require_accelerate,
|
||||||
|
require_big_gpu_with_torch_cuda,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_optimum_quanto_available():
|
||||||
|
from optimum.quanto import QLinear
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class LoRALayer(nn.Module):
|
||||||
|
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
|
||||||
|
|
||||||
|
Taken from
|
||||||
|
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, module: nn.Module, rank: int):
|
||||||
|
super().__init__()
|
||||||
|
self.module = module
|
||||||
|
self.adapter = nn.Sequential(
|
||||||
|
nn.Linear(module.in_features, rank, bias=False),
|
||||||
|
nn.Linear(rank, module.out_features, bias=False),
|
||||||
|
)
|
||||||
|
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||||
|
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||||
|
nn.init.zeros_(self.adapter[1].weight)
|
||||||
|
self.adapter.to(module.weight.device)
|
||||||
|
|
||||||
|
def forward(self, input, *args, **kwargs):
|
||||||
|
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||||
|
|
||||||
|
|
||||||
|
@nightly
|
||||||
|
@require_big_gpu_with_torch_cuda
|
||||||
|
@require_accelerate
|
||||||
|
class QuantoBaseTesterMixin:
|
||||||
|
model_id = None
|
||||||
|
pipeline_model_id = None
|
||||||
|
model_cls = None
|
||||||
|
torch_dtype = torch.bfloat16
|
||||||
|
# the expected reduction in peak memory used compared to an unquantized model expressed as a percentage
|
||||||
|
expected_memory_reduction = 0.0
|
||||||
|
keep_in_fp32_module = ""
|
||||||
|
modules_to_not_convert = ""
|
||||||
|
_test_torch_compile = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def get_dummy_init_kwargs(self):
|
||||||
|
return {"weights_dtype": "float8"}
|
||||||
|
|
||||||
|
def get_dummy_model_init_kwargs(self):
|
||||||
|
return {
|
||||||
|
"pretrained_model_name_or_path": self.model_id,
|
||||||
|
"torch_dtype": self.torch_dtype,
|
||||||
|
"quantization_config": QuantoConfig(**self.get_dummy_init_kwargs()),
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_quanto_layers(self):
|
||||||
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
assert isinstance(module, QLinear)
|
||||||
|
|
||||||
|
def test_quanto_memory_usage(self):
|
||||||
|
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
|
||||||
|
unquantized_model_memory = unquantized_model.get_memory_footprint() / 1024**3
|
||||||
|
|
||||||
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||||
|
inputs = self.get_dummy_inputs()
|
||||||
|
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
model(**inputs)
|
||||||
|
max_memory = torch.cuda.max_memory_allocated() / 1024**3
|
||||||
|
assert (1.0 - (max_memory / unquantized_model_memory)) >= self.expected_memory_reduction
|
||||||
|
|
||||||
|
def test_keep_modules_in_fp32(self):
|
||||||
|
r"""
|
||||||
|
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
|
||||||
|
Also ensures if inference works.
|
||||||
|
"""
|
||||||
|
_keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
|
||||||
|
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
|
||||||
|
|
||||||
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||||
|
model.to("cuda")
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
if name in model._keep_in_fp32_modules:
|
||||||
|
assert module.weight.dtype == torch.float32
|
||||||
|
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
|
||||||
|
|
||||||
|
def test_modules_to_not_convert(self):
|
||||||
|
init_kwargs = self.get_dummy_model_init_kwargs()
|
||||||
|
|
||||||
|
quantization_config_kwargs = self.get_dummy_init_kwargs()
|
||||||
|
quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert})
|
||||||
|
quantization_config = QuantoConfig(**quantization_config_kwargs)
|
||||||
|
|
||||||
|
init_kwargs.update({"quantization_config": quantization_config})
|
||||||
|
|
||||||
|
model = self.model_cls.from_pretrained(**init_kwargs)
|
||||||
|
model.to("cuda")
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if name in self.modules_to_not_convert:
|
||||||
|
assert not isinstance(module, QLinear)
|
||||||
|
|
||||||
|
def test_dtype_assignment(self):
|
||||||
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Tries with a `dtype`
|
||||||
|
model.to(torch.float16)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Tries with a `device` and `dtype`
|
||||||
|
model.to(device="cuda:0", dtype=torch.float16)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Tries with a cast
|
||||||
|
model.float()
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Tries with a cast
|
||||||
|
model.half()
|
||||||
|
|
||||||
|
# This should work
|
||||||
|
model.to("cuda")
|
||||||
|
|
||||||
|
def test_serialization(self):
|
||||||
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||||
|
inputs = self.get_dummy_inputs()
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output = model(**inputs)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
saved_model = self.model_cls.from_pretrained(
|
||||||
|
tmp_dir,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_model.to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
saved_model_output = saved_model(**inputs)
|
||||||
|
|
||||||
|
assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
def test_torch_compile(self):
|
||||||
|
if not self._test_torch_compile:
|
||||||
|
return
|
||||||
|
|
||||||
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||||
|
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False)
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output = model(**self.get_dummy_inputs()).sample
|
||||||
|
|
||||||
|
compiled_model.to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
|
||||||
|
|
||||||
|
model_output = model_output.detach().float().cpu().numpy()
|
||||||
|
compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
|
||||||
|
|
||||||
|
max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten())
|
||||||
|
assert max_diff < 1e-3
|
||||||
|
|
||||||
|
def test_device_map_error(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = self.model_cls.from_pretrained(
|
||||||
|
**self.get_dummy_model_init_kwargs(), device_map={0: "8GB", "cpu": "16GB"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTransformerQuantoMixin(QuantoBaseTesterMixin):
|
||||||
|
model_id = "hf-internal-testing/tiny-flux-transformer"
|
||||||
|
model_cls = FluxTransformer2DModel
|
||||||
|
pipeline_cls = FluxPipeline
|
||||||
|
torch_dtype = torch.bfloat16
|
||||||
|
keep_in_fp32_module = "proj_out"
|
||||||
|
modules_to_not_convert = ["proj_out"]
|
||||||
|
_test_torch_compile = False
|
||||||
|
|
||||||
|
def get_dummy_inputs(self):
|
||||||
|
return {
|
||||||
|
"hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||||
|
torch_device, self.torch_dtype
|
||||||
|
),
|
||||||
|
"encoder_hidden_states": torch.randn(
|
||||||
|
(1, 512, 4096),
|
||||||
|
generator=torch.Generator("cpu").manual_seed(0),
|
||||||
|
).to(torch_device, self.torch_dtype),
|
||||||
|
"pooled_projections": torch.randn(
|
||||||
|
(1, 768),
|
||||||
|
generator=torch.Generator("cpu").manual_seed(0),
|
||||||
|
).to(torch_device, self.torch_dtype),
|
||||||
|
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||||
|
"img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||||
|
torch_device, self.torch_dtype
|
||||||
|
),
|
||||||
|
"txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||||
|
torch_device, self.torch_dtype
|
||||||
|
),
|
||||||
|
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_dummy_training_inputs(self, device=None, seed: int = 0):
|
||||||
|
batch_size = 1
|
||||||
|
num_latent_channels = 4
|
||||||
|
num_image_channels = 3
|
||||||
|
height = width = 4
|
||||||
|
sequence_length = 48
|
||||||
|
embedding_dim = 32
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
|
||||||
|
device, dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"encoder_hidden_states": encoder_hidden_states,
|
||||||
|
"pooled_projections": pooled_prompt_embeds,
|
||||||
|
"txt_ids": text_ids,
|
||||||
|
"img_ids": image_ids,
|
||||||
|
"timestep": timestep,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_model_cpu_offload(self):
|
||||||
|
init_kwargs = self.get_dummy_init_kwargs()
|
||||||
|
transformer = self.model_cls.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-flux-pipe",
|
||||||
|
quantization_config=QuantoConfig(**init_kwargs),
|
||||||
|
subfolder="transformer",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
pipe = self.pipeline_cls.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
pipe.enable_model_cpu_offload(device=torch_device)
|
||||||
|
_ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
|
||||||
|
|
||||||
|
def test_training(self):
|
||||||
|
quantization_config = QuantoConfig(**self.get_dummy_init_kwargs())
|
||||||
|
quantized_model = self.model_cls.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-flux-pipe",
|
||||||
|
subfolder="transformer",
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
for param in quantized_model.parameters():
|
||||||
|
# freeze the model as only adapter layers will be trained
|
||||||
|
param.requires_grad = False
|
||||||
|
if param.ndim == 1:
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
for _, module in quantized_model.named_modules():
|
||||||
|
if isinstance(module, Attention):
|
||||||
|
module.to_q = LoRALayer(module.to_q, rank=4)
|
||||||
|
module.to_k = LoRALayer(module.to_k, rank=4)
|
||||||
|
module.to_v = LoRALayer(module.to_v, rank=4)
|
||||||
|
|
||||||
|
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
|
||||||
|
inputs = self.get_dummy_training_inputs(torch_device)
|
||||||
|
output = quantized_model(**inputs)[0]
|
||||||
|
output.norm().backward()
|
||||||
|
|
||||||
|
for module in quantized_model.modules():
|
||||||
|
if isinstance(module, LoRALayer):
|
||||||
|
self.assertTrue(module.adapter[1].weight.grad is not None)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||||
|
expected_memory_reduction = 0.3
|
||||||
|
|
||||||
|
def get_dummy_init_kwargs(self):
|
||||||
|
return {"weights_dtype": "float8"}
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||||
|
expected_memory_reduction = 0.3
|
||||||
|
_test_torch_compile = True
|
||||||
|
|
||||||
|
def get_dummy_init_kwargs(self):
|
||||||
|
return {"weights_dtype": "int8"}
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||||
|
expected_memory_reduction = 0.55
|
||||||
|
|
||||||
|
def get_dummy_init_kwargs(self):
|
||||||
|
return {"weights_dtype": "int4"}
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||||
|
expected_memory_reduction = 0.65
|
||||||
|
|
||||||
|
def get_dummy_init_kwargs(self):
|
||||||
|
return {"weights_dtype": "int2"}
|
||||||
Reference in New Issue
Block a user