mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-12 05:45:23 +08:00
Compare commits
4 Commits
transforme
...
fix-module
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed734a0e63 | ||
|
|
d676b03490 | ||
|
|
e117274aa5 | ||
|
|
a1804cfa80 |
@@ -21,11 +21,8 @@ import torch
|
||||
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
|
||||
from diffusers.utils.import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_gguf_available,
|
||||
is_nvidia_modelopt_available,
|
||||
is_optimum_quanto_available,
|
||||
is_torchao_available,
|
||||
is_torchao_version,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
@@ -59,13 +56,6 @@ if is_bitsandbytes_available():
|
||||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import QLinear
|
||||
|
||||
if is_gguf_available():
|
||||
pass
|
||||
|
||||
if is_torchao_available():
|
||||
if is_torchao_version(">=", "0.9.0"):
|
||||
pass
|
||||
|
||||
|
||||
class LoRALayer(torch.nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only.
|
||||
@@ -132,14 +122,14 @@ class QuantizationTesterMixin:
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
|
||||
|
||||
def _is_module_quantized(self, module):
|
||||
def _is_module_quantized(self, module, quant_config_kwargs=None):
|
||||
"""
|
||||
Check if a module is quantized. Returns True if quantized, False otherwise.
|
||||
Default implementation tries _verify_if_layer_quantized and catches exceptions.
|
||||
Subclasses can override for more efficient checking.
|
||||
"""
|
||||
try:
|
||||
self._verify_if_layer_quantized("", module, {})
|
||||
self._verify_if_layer_quantized("", module, quant_config_kwargs or {})
|
||||
return True
|
||||
except (AssertionError, AttributeError):
|
||||
return False
|
||||
@@ -279,7 +269,9 @@ class QuantizationTesterMixin:
|
||||
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
|
||||
)
|
||||
|
||||
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
|
||||
def _test_quantization_modules_to_not_convert(
|
||||
self, config_kwargs, modules_to_not_convert, to_not_convert_key="modules_to_not_convert"
|
||||
):
|
||||
"""
|
||||
Test that modules specified in modules_to_not_convert are not quantized.
|
||||
|
||||
@@ -289,7 +281,7 @@ class QuantizationTesterMixin:
|
||||
"""
|
||||
# Create config with modules_to_not_convert
|
||||
config_kwargs_with_exclusion = config_kwargs.copy()
|
||||
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
|
||||
config_kwargs_with_exclusion[to_not_convert_key] = modules_to_not_convert
|
||||
|
||||
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
|
||||
|
||||
@@ -301,7 +293,7 @@ class QuantizationTesterMixin:
|
||||
if any(excluded in name for excluded in modules_to_not_convert):
|
||||
found_excluded = True
|
||||
# This module should NOT be quantized
|
||||
assert not self._is_module_quantized(module), (
|
||||
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
|
||||
f"Module {name} should not be quantized but was found to be quantized"
|
||||
)
|
||||
|
||||
@@ -313,7 +305,7 @@ class QuantizationTesterMixin:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is NOT in the exclusion list
|
||||
if not any(excluded in name for excluded in modules_to_not_convert):
|
||||
if self._is_module_quantized(module):
|
||||
if self._is_module_quantized(module, config_kwargs_with_exclusion):
|
||||
found_quantized = True
|
||||
break
|
||||
|
||||
@@ -618,7 +610,7 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(
|
||||
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude
|
||||
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude, "llm_int8_skip_modules"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"])
|
||||
@@ -817,7 +809,14 @@ class TorchAoConfigMixin:
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
|
||||
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
|
||||
# Check if the weight is actually quantized
|
||||
weight = module.weight
|
||||
is_quantized = isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))
|
||||
assert is_quantized, f"Layer {name} weight is not quantized, got {type(weight)}"
|
||||
|
||||
|
||||
# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
|
||||
@@ -913,9 +912,39 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(
|
||||
TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
|
||||
)
|
||||
# Custom implementation for torchao that skips memory footprint check
|
||||
# because get_memory_footprint() doesn't accurately reflect torchao quantization
|
||||
config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]
|
||||
config_kwargs_with_exclusion = config_kwargs.copy()
|
||||
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_exclude
|
||||
|
||||
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
|
||||
|
||||
# Find a module that should NOT be quantized
|
||||
found_excluded = False
|
||||
for name, module in model_with_exclusion.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is in the exclusion list
|
||||
if any(excluded in name for excluded in modules_to_exclude):
|
||||
found_excluded = True
|
||||
# This module should NOT be quantized
|
||||
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
|
||||
f"Module {name} should not be quantized but was found to be quantized"
|
||||
)
|
||||
|
||||
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_exclude}"
|
||||
|
||||
# Find a module that SHOULD be quantized (not in exclusion list)
|
||||
found_quantized = False
|
||||
for name, module in model_with_exclusion.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is NOT in the exclusion list
|
||||
if not any(excluded in name for excluded in modules_to_exclude):
|
||||
if self._is_module_quantized(module, config_kwargs_with_exclusion):
|
||||
found_quantized = True
|
||||
break
|
||||
|
||||
assert found_quantized, "No quantized layers found outside of excluded modules"
|
||||
|
||||
def test_torchao_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
|
||||
@@ -318,6 +318,10 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def modules_to_not_convert_for_test(self):
|
||||
return ["norm_out.linear"]
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
"""Quanto quantization tests for Flux Transformer."""
|
||||
@@ -330,10 +334,18 @@ class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
def pretrained_model_kwargs(self):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def modules_to_not_convert_for_test(self):
|
||||
return ["norm_out.linear"]
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def modules_to_not_convert_for_test(self):
|
||||
return ["norm_out.linear"]
|
||||
|
||||
|
||||
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
|
||||
@property
|
||||
@@ -402,6 +414,10 @@ class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTes
|
||||
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
|
||||
"""ModelOpt quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def modules_to_not_convert_for_test(self):
|
||||
return ["norm_out.linear"]
|
||||
|
||||
|
||||
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
|
||||
"""ModelOpt + compile tests for Flux Transformer."""
|
||||
|
||||
Reference in New Issue
Block a user