mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-20 19:34:48 +08:00
1292 lines
52 KiB
Python
1292 lines
52 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import gc
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
|
|
from diffusers.utils.import_utils import (
|
|
is_bitsandbytes_available,
|
|
is_gguf_available,
|
|
is_nvidia_modelopt_available,
|
|
is_optimum_quanto_available,
|
|
is_torchao_available,
|
|
is_torchao_version,
|
|
)
|
|
|
|
from ...testing_utils import (
|
|
backend_empty_cache,
|
|
is_bitsandbytes,
|
|
is_gguf,
|
|
is_modelopt,
|
|
is_quantization,
|
|
is_quanto,
|
|
is_torch_compile,
|
|
is_torchao,
|
|
require_accelerate,
|
|
require_accelerator,
|
|
require_bitsandbytes_version_greater,
|
|
require_gguf_version_greater_or_equal,
|
|
require_modelopt_version_greater_or_equal,
|
|
require_quanto,
|
|
require_torchao_version_greater_or_equal,
|
|
torch_device,
|
|
)
|
|
|
|
|
|
if is_nvidia_modelopt_available():
|
|
import modelopt.torch.quantization as mtq
|
|
|
|
if is_bitsandbytes_available():
|
|
import bitsandbytes as bnb
|
|
|
|
if is_optimum_quanto_available():
|
|
from optimum.quanto import QLinear
|
|
|
|
if is_gguf_available():
|
|
pass
|
|
|
|
if is_torchao_available():
|
|
if is_torchao_version(">=", "0.9.0"):
|
|
pass
|
|
|
|
|
|
class LoRALayer(torch.nn.Module):
|
|
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only.
|
|
|
|
Taken from
|
|
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
|
|
"""
|
|
|
|
def __init__(self, module: torch.nn.Module, rank: int):
|
|
super().__init__()
|
|
self.module = module
|
|
self.adapter = torch.nn.Sequential(
|
|
torch.nn.Linear(module.in_features, rank, bias=False),
|
|
torch.nn.Linear(rank, module.out_features, bias=False),
|
|
)
|
|
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
|
|
torch.nn.init.normal_(self.adapter[0].weight, std=small_std)
|
|
torch.nn.init.zeros_(self.adapter[1].weight)
|
|
self.adapter.to(module.weight.device)
|
|
|
|
def forward(self, input, *args, **kwargs):
|
|
return self.module(input, *args, **kwargs) + self.adapter(input)
|
|
|
|
|
|
@require_accelerator
|
|
class QuantizationTesterMixin:
|
|
"""
|
|
Base mixin class providing common test implementations for quantization testing.
|
|
|
|
Backend-specific mixins should:
|
|
1. Implement _create_quantized_model(config_kwargs)
|
|
2. Implement _verify_if_layer_quantized(name, module, config_kwargs)
|
|
3. Define their config dict (e.g., BNB_CONFIGS, QUANTO_WEIGHT_TYPES, etc.)
|
|
4. Use @pytest.mark.parametrize to create tests that call the common test methods below
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
|
|
|
Expected methods in test classes:
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
"""
|
|
|
|
def setup_method(self):
|
|
gc.collect()
|
|
backend_empty_cache(torch_device)
|
|
|
|
def teardown_method(self):
|
|
gc.collect()
|
|
backend_empty_cache(torch_device)
|
|
|
|
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
|
"""
|
|
Create a quantized model with the given config kwargs.
|
|
|
|
Args:
|
|
config_kwargs: Quantization config parameters
|
|
**extra_kwargs: Additional kwargs to pass to from_pretrained (e.g., device_map, offload_folder)
|
|
"""
|
|
raise NotImplementedError("Subclass must implement _create_quantized_model")
|
|
|
|
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
|
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
|
|
|
|
def _is_module_quantized(self, module):
|
|
"""
|
|
Check if a module is quantized. Returns True if quantized, False otherwise.
|
|
Default implementation tries _verify_if_layer_quantized and catches exceptions.
|
|
Subclasses can override for more efficient checking.
|
|
"""
|
|
try:
|
|
self._verify_if_layer_quantized("", module, {})
|
|
return True
|
|
except (AssertionError, AttributeError):
|
|
return False
|
|
|
|
def _load_unquantized_model(self):
|
|
kwargs = getattr(self, "pretrained_model_kwargs", {})
|
|
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
|
|
|
def _test_quantization_num_parameters(self, config_kwargs):
|
|
model = self._load_unquantized_model()
|
|
num_params = model.num_parameters()
|
|
|
|
model_quantized = self._create_quantized_model(config_kwargs)
|
|
num_params_quantized = model_quantized.num_parameters()
|
|
|
|
assert num_params == num_params_quantized, (
|
|
f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
|
|
)
|
|
|
|
def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2):
|
|
model = self._load_unquantized_model()
|
|
mem = model.get_memory_footprint()
|
|
|
|
model_quantized = self._create_quantized_model(config_kwargs)
|
|
mem_quantized = model_quantized.get_memory_footprint()
|
|
|
|
ratio = mem / mem_quantized
|
|
assert ratio >= expected_memory_reduction, (
|
|
f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
|
|
)
|
|
|
|
def _test_quantization_inference(self, config_kwargs):
|
|
model_quantized = self._create_quantized_model(config_kwargs)
|
|
|
|
with torch.no_grad():
|
|
inputs = self.get_dummy_inputs()
|
|
output = model_quantized(**inputs, return_dict=False)[0]
|
|
|
|
assert output is not None, "Model output is None"
|
|
assert not torch.isnan(output).any(), "Model output contains NaN"
|
|
|
|
def _test_quantization_dtype_assignment(self, config_kwargs):
|
|
model = self._create_quantized_model(config_kwargs)
|
|
|
|
with pytest.raises(ValueError):
|
|
model.to(torch.float16)
|
|
|
|
with pytest.raises(ValueError):
|
|
device_0 = f"{torch_device}:0"
|
|
model.to(device=device_0, dtype=torch.float16)
|
|
|
|
with pytest.raises(ValueError):
|
|
model.float()
|
|
|
|
with pytest.raises(ValueError):
|
|
model.half()
|
|
|
|
model.to(torch_device)
|
|
|
|
def _test_quantization_lora_inference(self, config_kwargs):
|
|
try:
|
|
from peft import LoraConfig
|
|
except ImportError:
|
|
pytest.skip("peft is not available")
|
|
|
|
from diffusers.loaders.peft import PeftAdapterMixin
|
|
|
|
if not issubclass(self.model_class, PeftAdapterMixin):
|
|
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__})")
|
|
|
|
model = self._create_quantized_model(config_kwargs)
|
|
|
|
lora_config = LoraConfig(
|
|
r=4,
|
|
lora_alpha=4,
|
|
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
|
init_lora_weights=False,
|
|
)
|
|
model.add_adapter(lora_config)
|
|
|
|
with torch.no_grad():
|
|
inputs = self.get_dummy_inputs()
|
|
output = model(**inputs, return_dict=False)[0]
|
|
|
|
assert output is not None, "Model output is None with LoRA"
|
|
assert not torch.isnan(output).any(), "Model output contains NaN with LoRA"
|
|
|
|
def _test_quantization_serialization(self, config_kwargs, tmp_path):
|
|
model = self._create_quantized_model(config_kwargs)
|
|
|
|
model.save_pretrained(str(tmp_path), safe_serialization=True)
|
|
|
|
model_loaded = self.model_class.from_pretrained(str(tmp_path))
|
|
|
|
with torch.no_grad():
|
|
inputs = self.get_dummy_inputs()
|
|
output = model_loaded(**inputs, return_dict=False)[0]
|
|
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
|
|
|
|
def _test_quantized_layers(self, config_kwargs):
|
|
model_fp = self._load_unquantized_model()
|
|
num_linear_layers = sum(1 for module in model_fp.modules() if isinstance(module, torch.nn.Linear))
|
|
|
|
model_quantized = self._create_quantized_model(config_kwargs)
|
|
|
|
num_fp32_modules = 0
|
|
if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules:
|
|
for name, module in model_quantized.named_modules():
|
|
if isinstance(module, torch.nn.Linear):
|
|
if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules):
|
|
num_fp32_modules += 1
|
|
|
|
expected_quantized_layers = num_linear_layers - num_fp32_modules
|
|
|
|
num_quantized_layers = 0
|
|
for name, module in model_quantized.named_modules():
|
|
if isinstance(module, torch.nn.Linear):
|
|
if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules:
|
|
if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules):
|
|
continue
|
|
self._verify_if_layer_quantized(name, module, config_kwargs)
|
|
num_quantized_layers += 1
|
|
|
|
assert num_quantized_layers > 0, (
|
|
f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
|
|
)
|
|
assert num_quantized_layers == expected_quantized_layers, (
|
|
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
|
|
)
|
|
|
|
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
|
|
"""
|
|
Test that modules specified in modules_to_not_convert are not quantized.
|
|
|
|
Args:
|
|
config_kwargs: Base quantization config kwargs
|
|
modules_to_not_convert: List of module names to exclude from quantization
|
|
"""
|
|
# Create config with modules_to_not_convert
|
|
config_kwargs_with_exclusion = config_kwargs.copy()
|
|
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
|
|
|
|
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
|
|
|
|
# Find a module that should NOT be quantized
|
|
found_excluded = False
|
|
for name, module in model_with_exclusion.named_modules():
|
|
if isinstance(module, torch.nn.Linear):
|
|
# Check if this module is in the exclusion list
|
|
if any(excluded in name for excluded in modules_to_not_convert):
|
|
found_excluded = True
|
|
# This module should NOT be quantized
|
|
assert not self._is_module_quantized(module), (
|
|
f"Module {name} should not be quantized but was found to be quantized"
|
|
)
|
|
|
|
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}"
|
|
|
|
# Find a module that SHOULD be quantized (not in exclusion list)
|
|
found_quantized = False
|
|
for name, module in model_with_exclusion.named_modules():
|
|
if isinstance(module, torch.nn.Linear):
|
|
# Check if this module is NOT in the exclusion list
|
|
if not any(excluded in name for excluded in modules_to_not_convert):
|
|
if self._is_module_quantized(module):
|
|
found_quantized = True
|
|
break
|
|
|
|
assert found_quantized, "No quantized layers found outside of excluded modules"
|
|
|
|
# Compare memory footprint with fully quantized model
|
|
model_fully_quantized = self._create_quantized_model(config_kwargs)
|
|
|
|
mem_with_exclusion = model_with_exclusion.get_memory_footprint()
|
|
mem_fully_quantized = model_fully_quantized.get_memory_footprint()
|
|
|
|
assert mem_with_exclusion > mem_fully_quantized, (
|
|
f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
|
|
)
|
|
|
|
def _test_quantization_device_map(self, config_kwargs):
|
|
"""
|
|
Test that quantized models work correctly with device_map="auto".
|
|
|
|
Args:
|
|
config_kwargs: Base quantization config kwargs
|
|
"""
|
|
model = self._create_quantized_model(config_kwargs, device_map="auto")
|
|
|
|
# Verify device map is set
|
|
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute"
|
|
assert model.hf_device_map is not None, "hf_device_map should not be None"
|
|
|
|
# Verify inference works
|
|
with torch.no_grad():
|
|
inputs = self.get_dummy_inputs()
|
|
output = model(**inputs, return_dict=False)[0]
|
|
assert output is not None, "Model output is None"
|
|
assert not torch.isnan(output).any(), "Model output contains NaN"
|
|
|
|
def _test_dequantize(self, config_kwargs):
|
|
"""
|
|
Test that dequantize() converts quantized model back to standard linear layers.
|
|
|
|
Args:
|
|
config_kwargs: Quantization config parameters
|
|
"""
|
|
model = self._create_quantized_model(config_kwargs)
|
|
|
|
# Verify model has dequantize method
|
|
if not hasattr(model, "dequantize"):
|
|
pytest.skip("Model does not have dequantize method")
|
|
|
|
# Dequantize the model
|
|
model.dequantize()
|
|
|
|
# Verify no modules are quantized after dequantization
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, torch.nn.Linear):
|
|
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"
|
|
|
|
# Verify inference still works after dequantization
|
|
with torch.no_grad():
|
|
inputs = self.get_dummy_inputs()
|
|
output = model(**inputs, return_dict=False)[0]
|
|
assert output is not None, "Model output is None after dequantization"
|
|
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
|
|
|
|
def _test_quantization_training(self, config_kwargs):
|
|
"""
|
|
Test that quantized models can be used for training with LoRA-like adapters.
|
|
|
|
This test:
|
|
1. Freezes all model parameters
|
|
2. Casts small parameters (e.g., layernorm) to fp32 for stability
|
|
3. Adds LoRA adapters to attention layers
|
|
4. Runs forward and backward passes
|
|
5. Verifies gradients are computed correctly
|
|
|
|
Args:
|
|
config_kwargs: Quantization config parameters
|
|
"""
|
|
model = self._create_quantized_model(config_kwargs)
|
|
|
|
# Step 1: freeze all parameters
|
|
for param in model.parameters():
|
|
param.requires_grad = False
|
|
if param.ndim == 1:
|
|
# cast small parameters (e.g. layernorm) to fp32 for stability
|
|
param.data = param.data.to(torch.float32)
|
|
|
|
# Step 2: add adapters to attention layers
|
|
adapter_count = 0
|
|
for _, module in model.named_modules():
|
|
if "Attention" in repr(type(module)):
|
|
if hasattr(module, "to_k"):
|
|
module.to_k = LoRALayer(module.to_k, rank=4)
|
|
adapter_count += 1
|
|
if hasattr(module, "to_q"):
|
|
module.to_q = LoRALayer(module.to_q, rank=4)
|
|
adapter_count += 1
|
|
if hasattr(module, "to_v"):
|
|
module.to_v = LoRALayer(module.to_v, rank=4)
|
|
adapter_count += 1
|
|
|
|
if adapter_count == 0:
|
|
pytest.skip("No attention layers found in model for adapter training test")
|
|
|
|
# Step 3: run forward and backward pass
|
|
inputs = self.get_dummy_inputs()
|
|
|
|
with torch.amp.autocast(torch_device, dtype=torch.float16):
|
|
out = model(**inputs, return_dict=False)[0]
|
|
out.norm().backward()
|
|
|
|
# Step 4: verify gradients are computed
|
|
for module in model.modules():
|
|
if isinstance(module, LoRALayer):
|
|
assert module.adapter[1].weight.grad is not None, "LoRA adapter gradient is None"
|
|
assert module.adapter[1].weight.grad.norm().item() > 0, "LoRA adapter gradient norm is zero"
|
|
|
|
|
|
@is_quantization
|
|
@is_bitsandbytes
|
|
@require_accelerator
|
|
@require_bitsandbytes_version_greater("0.43.2")
|
|
@require_accelerate
|
|
class BitsAndBytesConfigMixin:
|
|
"""
|
|
Base mixin providing BitsAndBytes quantization config and model creation.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained
|
|
"""
|
|
|
|
BNB_CONFIGS = {
|
|
"4bit_nf4": {
|
|
"load_in_4bit": True,
|
|
"bnb_4bit_quant_type": "nf4",
|
|
"bnb_4bit_compute_dtype": torch.float16,
|
|
},
|
|
"4bit_fp4": {
|
|
"load_in_4bit": True,
|
|
"bnb_4bit_quant_type": "fp4",
|
|
"bnb_4bit_compute_dtype": torch.float16,
|
|
},
|
|
"8bit": {
|
|
"load_in_8bit": True,
|
|
},
|
|
}
|
|
|
|
BNB_EXPECTED_MEMORY_REDUCTIONS = {
|
|
"4bit_nf4": 3.0,
|
|
"4bit_fp4": 3.0,
|
|
"8bit": 1.5,
|
|
}
|
|
|
|
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
|
config = BitsAndBytesConfig(**config_kwargs)
|
|
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
|
kwargs["quantization_config"] = config
|
|
kwargs.update(extra_kwargs)
|
|
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
|
|
|
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
|
expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params
|
|
assert module.weight.__class__ == expected_weight_class, (
|
|
f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
|
|
)
|
|
|
|
|
|
@is_bitsandbytes
|
|
@require_accelerator
|
|
@require_bitsandbytes_version_greater("0.43.2")
|
|
@require_accelerate
|
|
class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
|
|
"""
|
|
Mixin class for testing BitsAndBytes quantization on models.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
|
|
|
Expected methods to be implemented by subclasses:
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Optional class attributes:
|
|
- BNB_CONFIGS: Dict of config name -> BitsAndBytesConfig kwargs to test
|
|
|
|
Pytest mark: bitsandbytes
|
|
Use `pytest -m "not bitsandbytes"` to skip these tests
|
|
"""
|
|
|
|
@pytest.mark.parametrize(
|
|
"config_name",
|
|
list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
|
|
ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
|
|
)
|
|
def test_bnb_quantization_num_parameters(self, config_name):
|
|
self._test_quantization_num_parameters(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name])
|
|
|
|
@pytest.mark.parametrize(
|
|
"config_name",
|
|
list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
|
|
ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
|
|
)
|
|
def test_bnb_quantization_memory_footprint(self, config_name):
|
|
expected = BitsAndBytesConfigMixin.BNB_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
|
|
self._test_quantization_memory_footprint(
|
|
BitsAndBytesConfigMixin.BNB_CONFIGS[config_name], expected_memory_reduction=expected
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
"config_name",
|
|
list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
|
|
ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
|
|
)
|
|
def test_bnb_quantization_inference(self, config_name):
|
|
self._test_quantization_inference(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name])
|
|
|
|
@pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"])
|
|
def test_bnb_quantization_dtype_assignment(self, config_name):
|
|
self._test_quantization_dtype_assignment(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name])
|
|
|
|
@pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"])
|
|
def test_bnb_quantization_lora_inference(self, config_name):
|
|
self._test_quantization_lora_inference(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name])
|
|
|
|
@pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"])
|
|
def test_bnb_quantization_serialization(self, config_name, tmp_path):
|
|
self._test_quantization_serialization(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name], tmp_path)
|
|
|
|
@pytest.mark.parametrize(
|
|
"config_name",
|
|
list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
|
|
ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
|
|
)
|
|
def test_bnb_quantized_layers(self, config_name):
|
|
self._test_quantized_layers(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name])
|
|
|
|
@pytest.mark.parametrize(
|
|
"config_name",
|
|
list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
|
|
ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
|
|
)
|
|
def test_bnb_quantization_config_serialization(self, config_name):
|
|
model = self._create_quantized_model(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name])
|
|
|
|
assert "quantization_config" in model.config, "Missing quantization_config"
|
|
_ = model.config["quantization_config"].to_dict()
|
|
_ = model.config["quantization_config"].to_diff_dict()
|
|
_ = model.config["quantization_config"].to_json_string()
|
|
|
|
def test_bnb_original_dtype(self):
|
|
config_name = list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys())[0]
|
|
config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]
|
|
|
|
model = self._create_quantized_model(config_kwargs)
|
|
|
|
assert "_pre_quantization_dtype" in model.config, "Missing _pre_quantization_dtype"
|
|
assert model.config["_pre_quantization_dtype"] in [
|
|
torch.float16,
|
|
torch.float32,
|
|
torch.bfloat16,
|
|
], f"Unexpected dtype: {model.config['_pre_quantization_dtype']}"
|
|
|
|
def test_bnb_keep_modules_in_fp32(self):
|
|
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
|
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
|
|
|
|
config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"]
|
|
|
|
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
|
|
self.model_class._keep_in_fp32_modules = ["proj_out"]
|
|
|
|
try:
|
|
model = self._create_quantized_model(config_kwargs)
|
|
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, torch.nn.Linear):
|
|
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
|
|
assert module.weight.dtype == torch.float32, (
|
|
f"Module {name} should be FP32 but is {module.weight.dtype}"
|
|
)
|
|
else:
|
|
assert module.weight.dtype == torch.uint8, (
|
|
f"Module {name} should be uint8 but is {module.weight.dtype}"
|
|
)
|
|
|
|
with torch.no_grad():
|
|
inputs = self.get_dummy_inputs()
|
|
_ = model(**inputs)
|
|
finally:
|
|
if original_fp32_modules is not None:
|
|
self.model_class._keep_in_fp32_modules = original_fp32_modules
|
|
|
|
def test_bnb_modules_to_not_convert(self):
|
|
"""Test that modules_to_not_convert parameter works correctly."""
|
|
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
|
if modules_to_exclude is None:
|
|
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
|
|
|
self._test_quantization_modules_to_not_convert(
|
|
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude
|
|
)
|
|
|
|
@pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"])
|
|
def test_bnb_device_map(self, config_name):
|
|
"""Test that device_map='auto' works correctly with quantization."""
|
|
self._test_quantization_device_map(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name])
|
|
|
|
def test_bnb_dequantize(self):
|
|
"""Test that dequantize() works correctly."""
|
|
self._test_dequantize(BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"])
|
|
|
|
def test_bnb_training(self):
|
|
"""Test that quantized models can be used for training with adapters."""
|
|
self._test_quantization_training(BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"])
|
|
|
|
|
|
@is_quantization
|
|
@is_quanto
|
|
@require_quanto
|
|
@require_accelerate
|
|
@require_accelerator
|
|
class QuantoConfigMixin:
|
|
"""
|
|
Base mixin providing Quanto quantization config and model creation.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained
|
|
"""
|
|
|
|
QUANTO_WEIGHT_TYPES = {
|
|
"float8": {"weights_dtype": "float8"},
|
|
"int8": {"weights_dtype": "int8"},
|
|
"int4": {"weights_dtype": "int4"},
|
|
"int2": {"weights_dtype": "int2"},
|
|
}
|
|
|
|
QUANTO_EXPECTED_MEMORY_REDUCTIONS = {
|
|
"float8": 1.5,
|
|
"int8": 1.5,
|
|
"int4": 3.0,
|
|
"int2": 7.0,
|
|
}
|
|
|
|
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
|
config = QuantoConfig(**config_kwargs)
|
|
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
|
kwargs["quantization_config"] = config
|
|
kwargs.update(extra_kwargs)
|
|
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
|
|
|
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
|
assert isinstance(module, QLinear), f"Layer {name} is not QLinear, got {type(module)}"
|
|
|
|
|
|
@is_quanto
|
|
@require_quanto
|
|
@require_accelerate
|
|
@require_accelerator
|
|
class QuantoTesterMixin(QuantoConfigMixin, QuantizationTesterMixin):
|
|
"""
|
|
Mixin class for testing Quanto quantization on models.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
|
|
|
Expected methods to be implemented by subclasses:
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Optional class attributes:
|
|
- QUANTO_WEIGHT_TYPES: Dict of weight_type_name -> qtype
|
|
|
|
Pytest mark: quanto
|
|
Use `pytest -m "not quanto"` to skip these tests
|
|
"""
|
|
|
|
@pytest.mark.parametrize(
|
|
"weight_type_name",
|
|
list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()),
|
|
ids=list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()),
|
|
)
|
|
def test_quanto_quantization_num_parameters(self, weight_type_name):
|
|
self._test_quantization_num_parameters(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name])
|
|
|
|
@pytest.mark.parametrize(
|
|
"weight_type_name",
|
|
list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()),
|
|
ids=list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()),
|
|
)
|
|
def test_quanto_quantization_memory_footprint(self, weight_type_name):
|
|
expected = QuantoConfigMixin.QUANTO_EXPECTED_MEMORY_REDUCTIONS.get(weight_type_name, 1.2)
|
|
self._test_quantization_memory_footprint(
|
|
QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name], expected_memory_reduction=expected
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
"weight_type_name",
|
|
list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()),
|
|
ids=list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()),
|
|
)
|
|
def test_quanto_quantization_inference(self, weight_type_name):
|
|
self._test_quantization_inference(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name])
|
|
|
|
@pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"])
|
|
def test_quanto_quantized_layers(self, weight_type_name):
|
|
self._test_quantized_layers(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name])
|
|
|
|
@pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"])
|
|
def test_quanto_quantization_lora_inference(self, weight_type_name):
|
|
self._test_quantization_lora_inference(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name])
|
|
|
|
@pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"])
|
|
def test_quanto_quantization_serialization(self, weight_type_name, tmp_path):
|
|
self._test_quantization_serialization(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name], tmp_path)
|
|
|
|
def test_quanto_modules_to_not_convert(self):
|
|
"""Test that modules_to_not_convert parameter works correctly."""
|
|
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
|
if modules_to_exclude is None:
|
|
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
|
|
|
self._test_quantization_modules_to_not_convert(
|
|
QuantoConfigMixin.QUANTO_WEIGHT_TYPES["int8"], modules_to_exclude
|
|
)
|
|
|
|
def test_quanto_device_map(self):
|
|
"""Test that device_map='auto' works correctly with quantization."""
|
|
self._test_quantization_device_map(QuantoConfigMixin.QUANTO_WEIGHT_TYPES["int8"])
|
|
|
|
def test_quanto_dequantize(self):
|
|
"""Test that dequantize() works correctly."""
|
|
self._test_dequantize(QuantoConfigMixin.QUANTO_WEIGHT_TYPES["int8"])
|
|
|
|
|
|
@is_quantization
|
|
@is_torchao
|
|
@require_accelerator
|
|
@require_torchao_version_greater_or_equal("0.7.0")
|
|
class TorchAoConfigMixin:
|
|
"""
|
|
Base mixin providing TorchAO quantization config and model creation.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained
|
|
"""
|
|
|
|
TORCHAO_QUANT_TYPES = {
|
|
"int4wo": {"quant_type": "int4_weight_only"},
|
|
"int8wo": {"quant_type": "int8_weight_only"},
|
|
"int8dq": {"quant_type": "int8_dynamic_activation_int8_weight"},
|
|
}
|
|
|
|
TORCHAO_EXPECTED_MEMORY_REDUCTIONS = {
|
|
"int4wo": 3.0,
|
|
"int8wo": 1.5,
|
|
"int8dq": 1.5,
|
|
}
|
|
|
|
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
|
config = TorchAoConfig(**config_kwargs)
|
|
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
|
kwargs["quantization_config"] = config
|
|
kwargs.update(extra_kwargs)
|
|
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
|
|
|
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
|
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
|
|
|
|
|
|
@is_torchao
|
|
@require_accelerator
|
|
@require_torchao_version_greater_or_equal("0.7.0")
|
|
class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
|
|
"""
|
|
Mixin class for testing TorchAO quantization on models.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
|
|
|
Expected methods to be implemented by subclasses:
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Optional class attributes:
|
|
- TORCHAO_QUANT_TYPES: Dict of quantization type strings to test
|
|
|
|
Pytest mark: torchao
|
|
Use `pytest -m "not torchao"` to skip these tests
|
|
"""
|
|
|
|
@pytest.mark.parametrize(
|
|
"quant_type",
|
|
list(TorchAoConfigMixin.TORCHAO_QUANT_TYPES.keys()),
|
|
ids=list(TorchAoConfigMixin.TORCHAO_QUANT_TYPES.keys()),
|
|
)
|
|
def test_torchao_quantization_num_parameters(self, quant_type):
|
|
self._test_quantization_num_parameters(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type])
|
|
|
|
@pytest.mark.parametrize(
|
|
"quant_type",
|
|
list(TorchAoConfigMixin.TORCHAO_QUANT_TYPES.keys()),
|
|
ids=list(TorchAoConfigMixin.TORCHAO_QUANT_TYPES.keys()),
|
|
)
|
|
def test_torchao_quantization_memory_footprint(self, quant_type):
|
|
expected = TorchAoConfigMixin.TORCHAO_EXPECTED_MEMORY_REDUCTIONS.get(quant_type, 1.2)
|
|
self._test_quantization_memory_footprint(
|
|
TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type], expected_memory_reduction=expected
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
"quant_type",
|
|
list(TorchAoConfigMixin.TORCHAO_QUANT_TYPES.keys()),
|
|
ids=list(TorchAoConfigMixin.TORCHAO_QUANT_TYPES.keys()),
|
|
)
|
|
def test_torchao_quantization_inference(self, quant_type):
|
|
self._test_quantization_inference(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type])
|
|
|
|
@pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"])
|
|
def test_torchao_quantized_layers(self, quant_type):
|
|
self._test_quantized_layers(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type])
|
|
|
|
@pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"])
|
|
def test_torchao_quantization_lora_inference(self, quant_type):
|
|
self._test_quantization_lora_inference(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type])
|
|
|
|
@pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"])
|
|
def test_torchao_quantization_serialization(self, quant_type, tmp_path):
|
|
self._test_quantization_serialization(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type], tmp_path)
|
|
|
|
def test_torchao_modules_to_not_convert(self):
|
|
"""Test that modules_to_not_convert parameter works correctly."""
|
|
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
|
if modules_to_exclude is None:
|
|
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
|
|
|
self._test_quantization_modules_to_not_convert(
|
|
TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
|
|
)
|
|
|
|
def test_torchao_device_map(self):
|
|
"""Test that device_map='auto' works correctly with quantization."""
|
|
self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])
|
|
|
|
def test_torchao_dequantize(self):
|
|
"""Test that dequantize() works correctly."""
|
|
self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])
|
|
|
|
def test_torchao_training(self):
|
|
"""Test that quantized models can be used for training with adapters."""
|
|
self._test_quantization_training(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])
|
|
|
|
|
|
@is_quantization
|
|
@is_gguf
|
|
@require_accelerate
|
|
@require_accelerator
|
|
@require_gguf_version_greater_or_equal("0.10.0")
|
|
class GGUFConfigMixin:
|
|
"""
|
|
Base mixin providing GGUF quantization config and model creation.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- gguf_filename: URL or path to the GGUF file
|
|
"""
|
|
|
|
gguf_filename = None
|
|
|
|
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
|
|
if config_kwargs is None:
|
|
config_kwargs = {"compute_dtype": torch.bfloat16}
|
|
|
|
config = GGUFQuantizationConfig(**config_kwargs)
|
|
kwargs = {
|
|
"quantization_config": config,
|
|
"torch_dtype": config_kwargs.get("compute_dtype", torch.bfloat16),
|
|
}
|
|
kwargs.update(extra_kwargs)
|
|
return self.model_class.from_single_file(self.gguf_filename, **kwargs)
|
|
|
|
def _verify_if_layer_quantized(self, name, module, config_kwargs=None):
|
|
from diffusers.quantizers.gguf.utils import GGUFParameter
|
|
|
|
assert isinstance(module.weight, GGUFParameter), f"{name} weight is not GGUFParameter"
|
|
assert hasattr(module.weight, "quant_type"), f"{name} weight missing quant_type"
|
|
assert module.weight.dtype == torch.uint8, f"{name} weight dtype should be uint8"
|
|
|
|
|
|
@is_gguf
|
|
@require_accelerate
|
|
@require_accelerator
|
|
@require_gguf_version_greater_or_equal("0.10.0")
|
|
class GGUFTesterMixin(GGUFConfigMixin, QuantizationTesterMixin):
|
|
"""
|
|
Mixin class for testing GGUF quantization on models.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- gguf_filename: URL or path to the GGUF file
|
|
|
|
Expected methods to be implemented by subclasses:
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Pytest mark: gguf
|
|
Use `pytest -m "not gguf"` to skip these tests
|
|
"""
|
|
|
|
def test_gguf_quantization_inference(self):
|
|
self._test_quantization_inference({"compute_dtype": torch.bfloat16})
|
|
|
|
def test_gguf_keep_modules_in_fp32(self):
|
|
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
|
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
|
|
|
|
_keep_in_fp32_modules = self.model_class._keep_in_fp32_modules
|
|
self.model_class._keep_in_fp32_modules = ["proj_out"]
|
|
|
|
try:
|
|
model = self._create_quantized_model()
|
|
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, torch.nn.Linear):
|
|
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
|
|
assert module.weight.dtype == torch.float32, f"Module {name} should be FP32"
|
|
finally:
|
|
self.model_class._keep_in_fp32_modules = _keep_in_fp32_modules
|
|
|
|
def test_gguf_quantization_dtype_assignment(self):
|
|
self._test_quantization_dtype_assignment({"compute_dtype": torch.bfloat16})
|
|
|
|
def test_gguf_quantization_lora_inference(self):
|
|
self._test_quantization_lora_inference({"compute_dtype": torch.bfloat16})
|
|
|
|
def test_gguf_dequantize(self):
|
|
"""Test that dequantize() works correctly."""
|
|
self._test_dequantize({"compute_dtype": torch.bfloat16})
|
|
|
|
def test_gguf_quantized_layers(self):
|
|
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
|
|
|
|
|
|
@is_quantization
|
|
@is_modelopt
|
|
@require_accelerator
|
|
@require_accelerate
|
|
@require_modelopt_version_greater_or_equal("0.33.1")
|
|
class ModelOptConfigMixin:
|
|
"""
|
|
Base mixin providing NVIDIA ModelOpt quantization config and model creation.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained
|
|
"""
|
|
|
|
MODELOPT_CONFIGS = {
|
|
"fp8": {"quant_type": "FP8"},
|
|
"int8": {"quant_type": "INT8"},
|
|
"int4": {"quant_type": "INT4"},
|
|
}
|
|
|
|
MODELOPT_EXPECTED_MEMORY_REDUCTIONS = {
|
|
"fp8": 1.5,
|
|
"int8": 1.5,
|
|
"int4": 3.0,
|
|
}
|
|
|
|
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
|
config = NVIDIAModelOptConfig(**config_kwargs)
|
|
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
|
kwargs["quantization_config"] = config
|
|
kwargs.update(extra_kwargs)
|
|
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
|
|
|
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
|
assert mtq.utils.is_quantized(module), f"Layer {name} does not have weight_quantizer attribute (not quantized)"
|
|
|
|
|
|
@is_modelopt
|
|
@require_accelerator
|
|
@require_accelerate
|
|
@require_modelopt_version_greater_or_equal("0.33.1")
|
|
class ModelOptTesterMixin(ModelOptConfigMixin, QuantizationTesterMixin):
|
|
"""
|
|
Mixin class for testing NVIDIA ModelOpt quantization on models.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"})
|
|
|
|
Expected methods to be implemented by subclasses:
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Optional class attributes:
|
|
- MODELOPT_CONFIGS: Dict of config name -> NVIDIAModelOptConfig kwargs to test
|
|
|
|
Pytest mark: modelopt
|
|
Use `pytest -m "not modelopt"` to skip these tests
|
|
"""
|
|
|
|
@pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"])
|
|
def test_modelopt_quantization_num_parameters(self, config_name):
|
|
self._test_quantization_num_parameters(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name])
|
|
|
|
@pytest.mark.parametrize(
|
|
"config_name",
|
|
list(ModelOptConfigMixin.MODELOPT_CONFIGS.keys()),
|
|
ids=list(ModelOptConfigMixin.MODELOPT_CONFIGS.keys()),
|
|
)
|
|
def test_modelopt_quantization_memory_footprint(self, config_name):
|
|
expected = ModelOptConfigMixin.MODELOPT_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2)
|
|
self._test_quantization_memory_footprint(
|
|
ModelOptConfigMixin.MODELOPT_CONFIGS[config_name], expected_memory_reduction=expected
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
"config_name",
|
|
list(ModelOptConfigMixin.MODELOPT_CONFIGS.keys()),
|
|
ids=list(ModelOptConfigMixin.MODELOPT_CONFIGS.keys()),
|
|
)
|
|
def test_modelopt_quantization_inference(self, config_name):
|
|
self._test_quantization_inference(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name])
|
|
|
|
@pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"])
|
|
def test_modelopt_quantization_dtype_assignment(self, config_name):
|
|
self._test_quantization_dtype_assignment(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name])
|
|
|
|
@pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"])
|
|
def test_modelopt_quantization_lora_inference(self, config_name):
|
|
self._test_quantization_lora_inference(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name])
|
|
|
|
@pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"])
|
|
def test_modelopt_quantization_serialization(self, config_name, tmp_path):
|
|
self._test_quantization_serialization(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name], tmp_path)
|
|
|
|
@pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"])
|
|
def test_modelopt_quantized_layers(self, config_name):
|
|
self._test_quantized_layers(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name])
|
|
|
|
def test_modelopt_modules_to_not_convert(self):
|
|
"""Test that modules_to_not_convert parameter works correctly."""
|
|
modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None)
|
|
if modules_to_exclude is None:
|
|
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
|
|
|
self._test_quantization_modules_to_not_convert(ModelOptConfigMixin.MODELOPT_CONFIGS["fp8"], modules_to_exclude)
|
|
|
|
def test_modelopt_device_map(self):
|
|
"""Test that device_map='auto' works correctly with quantization."""
|
|
self._test_quantization_device_map(ModelOptConfigMixin.MODELOPT_CONFIGS["fp8"])
|
|
|
|
def test_modelopt_dequantize(self):
|
|
"""Test that dequantize() works correctly."""
|
|
self._test_dequantize(ModelOptConfigMixin.MODELOPT_CONFIGS["fp8"])
|
|
|
|
|
|
@is_torch_compile
|
|
class QuantizationCompileTesterMixin:
|
|
"""
|
|
Base mixin class providing common test implementations for torch.compile with quantized models.
|
|
|
|
Backend-specific compile mixins should:
|
|
1. Inherit from their respective config mixin (e.g., BitsAndBytesConfigMixin)
|
|
2. Inherit from this mixin
|
|
3. Define the config to use for compile tests
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained
|
|
|
|
Expected methods in test classes:
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
"""
|
|
|
|
def setup_method(self):
|
|
gc.collect()
|
|
backend_empty_cache(torch_device)
|
|
torch.compiler.reset()
|
|
|
|
def teardown_method(self):
|
|
gc.collect()
|
|
backend_empty_cache(torch_device)
|
|
torch.compiler.reset()
|
|
|
|
def _test_torch_compile(self, config_kwargs):
|
|
"""
|
|
Test that torch.compile works correctly with a quantized model.
|
|
|
|
Args:
|
|
config_kwargs: Quantization config parameters
|
|
"""
|
|
model = self._create_quantized_model(config_kwargs)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
# Compile the model with fullgraph=True to ensure no graph breaks
|
|
model = torch.compile(model, fullgraph=True)
|
|
|
|
# Run inference with error_on_recompile to detect recompilation issues
|
|
with torch.no_grad(), torch._dynamo.config.patch(error_on_recompile=True):
|
|
inputs = self.get_dummy_inputs()
|
|
output = model(**inputs, return_dict=False)[0]
|
|
assert output is not None, "Model output is None"
|
|
assert not torch.isnan(output).any(), "Model output contains NaN"
|
|
|
|
def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False):
|
|
"""
|
|
Test that torch.compile works correctly with a quantized model and group offloading.
|
|
|
|
Args:
|
|
config_kwargs: Quantization config parameters
|
|
use_stream: Whether to use CUDA streams for offloading
|
|
"""
|
|
torch._dynamo.config.cache_size_limit = 1000
|
|
|
|
model = self._create_quantized_model(config_kwargs)
|
|
model.eval()
|
|
|
|
if not hasattr(model, "enable_group_offload"):
|
|
pytest.skip("Model does not support group offloading")
|
|
|
|
group_offload_kwargs = {
|
|
"onload_device": torch.device(torch_device),
|
|
"offload_device": torch.device("cpu"),
|
|
"offload_type": "leaf_level",
|
|
"use_stream": use_stream,
|
|
}
|
|
model.enable_group_offload(**group_offload_kwargs)
|
|
model = torch.compile(model)
|
|
|
|
with torch.no_grad():
|
|
inputs = self.get_dummy_inputs()
|
|
output = model(**inputs, return_dict=False)[0]
|
|
assert output is not None, "Model output is None"
|
|
assert not torch.isnan(output).any(), "Model output contains NaN"
|
|
|
|
|
|
@is_bitsandbytes
|
|
@require_accelerator
|
|
@require_bitsandbytes_version_greater("0.43.2")
|
|
@require_accelerate
|
|
class BitsAndBytesCompileTesterMixin(BitsAndBytesConfigMixin, QuantizationCompileTesterMixin):
|
|
"""
|
|
Mixin class for testing torch.compile with BitsAndBytes quantized models.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained
|
|
|
|
Expected methods to be implemented by subclasses:
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Pytest mark: bitsandbytes
|
|
Use `pytest -m "not bitsandbytes"` to skip these tests
|
|
"""
|
|
|
|
@pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"])
|
|
def test_bnb_torch_compile(self, config_name):
|
|
self._test_torch_compile(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name])
|
|
|
|
@pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"])
|
|
def test_bnb_torch_compile_with_group_offload(self, config_name):
|
|
self._test_torch_compile_with_group_offload(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name])
|
|
|
|
|
|
@is_quanto
|
|
@require_quanto
|
|
@require_accelerate
|
|
@require_accelerator
|
|
class QuantoCompileTesterMixin(QuantoConfigMixin, QuantizationCompileTesterMixin):
|
|
"""
|
|
Mixin class for testing torch.compile with Quanto quantized models.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained
|
|
|
|
Expected methods to be implemented by subclasses:
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Pytest mark: quanto
|
|
Use `pytest -m "not quanto"` to skip these tests
|
|
"""
|
|
|
|
@pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"])
|
|
def test_quanto_torch_compile(self, weight_type_name):
|
|
self._test_torch_compile(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name])
|
|
|
|
@pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"])
|
|
def test_quanto_torch_compile_with_group_offload(self, weight_type_name):
|
|
self._test_torch_compile_with_group_offload(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name])
|
|
|
|
|
|
@is_torchao
|
|
@require_accelerator
|
|
@require_torchao_version_greater_or_equal("0.7.0")
|
|
class TorchAoCompileTesterMixin(TorchAoConfigMixin, QuantizationCompileTesterMixin):
|
|
"""
|
|
Mixin class for testing torch.compile with TorchAO quantized models.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained
|
|
|
|
Expected methods to be implemented by subclasses:
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Pytest mark: torchao
|
|
Use `pytest -m "not torchao"` to skip these tests
|
|
"""
|
|
|
|
@pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"])
|
|
def test_torchao_torch_compile(self, quant_type):
|
|
self._test_torch_compile(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type])
|
|
|
|
@pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"])
|
|
def test_torchao_torch_compile_with_group_offload(self, quant_type):
|
|
self._test_torch_compile_with_group_offload(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type])
|
|
|
|
|
|
@is_gguf
|
|
@require_accelerate
|
|
@require_accelerator
|
|
@require_gguf_version_greater_or_equal("0.10.0")
|
|
class GGUFCompileTesterMixin(GGUFConfigMixin, QuantizationCompileTesterMixin):
|
|
"""
|
|
Mixin class for testing torch.compile with GGUF quantized models.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- gguf_filename: URL or path to the GGUF file
|
|
|
|
Expected methods to be implemented by subclasses:
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Pytest mark: gguf
|
|
Use `pytest -m "not gguf"` to skip these tests
|
|
"""
|
|
|
|
def test_gguf_torch_compile(self):
|
|
self._test_torch_compile({"compute_dtype": torch.bfloat16})
|
|
|
|
def test_gguf_torch_compile_with_group_offload(self):
|
|
self._test_torch_compile_with_group_offload({"compute_dtype": torch.bfloat16})
|
|
|
|
|
|
@is_modelopt
|
|
@require_accelerator
|
|
@require_accelerate
|
|
@require_modelopt_version_greater_or_equal("0.33.1")
|
|
class ModelOptCompileTesterMixin(ModelOptConfigMixin, QuantizationCompileTesterMixin):
|
|
"""
|
|
Mixin class for testing torch.compile with NVIDIA ModelOpt quantized models.
|
|
|
|
Expected class attributes:
|
|
- model_class: The model class to test
|
|
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
|
- pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained
|
|
|
|
Expected methods to be implemented by subclasses:
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Pytest mark: modelopt
|
|
Use `pytest -m "not modelopt"` to skip these tests
|
|
"""
|
|
|
|
@pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"])
|
|
def test_modelopt_torch_compile(self, config_name):
|
|
self._test_torch_compile(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name])
|
|
|
|
@pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"])
|
|
def test_modelopt_torch_compile_with_group_offload(self, config_name):
|
|
self._test_torch_compile_with_group_offload(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name])
|