mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
41 Commits
cached-lor
...
add-quanto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1b46a32f2a | ||
|
|
cf4694e19e | ||
|
|
deebc22ebd | ||
|
|
d5ab9cadc0 | ||
|
|
6cad1d537a | ||
|
|
bb7fb66b4d | ||
|
|
8163687e08 | ||
|
|
8afff1bb0d | ||
|
|
830b7345b7 | ||
|
|
4516f2238b | ||
|
|
156db084d2 | ||
|
|
963559f69b | ||
|
|
dbaef7c3a4 | ||
|
|
f512c2893f | ||
|
|
4eabed7f97 | ||
|
|
0736f87d1f | ||
|
|
6cf9a78108 | ||
|
|
c29684f44f | ||
|
|
c4b6e24fe5 | ||
|
|
79901e4dec | ||
|
|
9a72fefd4b | ||
|
|
d355e6aa9b | ||
|
|
c80d4d4a72 | ||
|
|
2c7f30325d | ||
|
|
b136d239e4 | ||
|
|
9e5a3d0766 | ||
|
|
559f12470a | ||
|
|
e090177766 | ||
|
|
7b841dc52d | ||
|
|
4ae86916b2 | ||
|
|
e96686e9c9 | ||
|
|
7472f18b9c | ||
|
|
f734c096e7 | ||
|
|
5cff237f75 | ||
|
|
f67d97c0ac | ||
|
|
f4c14c222d | ||
|
|
f52050a39f | ||
|
|
39e20e2405 | ||
|
|
aa8cdaf056 | ||
|
|
ba5bba74f2 | ||
|
|
ff50418472 |
2
.github/workflows/nightly_tests.yml
vendored
2
.github/workflows/nightly_tests.yml
vendored
@@ -418,6 +418,8 @@ jobs:
|
||||
test_location: "gguf"
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
- backend: "optimum_quanto"
|
||||
test_location: "quanto"
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
|
||||
@@ -173,6 +173,8 @@
|
||||
title: gguf
|
||||
- local: quantization/torchao
|
||||
title: torchao
|
||||
- local: quantization/quanto
|
||||
title: quanto
|
||||
title: Quantization Methods
|
||||
- sections:
|
||||
- local: optimization/fp16
|
||||
|
||||
@@ -31,6 +31,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
|
||||
## GGUFQuantizationConfig
|
||||
|
||||
[[autodoc]] GGUFQuantizationConfig
|
||||
|
||||
## QuantoConfig
|
||||
|
||||
[[autodoc]] QuantoConfig
|
||||
|
||||
## TorchAoConfig
|
||||
|
||||
[[autodoc]] TorchAoConfig
|
||||
|
||||
@@ -36,5 +36,6 @@ Diffusers currently supports the following quantization methods.
|
||||
- [BitsandBytes](./bitsandbytes)
|
||||
- [TorchAO](./torchao)
|
||||
- [GGUF](./gguf)
|
||||
- [Quanto](./quanto.md)
|
||||
|
||||
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||
|
||||
148
docs/source/en/quantization/quanto.md
Normal file
148
docs/source/en/quantization/quanto.md
Normal file
@@ -0,0 +1,148 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
-->
|
||||
|
||||
# Quanto
|
||||
|
||||
[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind:
|
||||
|
||||
- All features are available in eager mode (works with non-traceable models)
|
||||
- Supports quantization aware training
|
||||
- Quantized models are compatible with `torch.compile`
|
||||
- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU)
|
||||
|
||||
In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate`
|
||||
|
||||
```shell
|
||||
pip install optimum-quanto accelerate
|
||||
```
|
||||
|
||||
Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
## Skipping Quantization on specific modules
|
||||
|
||||
It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"])
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
```
|
||||
|
||||
## Using `from_single_file` with the Quanto Backend
|
||||
|
||||
`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## Saving Quantized models
|
||||
|
||||
Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method.
|
||||
|
||||
The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized
|
||||
with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained`
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="float8")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
# save quantized model to reuse
|
||||
transformer.save_pretrained("<your quantized model save path>")
|
||||
|
||||
# you can reload your quantized model with
|
||||
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>")
|
||||
```
|
||||
|
||||
## Using `torch.compile` with Quanto
|
||||
|
||||
Currently the Quanto backend supports `torch.compile` for the following quantization types:
|
||||
|
||||
- `int8` weights
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
quantization_config = QuantoConfig(weights_dtype="int8")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, torch_dtype=torch_dtype
|
||||
)
|
||||
pipe.to("cuda")
|
||||
images = pipe("A cat holding a sign that says hello").images[0]
|
||||
images.save("flux-quanto-compile.png")
|
||||
```
|
||||
|
||||
## Supported Quantization Types
|
||||
|
||||
### Weights
|
||||
|
||||
- float8
|
||||
- int8
|
||||
- int4
|
||||
- int2
|
||||
|
||||
|
||||
9
setup.py
9
setup.py
@@ -128,6 +128,10 @@ _deps = [
|
||||
"GitPython<3.1.19",
|
||||
"scipy",
|
||||
"onnx",
|
||||
"optimum_quanto>=0.2.6",
|
||||
"gguf>=0.10.0",
|
||||
"torchao>=0.7.0",
|
||||
"bitsandbytes>=0.43.3",
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"tensorboard",
|
||||
@@ -235,6 +239,11 @@ extras["test"] = deps_list(
|
||||
)
|
||||
extras["torch"] = deps_list("torch", "accelerate")
|
||||
|
||||
extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
|
||||
extras["gguf"] = deps_list("gguf", "accelerate")
|
||||
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
|
||||
extras["torchao"] = deps_list("torchao", "accelerate")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
extras["flax"] = [] # jax is not supported on windows
|
||||
else:
|
||||
|
||||
@@ -2,6 +2,15 @@ __version__ = "0.33.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from diffusers.quantizers import quantization_config
|
||||
from diffusers.utils import dummy_gguf_objects
|
||||
from diffusers.utils.import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_gguf_available,
|
||||
is_optimum_quanto_version,
|
||||
is_torchao_available,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
@@ -11,6 +20,7 @@ from .utils import (
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_optimum_quanto_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
@@ -32,7 +42,7 @@ _import_structure = {
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"pipelines": [],
|
||||
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
|
||||
"quantizers.quantization_config": [],
|
||||
"schedulers": [],
|
||||
"utils": [
|
||||
"OptionalDependencyNotAvailable",
|
||||
@@ -54,6 +64,55 @@ _import_structure = {
|
||||
],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_bitsandbytes_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_bitsandbytes_objects
|
||||
|
||||
_import_structure["utils.dummy_bitsandbytes_objects"] = [
|
||||
name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
|
||||
|
||||
try:
|
||||
if not is_gguf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_gguf_objects
|
||||
|
||||
_import_structure["utils.dummy_gguf_objects"] = [
|
||||
name for name in dir(dummy_gguf_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
|
||||
|
||||
try:
|
||||
if not is_torchao_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_torchao_objects
|
||||
|
||||
_import_structure["utils.dummy_torchao_objects"] = [
|
||||
name for name in dir(dummy_torchao_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
|
||||
|
||||
try:
|
||||
if not is_optimum_quanto_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_optimum_quanto_objects
|
||||
|
||||
_import_structure["utils.dummy_optimum_quanto_objects"] = [
|
||||
name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
|
||||
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -598,7 +657,38 @@ else:
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
|
||||
|
||||
try:
|
||||
if not is_bitsandbytes_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_bitsandbytes_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig
|
||||
|
||||
try:
|
||||
if not is_gguf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_gguf_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import GGUFQuantizationConfig
|
||||
|
||||
try:
|
||||
if not is_torchao_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torchao_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import TorchAoConfig
|
||||
|
||||
try:
|
||||
if not is_optimum_quanto_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_optimum_quanto_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import QuantoConfig
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
|
||||
@@ -35,6 +35,10 @@ deps = {
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"scipy": "scipy",
|
||||
"onnx": "onnx",
|
||||
"optimum_quanto": "optimum_quanto>=0.2.6",
|
||||
"gguf": "gguf>=0.10.0",
|
||||
"torchao": "torchao>=0.7.0",
|
||||
"bitsandbytes": "bitsandbytes>=0.43.3",
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"tensorboard": "tensorboard",
|
||||
|
||||
@@ -245,6 +245,9 @@ def load_model_dict_into_meta(
|
||||
):
|
||||
param = param.to(torch.float32)
|
||||
set_module_kwargs["dtype"] = torch.float32
|
||||
# For quantizers have save weights using torch.float8_e4m3fn
|
||||
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
|
||||
pass
|
||||
else:
|
||||
param = param.to(dtype)
|
||||
set_module_kwargs["dtype"] = dtype
|
||||
@@ -292,7 +295,9 @@ def load_model_dict_into_meta(
|
||||
elif is_quantized and (
|
||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
|
||||
):
|
||||
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
|
||||
hf_quantizer.create_quantized_param(
|
||||
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
|
||||
)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
|
||||
|
||||
|
||||
@@ -26,8 +26,10 @@ from .quantization_config import (
|
||||
GGUFQuantizationConfig,
|
||||
QuantizationConfigMixin,
|
||||
QuantizationMethod,
|
||||
QuantoConfig,
|
||||
TorchAoConfig,
|
||||
)
|
||||
from .quanto import QuantoQuantizer
|
||||
from .torchao import TorchAoHfQuantizer
|
||||
|
||||
|
||||
@@ -35,6 +37,7 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
|
||||
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
|
||||
"gguf": GGUFQuantizer,
|
||||
"quanto": QuantoQuantizer,
|
||||
"torchao": TorchAoHfQuantizer,
|
||||
}
|
||||
|
||||
@@ -42,6 +45,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"bitsandbytes_4bit": BitsAndBytesConfig,
|
||||
"bitsandbytes_8bit": BitsAndBytesConfig,
|
||||
"gguf": GGUFQuantizationConfig,
|
||||
"quanto": QuantoConfig,
|
||||
"torchao": TorchAoConfig,
|
||||
}
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ class QuantizationMethod(str, Enum):
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GGUF = "gguf"
|
||||
TORCHAO = "torchao"
|
||||
QUANTO = "quanto"
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
@@ -686,3 +687,38 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
return (
|
||||
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantoConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
||||
loaded using `quanto`.
|
||||
|
||||
Args:
|
||||
weights_dtype (`str`, *optional*, defaults to `"int8"`):
|
||||
The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
|
||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
|
||||
modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weights_dtype: str = "int8",
|
||||
modules_to_not_convert: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.QUANTO
|
||||
self.weights_dtype = weights_dtype
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
accepted_weights = ["float8", "int8", "int4", "int2"]
|
||||
if self.weights_dtype not in accepted_weights:
|
||||
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
|
||||
|
||||
1
src/diffusers/quantizers/quanto/__init__.py
Normal file
1
src/diffusers/quantizers/quanto/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .quanto_quantizer import QuantoQuantizer
|
||||
177
src/diffusers/quantizers/quanto/quanto_quantizer.py
Normal file
177
src/diffusers/quantizers/quanto/quanto_quantizer.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from diffusers.utils.import_utils import is_optimum_quanto_version
|
||||
|
||||
from ...utils import (
|
||||
get_module_from_name,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_optimum_quanto_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.utils import CustomDtype, set_module_tensor_to_device
|
||||
|
||||
if is_optimum_quanto_available():
|
||||
from .utils import _replace_with_quanto_layers
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QuantoQuantizer(DiffusersQuantizer):
|
||||
r"""
|
||||
Diffusers Quantizer for Optimum Quanto
|
||||
"""
|
||||
|
||||
use_keep_in_fp32_modules = True
|
||||
requires_calibration = False
|
||||
required_packages = ["quanto", "accelerate"]
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_optimum_quanto_available():
|
||||
raise ImportError(
|
||||
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
|
||||
)
|
||||
if not is_optimum_quanto_version(">=", "0.2.6"):
|
||||
raise ImportError(
|
||||
"Loading an optimum-quanto quantized model requires `optimum-quanto>=0.2.6`. "
|
||||
"Please upgrade your installation with `pip install --upgrade optimum-quanto"
|
||||
)
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError(
|
||||
"Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
|
||||
)
|
||||
|
||||
device_map = kwargs.get("device_map", None)
|
||||
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
||||
raise ValueError(
|
||||
"`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend"
|
||||
)
|
||||
|
||||
def check_if_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
# Quanto imports diffusers internally. This is here to prevent circular imports
|
||||
from optimum.quanto import QModuleMixin, QTensor
|
||||
from optimum.quanto.tensor.packed import PackedTensor
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if self.pre_quantized and any(isinstance(module, t) for t in [QTensor, PackedTensor]):
|
||||
return True
|
||||
elif isinstance(module, QModuleMixin) and "weight" in tensor_name:
|
||||
return not module.frozen
|
||||
|
||||
return False
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create the quantized parameter by calling .freeze() after setting it to the module.
|
||||
"""
|
||||
|
||||
dtype = kwargs.get("dtype", torch.float32)
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if self.pre_quantized:
|
||||
setattr(module, tensor_name, param_value)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
|
||||
module.freeze()
|
||||
module.weight.requires_grad = False
|
||||
|
||||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
||||
return max_memory
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if is_accelerate_version(">=", "0.27.0"):
|
||||
mapping = {
|
||||
"int8": torch.int8,
|
||||
"float8": CustomDtype.FP8,
|
||||
"int4": CustomDtype.INT4,
|
||||
"int2": CustomDtype.INT2,
|
||||
}
|
||||
target_dtype = mapping[self.quantization_config.weights_dtype]
|
||||
|
||||
return target_dtype
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
|
||||
torch_dtype = torch.float32
|
||||
return torch_dtype
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||
# Quanto imports diffusers internally. This is here to prevent circular imports
|
||||
from optimum.quanto import QModuleMixin
|
||||
|
||||
not_missing_keys = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, QModuleMixin):
|
||||
for missing in missing_keys:
|
||||
if (
|
||||
(name in missing or name in f"{prefix}.{missing}")
|
||||
and not missing.endswith(".weight")
|
||||
and not missing.endswith(".bias")
|
||||
):
|
||||
not_missing_keys.append(missing)
|
||||
return [k for k in missing_keys if k not in not_missing_keys]
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
**kwargs,
|
||||
):
|
||||
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
|
||||
|
||||
if not isinstance(self.modules_to_not_convert, list):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
|
||||
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
|
||||
model = _replace_with_quanto_layers(
|
||||
model,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
quantization_config=self.quantization_config,
|
||||
pre_quantized=self.pre_quantized,
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
return True
|
||||
60
src/diffusers/quantizers/quanto/utils.py
Normal file
60
src/diffusers/quantizers/quanto/utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils import is_accelerate_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
def _replace_with_quanto_layers(model, quantization_config, modules_to_not_convert: list, pre_quantized=False):
|
||||
# Quanto imports diffusers internally. These are placed here to avoid circular imports
|
||||
from optimum.quanto import QLinear, freeze, qfloat8, qint2, qint4, qint8
|
||||
|
||||
def _get_weight_type(dtype: str):
|
||||
return {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}[dtype]
|
||||
|
||||
def _replace_layers(model, quantization_config, modules_to_not_convert):
|
||||
has_children = list(model.children())
|
||||
if not has_children:
|
||||
return model
|
||||
|
||||
for name, module in model.named_children():
|
||||
_replace_layers(module, quantization_config, modules_to_not_convert)
|
||||
|
||||
if name in modules_to_not_convert:
|
||||
continue
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
with init_empty_weights():
|
||||
qlinear = QLinear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
dtype=module.weight.dtype,
|
||||
weights=_get_weight_type(quantization_config.weights_dtype),
|
||||
)
|
||||
model._modules[name] = qlinear
|
||||
model._modules[name].source_cls = type(module)
|
||||
model._modules[name].requires_grad_(False)
|
||||
|
||||
return model
|
||||
|
||||
model = _replace_layers(model, quantization_config, modules_to_not_convert)
|
||||
has_been_replaced = any(isinstance(replaced_module, QLinear) for _, replaced_module in model.named_modules())
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
f"{model.__class__.__name__} does not appear to have any `nn.Linear` modules. Quantization will not be applied."
|
||||
" Please check your model architecture, or submit an issue on Github if you think this is a bug."
|
||||
" https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
# We need to freeze the pre_quantized model in order for the loaded state_dict and model state dict
|
||||
# to match when trying to load weights with load_model_dict_into_meta
|
||||
if pre_quantized:
|
||||
freeze(model)
|
||||
|
||||
return model
|
||||
@@ -79,6 +79,8 @@ from .import_utils import (
|
||||
is_matplotlib_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_optimum_quanto_available,
|
||||
is_optimum_quanto_version,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_safetensors_available,
|
||||
|
||||
17
src/diffusers/utils/dummy_bitsandbytes_objects.py
Normal file
17
src/diffusers/utils/dummy_bitsandbytes_objects.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class BitsAndBytesConfig(metaclass=DummyObject):
|
||||
_backends = ["bitsandbytes"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["bitsandbytes"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["bitsandbytes"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["bitsandbytes"])
|
||||
17
src/diffusers/utils/dummy_gguf_objects.py
Normal file
17
src/diffusers/utils/dummy_gguf_objects.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class GGUFQuantizationConfig(metaclass=DummyObject):
|
||||
_backends = ["gguf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["gguf"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["gguf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["gguf"])
|
||||
17
src/diffusers/utils/dummy_optimum_quanto_objects.py
Normal file
17
src/diffusers/utils/dummy_optimum_quanto_objects.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class QuantoConfig(metaclass=DummyObject):
|
||||
_backends = ["optimum_quanto"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["optimum_quanto"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["optimum_quanto"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["optimum_quanto"])
|
||||
17
src/diffusers/utils/dummy_torchao_objects.py
Normal file
17
src/diffusers/utils/dummy_torchao_objects.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class TorchAoConfig(metaclass=DummyObject):
|
||||
_backends = ["torchao"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torchao"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torchao"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torchao"])
|
||||
@@ -365,6 +365,15 @@ if _is_torchao_available:
|
||||
_is_torchao_available = False
|
||||
|
||||
|
||||
_is_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
|
||||
if _is_optimum_quanto_available:
|
||||
try:
|
||||
_optimum_quanto_version = importlib_metadata.version("optimum_quanto")
|
||||
logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_is_optimum_quanto_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
@@ -493,6 +502,10 @@ def is_torchao_available():
|
||||
return _is_torchao_available
|
||||
|
||||
|
||||
def is_optimum_quanto_available():
|
||||
return _is_optimum_quanto_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@@ -636,6 +649,11 @@ TORCHAO_IMPORT_ERROR = """
|
||||
torchao`
|
||||
"""
|
||||
|
||||
QUANTO_IMPORT_ERROR = """
|
||||
{0} requires the optimum-quanto library but it was not found in your environment. You can install it with pip: `pip
|
||||
install optimum-quanto`
|
||||
"""
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
||||
@@ -663,6 +681,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
|
||||
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
|
||||
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
|
||||
("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -864,6 +883,21 @@ def is_k_diffusion_version(operation: str, version: str):
|
||||
return compare_versions(parse(_k_diffusion_version), operation, version)
|
||||
|
||||
|
||||
def is_optimum_quanto_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Accelerate version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _is_optimum_quanto_available:
|
||||
return False
|
||||
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
||||
|
||||
|
||||
def get_objects_from_module(module):
|
||||
"""
|
||||
Returns a dict of object names and values in a module, while skipping private/internal objects
|
||||
|
||||
346
tests/quantization/quanto/test_quanto.py
Normal file
346
tests/quantization/quanto/test_quanto.py
Normal file
@@ -0,0 +1,346 @@
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.utils import is_optimum_quanto_available, is_torch_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import QLinear
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class LoRALayer(nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
|
||||
|
||||
Taken from
|
||||
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, rank: int):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.adapter = nn.Sequential(
|
||||
nn.Linear(module.in_features, rank, bias=False),
|
||||
nn.Linear(rank, module.out_features, bias=False),
|
||||
)
|
||||
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
||||
nn.init.normal_(self.adapter[0].weight, std=small_std)
|
||||
nn.init.zeros_(self.adapter[1].weight)
|
||||
self.adapter.to(module.weight.device)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
return self.module(input, *args, **kwargs) + self.adapter(input)
|
||||
|
||||
|
||||
@nightly
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_accelerate
|
||||
class QuantoBaseTesterMixin:
|
||||
model_id = None
|
||||
pipeline_model_id = None
|
||||
model_cls = None
|
||||
torch_dtype = torch.bfloat16
|
||||
# the expected reduction in peak memory used compared to an unquantized model expressed as a percentage
|
||||
expected_memory_reduction = 0.0
|
||||
keep_in_fp32_module = ""
|
||||
modules_to_not_convert = ""
|
||||
_test_torch_compile = False
|
||||
|
||||
def setUp(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def tearDown(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
return {"weights_dtype": "float8"}
|
||||
|
||||
def get_dummy_model_init_kwargs(self):
|
||||
return {
|
||||
"pretrained_model_name_or_path": self.model_id,
|
||||
"torch_dtype": self.torch_dtype,
|
||||
"quantization_config": QuantoConfig(**self.get_dummy_init_kwargs()),
|
||||
}
|
||||
|
||||
def test_quanto_layers(self):
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
assert isinstance(module, QLinear)
|
||||
|
||||
def test_quanto_memory_usage(self):
|
||||
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
|
||||
unquantized_model_memory = unquantized_model.get_memory_footprint() / 1024**3
|
||||
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
inputs = self.get_dummy_inputs()
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
max_memory = torch.cuda.max_memory_allocated() / 1024**3
|
||||
assert (1.0 - (max_memory / unquantized_model_memory)) >= self.expected_memory_reduction
|
||||
|
||||
def test_keep_modules_in_fp32(self):
|
||||
r"""
|
||||
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
|
||||
Also ensures if inference works.
|
||||
"""
|
||||
_keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
|
||||
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
|
||||
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
model.to("cuda")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if name in model._keep_in_fp32_modules:
|
||||
assert module.weight.dtype == torch.float32
|
||||
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
|
||||
|
||||
def test_modules_to_not_convert(self):
|
||||
init_kwargs = self.get_dummy_model_init_kwargs()
|
||||
|
||||
quantization_config_kwargs = self.get_dummy_init_kwargs()
|
||||
quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert})
|
||||
quantization_config = QuantoConfig(**quantization_config_kwargs)
|
||||
|
||||
init_kwargs.update({"quantization_config": quantization_config})
|
||||
|
||||
model = self.model_cls.from_pretrained(**init_kwargs)
|
||||
model.to("cuda")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if name in self.modules_to_not_convert:
|
||||
assert not isinstance(module, QLinear)
|
||||
|
||||
def test_dtype_assignment(self):
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype`
|
||||
model.to(torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device` and `dtype`
|
||||
model.to(device="cuda:0", dtype=torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a cast
|
||||
model.float()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a cast
|
||||
model.half()
|
||||
|
||||
# This should work
|
||||
model.to("cuda")
|
||||
|
||||
def test_serialization(self):
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
inputs = self.get_dummy_inputs()
|
||||
|
||||
model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
model_output = model(**inputs)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
saved_model = self.model_cls.from_pretrained(
|
||||
tmp_dir,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
saved_model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
saved_model_output = saved_model(**inputs)
|
||||
|
||||
assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_torch_compile(self):
|
||||
if not self._test_torch_compile:
|
||||
return
|
||||
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False)
|
||||
|
||||
model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
model_output = model(**self.get_dummy_inputs()).sample
|
||||
|
||||
compiled_model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
|
||||
|
||||
model_output = model_output.detach().float().cpu().numpy()
|
||||
compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_device_map_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = self.model_cls.from_pretrained(
|
||||
**self.get_dummy_model_init_kwargs(), device_map={0: "8GB", "cpu": "16GB"}
|
||||
)
|
||||
|
||||
|
||||
class FluxTransformerQuantoMixin(QuantoBaseTesterMixin):
|
||||
model_id = "hf-internal-testing/tiny-flux-transformer"
|
||||
model_cls = FluxTransformer2DModel
|
||||
pipeline_cls = FluxPipeline
|
||||
torch_dtype = torch.bfloat16
|
||||
keep_in_fp32_module = "proj_out"
|
||||
modules_to_not_convert = ["proj_out"]
|
||||
_test_torch_compile = False
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": torch.randn(
|
||||
(1, 512, 4096),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"pooled_projections": torch.randn(
|
||||
(1, 768),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||
"img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
def get_dummy_training_inputs(self, device=None, seed: int = 0):
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
height = width = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
torch.manual_seed(seed)
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(
|
||||
device, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"txt_ids": text_ids,
|
||||
"img_ids": image_ids,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
def test_model_cpu_offload(self):
|
||||
init_kwargs = self.get_dummy_init_kwargs()
|
||||
transformer = self.model_cls.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
quantization_config=QuantoConfig(**init_kwargs),
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = self.pipeline_cls.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
_ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
|
||||
|
||||
def test_training(self):
|
||||
quantization_config = QuantoConfig(**self.get_dummy_init_kwargs())
|
||||
quantized_model = self.model_cls.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-pipe",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(torch_device)
|
||||
|
||||
for param in quantized_model.parameters():
|
||||
# freeze the model as only adapter layers will be trained
|
||||
param.requires_grad = False
|
||||
if param.ndim == 1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
for _, module in quantized_model.named_modules():
|
||||
if isinstance(module, Attention):
|
||||
module.to_q = LoRALayer(module.to_q, rank=4)
|
||||
module.to_k = LoRALayer(module.to_k, rank=4)
|
||||
module.to_v = LoRALayer(module.to_v, rank=4)
|
||||
|
||||
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
|
||||
inputs = self.get_dummy_training_inputs(torch_device)
|
||||
output = quantized_model(**inputs)[0]
|
||||
output.norm().backward()
|
||||
|
||||
for module in quantized_model.modules():
|
||||
if isinstance(module, LoRALayer):
|
||||
self.assertTrue(module.adapter[1].weight.grad is not None)
|
||||
|
||||
|
||||
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||
expected_memory_reduction = 0.3
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
return {"weights_dtype": "float8"}
|
||||
|
||||
|
||||
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||
expected_memory_reduction = 0.3
|
||||
_test_torch_compile = True
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
return {"weights_dtype": "int8"}
|
||||
|
||||
|
||||
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||
expected_memory_reduction = 0.55
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
return {"weights_dtype": "int4"}
|
||||
|
||||
|
||||
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
|
||||
expected_memory_reduction = 0.65
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
return {"weights_dtype": "int2"}
|
||||
Reference in New Issue
Block a user