Compare commits

...

41 Commits

Author SHA1 Message Date
Dhruv Nair
1b46a32f2a update 2025-03-10 03:23:56 +01:00
Dhruv Nair
cf4694e19e Merge branch 'add-quanto' of https://github.com/huggingface/diffusers into add-quanto 2025-03-07 17:27:30 +01:00
Dhruv Nair
deebc22ebd update 2025-03-07 17:27:19 +01:00
Dhruv Nair
d5ab9cadc0 Update src/diffusers/quantizers/quanto/utils.py
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-03-07 21:53:05 +05:30
Dhruv Nair
6cad1d537a update 2025-03-07 17:22:33 +01:00
Dhruv Nair
bb7fb66b4d update 2025-03-07 17:15:20 +01:00
Dhruv Nair
8163687e08 update 2025-03-07 04:21:46 +01:00
Dhruv Nair
8afff1bb0d Merge branch 'main' into add-quanto 2025-03-07 04:15:32 +01:00
Dhruv Nair
830b7345b7 update 2025-03-07 04:09:04 +01:00
Dhruv Nair
4516f2238b update 2025-03-07 03:51:28 +01:00
Dhruv Nair
156db084d2 Merge branch 'main' into add-quanto 2025-03-03 21:29:14 +05:30
Dhruv Nair
963559f69b update 2025-02-25 13:14:59 +01:00
Dhruv Nair
dbaef7c3a4 update 2025-02-25 13:02:26 +01:00
Dhruv Nair
f512c2893f update 2025-02-25 11:52:48 +01:00
Dhruv Nair
4eabed7f97 update 2025-02-25 05:12:25 +01:00
Dhruv Nair
0736f87d1f update 2025-02-20 09:03:50 +01:00
Dhruv Nair
6cf9a78108 update 2025-02-20 08:19:00 +01:00
Dhruv Nair
c29684f44f Merge branch 'main' into add-quanto 2025-02-20 05:27:36 +01:00
Dhruv Nair
c4b6e24fe5 update 2025-02-20 05:23:49 +01:00
Dhruv Nair
79901e4dec update 2025-02-18 19:19:30 +01:00
Dhruv Nair
9a72fefd4b Merge branch 'main' into add-quanto 2025-02-13 17:31:41 +01:00
Dhruv Nair
d355e6aa9b update 2025-02-13 17:31:21 +01:00
Dhruv Nair
c80d4d4a72 update 2025-02-12 18:56:21 +01:00
Dhruv Nair
2c7f30325d update 2025-02-11 15:39:25 +01:00
Dhruv Nair
b136d239e4 update 2025-02-11 12:36:14 +01:00
Dhruv Nair
9e5a3d0766 Merge branch 'add-quanto' of https://github.com/huggingface/diffusers into add-quanto 2025-02-11 12:34:55 +01:00
Dhruv Nair
559f12470a Merge https://github.com/huggingface/diffusers into add-quanto 2025-02-11 12:34:30 +01:00
Dhruv Nair
e090177766 update 2025-02-11 12:34:13 +01:00
Dhruv Nair
7b841dc52d Update docs/source/en/quantization/quanto.md
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-02-11 11:28:10 +05:30
Dhruv Nair
4ae86916b2 update 2025-02-10 09:46:40 +01:00
Dhruv Nair
e96686e9c9 update 2025-02-10 09:34:37 +01:00
Dhruv Nair
7472f18b9c update 2025-02-10 09:15:09 +01:00
Dhruv Nair
f734c096e7 update 2025-02-10 09:00:56 +01:00
Dhruv Nair
5cff237f75 update 2025-02-10 08:49:48 +01:00
Dhruv Nair
f67d97c0ac update 2025-02-10 08:22:27 +01:00
Dhruv Nair
f4c14c222d update 2025-02-10 08:20:14 +01:00
Dhruv Nair
f52050a39f update 2025-02-08 10:52:12 +01:00
Dhruv Nair
39e20e2405 update 2025-02-08 10:51:57 +01:00
Dhruv Nair
aa8cdaf056 update 2025-02-05 18:18:30 +01:00
Dhruv Nair
ba5bba74f2 updaet 2025-02-05 14:29:11 +01:00
Dhruv Nair
ff50418472 update 2025-02-05 14:28:57 +01:00
21 changed files with 997 additions and 3 deletions

View File

@@ -418,6 +418,8 @@ jobs:
test_location: "gguf"
- backend: "torchao"
test_location: "torchao"
- backend: "optimum_quanto"
test_location: "quanto"
runs-on:
group: aws-g6e-xlarge-plus
container:

View File

@@ -173,6 +173,8 @@
title: gguf
- local: quantization/torchao
title: torchao
- local: quantization/quanto
title: quanto
title: Quantization Methods
- sections:
- local: optimization/fp16

View File

@@ -31,6 +31,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
## GGUFQuantizationConfig
[[autodoc]] GGUFQuantizationConfig
## QuantoConfig
[[autodoc]] QuantoConfig
## TorchAoConfig
[[autodoc]] TorchAoConfig

View File

@@ -36,5 +36,6 @@ Diffusers currently supports the following quantization methods.
- [BitsandBytes](./bitsandbytes)
- [TorchAO](./torchao)
- [GGUF](./gguf)
- [Quanto](./quanto.md)
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.

View File

@@ -0,0 +1,148 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Quanto
[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind:
- All features are available in eager mode (works with non-traceable models)
- Supports quantization aware training
- Quantized models are compatible with `torch.compile`
- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU)
In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate`
```shell
pip install optimum-quanto accelerate
```
Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto.
```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig
model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
pipe.to("cuda")
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
).images[0]
image.save("output.png")
```
## Skipping Quantization on specific modules
It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict`
```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig
model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"])
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
```
## Using `from_single_file` with the Quanto Backend
`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`.
```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
quantization_config = QuantoConfig(weights_dtype="float8")
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
```
## Saving Quantized models
Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method.
The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized
with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained`
```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig
model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
# save quantized model to reuse
transformer.save_pretrained("<your quantized model save path>")
# you can reload your quantized model with
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>")
```
## Using `torch.compile` with Quanto
Currently the Quanto backend supports `torch.compile` for the following quantization types:
- `int8` weights
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="int8")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
pipe = FluxPipeline.from_pretrained(
model_id, transformer=transformer, torch_dtype=torch_dtype
)
pipe.to("cuda")
images = pipe("A cat holding a sign that says hello").images[0]
images.save("flux-quanto-compile.png")
```
## Supported Quantization Types
### Weights
- float8
- int8
- int4
- int2

View File

@@ -128,6 +128,10 @@ _deps = [
"GitPython<3.1.19",
"scipy",
"onnx",
"optimum_quanto>=0.2.6",
"gguf>=0.10.0",
"torchao>=0.7.0",
"bitsandbytes>=0.43.3",
"regex!=2019.12.17",
"requests",
"tensorboard",
@@ -235,6 +239,11 @@ extras["test"] = deps_list(
)
extras["torch"] = deps_list("torch", "accelerate")
extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
extras["gguf"] = deps_list("gguf", "accelerate")
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
extras["torchao"] = deps_list("torchao", "accelerate")
if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
else:

View File

@@ -2,6 +2,15 @@ __version__ = "0.33.0.dev0"
from typing import TYPE_CHECKING
from diffusers.quantizers import quantization_config
from diffusers.utils import dummy_gguf_objects
from diffusers.utils.import_utils import (
is_bitsandbytes_available,
is_gguf_available,
is_optimum_quanto_version,
is_torchao_available,
)
from .utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
@@ -11,6 +20,7 @@ from .utils import (
is_librosa_available,
is_note_seq_available,
is_onnx_available,
is_optimum_quanto_available,
is_scipy_available,
is_sentencepiece_available,
is_torch_available,
@@ -32,7 +42,7 @@ _import_structure = {
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
@@ -54,6 +64,55 @@ _import_structure = {
],
}
try:
if not is_bitsandbytes_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_bitsandbytes_objects
_import_structure["utils.dummy_bitsandbytes_objects"] = [
name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
try:
if not is_gguf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_gguf_objects
_import_structure["utils.dummy_gguf_objects"] = [
name for name in dir(dummy_gguf_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
try:
if not is_torchao_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torchao_objects
_import_structure["utils.dummy_torchao_objects"] = [
name for name in dir(dummy_torchao_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
try:
if not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_optimum_quanto_objects
_import_structure["utils.dummy_optimum_quanto_objects"] = [
name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -598,7 +657,38 @@ else:
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
try:
if not is_bitsandbytes_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_bitsandbytes_objects import *
else:
from .quantizers.quantization_config import BitsAndBytesConfig
try:
if not is_gguf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_gguf_objects import *
else:
from .quantizers.quantization_config import GGUFQuantizationConfig
try:
if not is_torchao_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torchao_objects import *
else:
from .quantizers.quantization_config import TorchAoConfig
try:
if not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_optimum_quanto_objects import *
else:
from .quantizers.quantization_config import QuantoConfig
try:
if not is_onnx_available():

View File

@@ -35,6 +35,10 @@ deps = {
"GitPython": "GitPython<3.1.19",
"scipy": "scipy",
"onnx": "onnx",
"optimum_quanto": "optimum_quanto>=0.2.6",
"gguf": "gguf>=0.10.0",
"torchao": "torchao>=0.7.0",
"bitsandbytes": "bitsandbytes>=0.43.3",
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",

View File

@@ -245,6 +245,9 @@ def load_model_dict_into_meta(
):
param = param.to(torch.float32)
set_module_kwargs["dtype"] = torch.float32
# For quantizers have save weights using torch.float8_e4m3fn
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
pass
else:
param = param.to(dtype)
set_module_kwargs["dtype"] = dtype
@@ -292,7 +295,9 @@ def load_model_dict_into_meta(
elif is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
):
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
hf_quantizer.create_quantized_param(
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)

View File

@@ -26,8 +26,10 @@ from .quantization_config import (
GGUFQuantizationConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
TorchAoConfig,
)
from .quanto import QuantoQuantizer
from .torchao import TorchAoHfQuantizer
@@ -35,6 +37,7 @@ AUTO_QUANTIZER_MAPPING = {
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
}
@@ -42,6 +45,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig,
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
}

View File

@@ -45,6 +45,7 @@ class QuantizationMethod(str, Enum):
BITS_AND_BYTES = "bitsandbytes"
GGUF = "gguf"
TORCHAO = "torchao"
QUANTO = "quanto"
if is_torchao_available():
@@ -686,3 +687,38 @@ class TorchAoConfig(QuantizationConfigMixin):
return (
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
)
@dataclass
class QuantoConfig(QuantizationConfigMixin):
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `quanto`.
Args:
weights_dtype (`str`, *optional*, defaults to `"int8"`):
The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
"""
def __init__(
self,
weights_dtype: str = "int8",
modules_to_not_convert: Optional[List[str]] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.QUANTO
self.weights_dtype = weights_dtype
self.modules_to_not_convert = modules_to_not_convert
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct
"""
accepted_weights = ["float8", "int8", "int4", "int2"]
if self.weights_dtype not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")

View File

@@ -0,0 +1 @@
from .quanto_quantizer import QuantoQuantizer

View File

@@ -0,0 +1,177 @@
from typing import TYPE_CHECKING, Any, Dict, List, Union
from diffusers.utils.import_utils import is_optimum_quanto_version
from ...utils import (
get_module_from_name,
is_accelerate_available,
is_accelerate_version,
is_optimum_quanto_available,
is_torch_available,
logging,
)
from ..base import DiffusersQuantizer
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate.utils import CustomDtype, set_module_tensor_to_device
if is_optimum_quanto_available():
from .utils import _replace_with_quanto_layers
logger = logging.get_logger(__name__)
class QuantoQuantizer(DiffusersQuantizer):
r"""
Diffusers Quantizer for Optimum Quanto
"""
use_keep_in_fp32_modules = True
requires_calibration = False
required_packages = ["quanto", "accelerate"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
if not is_optimum_quanto_available():
raise ImportError(
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
)
if not is_optimum_quanto_version(">=", "0.2.6"):
raise ImportError(
"Loading an optimum-quanto quantized model requires `optimum-quanto>=0.2.6`. "
"Please upgrade your installation with `pip install --upgrade optimum-quanto"
)
if not is_accelerate_available():
raise ImportError(
"Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
)
device_map = kwargs.get("device_map", None)
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
raise ValueError(
"`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend"
)
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
# Quanto imports diffusers internally. This is here to prevent circular imports
from optimum.quanto import QModuleMixin, QTensor
from optimum.quanto.tensor.packed import PackedTensor
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized and any(isinstance(module, t) for t in [QTensor, PackedTensor]):
return True
elif isinstance(module, QModuleMixin) and "weight" in tensor_name:
return not module.frozen
return False
def create_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
*args,
**kwargs,
):
"""
Create the quantized parameter by calling .freeze() after setting it to the module.
"""
dtype = kwargs.get("dtype", torch.float32)
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
setattr(module, tensor_name, param_value)
else:
set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
module.freeze()
module.weight.requires_grad = False
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if is_accelerate_version(">=", "0.27.0"):
mapping = {
"int8": torch.int8,
"float8": CustomDtype.FP8,
"int4": CustomDtype.INT4,
"int2": CustomDtype.INT2,
}
target_dtype = mapping[self.quantization_config.weights_dtype]
return target_dtype
def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
if torch_dtype is None:
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
torch_dtype = torch.float32
return torch_dtype
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
# Quanto imports diffusers internally. This is here to prevent circular imports
from optimum.quanto import QModuleMixin
not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, QModuleMixin):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
and not missing.endswith(".weight")
and not missing.endswith(".bias")
):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
self.modules_to_not_convert.extend(keep_in_fp32_modules)
model = _replace_with_quanto_layers(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
pre_quantized=self.pre_quantized,
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model, **kwargs):
return model
@property
def is_trainable(self):
return True
@property
def is_serializable(self):
return True

View File

@@ -0,0 +1,60 @@
import torch.nn as nn
from ...utils import is_accelerate_available, logging
logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate import init_empty_weights
def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list, pre_quantized=False):
# Quanto imports diffusers internally. These are placed here to avoid circular imports
from optimum.quanto import QLinear, freeze, qfloat8, qint2, qint4, qint8
def _get_weight_type(dtype: str):
return {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}[dtype]
def _replace_layers(model, quantization_config, modules_to_not_convert):
has_children = list(model.children())
if not has_children:
return model
for name, module in model.named_children():
_replace_layers(module, quantization_config, modules_to_not_convert)
if name in modules_to_not_convert:
continue
if isinstance(module, nn.Linear):
with init_empty_weights():
qlinear = QLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
dtype=module.weight.dtype,
weights=_get_weight_type(quantization_config.weights_dtype),
)
model._modules[name] = qlinear
model._modules[name].source_cls = type(module)
model._modules[name].requires_grad_(False)
return model
model = _replace_layers(model, quantization_config, modules_to_not_convert)
has_been_replaced = any(isinstance(replaced_module, QLinear) for _, replaced_module in model.named_modules())
if not has_been_replaced:
logger.warning(
f"{model.__class__.__name__} does not appear to have any `nn.Linear` modules. Quantization will not be applied."
" Please check your model architecture, or submit an issue on Github if you think this is a bug."
" https://github.com/huggingface/diffusers/issues/new"
)
# We need to freeze the pre_quantized model in order for the loaded state_dict and model state dict
# to match when trying to load weights with load_model_dict_into_meta
if pre_quantized:
freeze(model)
return model

View File

@@ -79,6 +79,8 @@ from .import_utils import (
is_matplotlib_available,
is_note_seq_available,
is_onnx_available,
is_optimum_quanto_available,
is_optimum_quanto_version,
is_peft_available,
is_peft_version,
is_safetensors_available,

View File

@@ -0,0 +1,17 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class BitsAndBytesConfig(metaclass=DummyObject):
_backends = ["bitsandbytes"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["bitsandbytes"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["bitsandbytes"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["bitsandbytes"])

View File

@@ -0,0 +1,17 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class GGUFQuantizationConfig(metaclass=DummyObject):
_backends = ["gguf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["gguf"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["gguf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["gguf"])

View File

@@ -0,0 +1,17 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class QuantoConfig(metaclass=DummyObject):
_backends = ["optimum_quanto"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["optimum_quanto"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["optimum_quanto"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["optimum_quanto"])

View File

@@ -0,0 +1,17 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class TorchAoConfig(metaclass=DummyObject):
_backends = ["torchao"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchao"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torchao"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torchao"])

View File

@@ -365,6 +365,15 @@ if _is_torchao_available:
_is_torchao_available = False
_is_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
if _is_optimum_quanto_available:
try:
_optimum_quanto_version = importlib_metadata.version("optimum_quanto")
logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}")
except importlib_metadata.PackageNotFoundError:
_is_optimum_quanto_available = False
def is_torch_available():
return _torch_available
@@ -493,6 +502,10 @@ def is_torchao_available():
return _is_torchao_available
def is_optimum_quanto_available():
return _is_optimum_quanto_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -636,6 +649,11 @@ TORCHAO_IMPORT_ERROR = """
torchao`
"""
QUANTO_IMPORT_ERROR = """
{0} requires the optimum-quanto library but it was not found in your environment. You can install it with pip: `pip
install optimum-quanto`
"""
BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
@@ -663,6 +681,7 @@ BACKENDS_MAPPING = OrderedDict(
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
]
)
@@ -864,6 +883,21 @@ def is_k_diffusion_version(operation: str, version: str):
return compare_versions(parse(_k_diffusion_version), operation, version)
def is_optimum_quanto_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _is_optimum_quanto_available:
return False
return compare_versions(parse(_optimum_quanto_version), operation, version)
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects

View File

@@ -0,0 +1,346 @@
import gc
import tempfile
import unittest
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
from diffusers.models.attention_processor import Attention
from diffusers.utils import is_optimum_quanto_available, is_torch_available
from diffusers.utils.testing_utils import (
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
require_big_gpu_with_torch_cuda,
torch_device,
)
if is_optimum_quanto_available():
from optimum.quanto import QLinear
if is_torch_available():
import torch
import torch.nn as nn
class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""
def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)
def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
@nightly
@require_big_gpu_with_torch_cuda
@require_accelerate
class QuantoBaseTesterMixin:
model_id = None
pipeline_model_id = None
model_cls = None
torch_dtype = torch.bfloat16
# the expected reduction in peak memory used compared to an unquantized model expressed as a percentage
expected_memory_reduction = 0.0
keep_in_fp32_module = ""
modules_to_not_convert = ""
_test_torch_compile = False
def setUp(self):
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()
def tearDown(self):
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()
def get_dummy_init_kwargs(self):
return {"weights_dtype": "float8"}
def get_dummy_model_init_kwargs(self):
return {
"pretrained_model_name_or_path": self.model_id,
"torch_dtype": self.torch_dtype,
"quantization_config": QuantoConfig(**self.get_dummy_init_kwargs()),
}
def test_quanto_layers(self):
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
assert isinstance(module, QLinear)
def test_quanto_memory_usage(self):
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
unquantized_model_memory = unquantized_model.get_memory_footprint() / 1024**3
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
inputs = self.get_dummy_inputs()
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
model.to(torch_device)
with torch.no_grad():
model(**inputs)
max_memory = torch.cuda.max_memory_allocated() / 1024**3
assert (1.0 - (max_memory / unquantized_model_memory)) >= self.expected_memory_reduction
def test_keep_modules_in_fp32(self):
r"""
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
Also ensures if inference works.
"""
_keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
model.to("cuda")
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if name in model._keep_in_fp32_modules:
assert module.weight.dtype == torch.float32
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
def test_modules_to_not_convert(self):
init_kwargs = self.get_dummy_model_init_kwargs()
quantization_config_kwargs = self.get_dummy_init_kwargs()
quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert})
quantization_config = QuantoConfig(**quantization_config_kwargs)
init_kwargs.update({"quantization_config": quantization_config})
model = self.model_cls.from_pretrained(**init_kwargs)
model.to("cuda")
for name, module in model.named_modules():
if name in self.modules_to_not_convert:
assert not isinstance(module, QLinear)
def test_dtype_assignment(self):
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
with self.assertRaises(ValueError):
# Tries with a `dtype`
model.to(torch.float16)
with self.assertRaises(ValueError):
# Tries with a `device` and `dtype`
model.to(device="cuda:0", dtype=torch.float16)
with self.assertRaises(ValueError):
# Tries with a cast
model.float()
with self.assertRaises(ValueError):
# Tries with a cast
model.half()
# This should work
model.to("cuda")
def test_serialization(self):
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
inputs = self.get_dummy_inputs()
model.to(torch_device)
with torch.no_grad():
model_output = model(**inputs)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
saved_model = self.model_cls.from_pretrained(
tmp_dir,
torch_dtype=torch.bfloat16,
)
saved_model.to(torch_device)
with torch.no_grad():
saved_model_output = saved_model(**inputs)
assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5)
def test_torch_compile(self):
if not self._test_torch_compile:
return
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False)
model.to(torch_device)
with torch.no_grad():
model_output = model(**self.get_dummy_inputs()).sample
compiled_model.to(torch_device)
with torch.no_grad():
compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
model_output = model_output.detach().float().cpu().numpy()
compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten())
assert max_diff < 1e-3
def test_device_map_error(self):
with self.assertRaises(ValueError):
_ = self.model_cls.from_pretrained(
**self.get_dummy_model_init_kwargs(), device_map={0: "8GB", "cpu": "16GB"}
)
class FluxTransformerQuantoMixin(QuantoBaseTesterMixin):
model_id = "hf-internal-testing/tiny-flux-transformer"
model_cls = FluxTransformer2DModel
pipeline_cls = FluxPipeline
torch_dtype = torch.bfloat16
keep_in_fp32_module = "proj_out"
modules_to_not_convert = ["proj_out"]
_test_torch_compile = False
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
(1, 512, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"pooled_projections": torch.randn(
(1, 768),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
"img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
}
def get_dummy_training_inputs(self, device=None, seed: int = 0):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
height = width = 4
sequence_length = 48
embedding_dim = 32
torch.manual_seed(seed)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
device, dtype=torch.bfloat16
)
torch.manual_seed(seed)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
torch.manual_seed(seed)
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"txt_ids": text_ids,
"img_ids": image_ids,
"timestep": timestep,
}
def test_model_cpu_offload(self):
init_kwargs = self.get_dummy_init_kwargs()
transformer = self.model_cls.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
quantization_config=QuantoConfig(**init_kwargs),
subfolder="transformer",
torch_dtype=torch.bfloat16,
)
pipe = self.pipeline_cls.from_pretrained(
"hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload(device=torch_device)
_ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
def test_training(self):
quantization_config = QuantoConfig(**self.get_dummy_init_kwargs())
quantized_model = self.model_cls.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
).to(torch_device)
for param in quantized_model.parameters():
# freeze the model as only adapter layers will be trained
param.requires_grad = False
if param.ndim == 1:
param.data = param.data.to(torch.float32)
for _, module in quantized_model.named_modules():
if isinstance(module, Attention):
module.to_q = LoRALayer(module.to_q, rank=4)
module.to_k = LoRALayer(module.to_k, rank=4)
module.to_v = LoRALayer(module.to_v, rank=4)
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
inputs = self.get_dummy_training_inputs(torch_device)
output = quantized_model(**inputs)[0]
output.norm().backward()
for module in quantized_model.modules():
if isinstance(module, LoRALayer):
self.assertTrue(module.adapter[1].weight.grad is not None)
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.3
def get_dummy_init_kwargs(self):
return {"weights_dtype": "float8"}
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.3
_test_torch_compile = True
def get_dummy_init_kwargs(self):
return {"weights_dtype": "int8"}
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.55
def get_dummy_init_kwargs(self):
return {"weights_dtype": "int4"}
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.65
def get_dummy_init_kwargs(self):
return {"weights_dtype": "int2"}