Compare commits

..

6 Commits

Author SHA1 Message Date
Sayak Paul
d79b88ae8d Merge branch 'main' into fix-modules-no-convert-torchao 2026-03-04 16:34:08 +05:30
Sayak Paul
faed0087d3 Merge branch 'main' into fix-modules-no-convert-torchao 2026-02-13 19:53:58 +05:30
Sayak Paul
ed734a0e63 Merge branch 'main' into fix-modules-no-convert-torchao 2026-02-10 15:49:51 +05:30
sayakpaul
d676b03490 fix torchao/. 2026-02-10 15:32:41 +05:30
sayakpaul
e117274aa5 fix bnb modules_to_convert. 2026-02-10 13:49:05 +05:30
sayakpaul
a1804cfa80 make modules_to_not_convert actually run. 2026-02-05 09:47:15 +05:30
3 changed files with 81 additions and 98 deletions

View File

@@ -2519,13 +2519,6 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
if has_default:
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
# Normalize ZImage-specific dot-separated module names to underscore form so they
# match the diffusers model parameter names (context_refiner, noise_refiner).
state_dict = {
k.replace("context.refiner.", "context_refiner.").replace("noise.refiner.", "noise_refiner."): v
for k, v in state_dict.items()
}
converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
@@ -2536,18 +2529,19 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
if has_non_diffusers_lora_id:
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
for k in all_keys:
if k.endswith(down_key):
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
@@ -2560,69 +2554,13 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
converted_state_dict[diffusers_down_key] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up
# Already in diffusers format (lora_A/lora_B), apply alpha scaling and pop.
# Already in diffusers format (lora_A/lora_B), just pop
elif has_diffusers_lora_id:
for k in all_keys:
if k.endswith(a_key):
diffusers_up_key = k.replace(a_key, b_key)
alpha_key = k.replace(a_key, ".alpha")
down_weight = state_dict.pop(k)
up_weight = state_dict.pop(diffusers_up_key)
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[k] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up
# Handle dot-format LoRA keys: ".lora.down.weight" / ".lora.up.weight".
# Some external ZImage trainers (e.g. Anime-Z) use dots instead of underscores in
# lora weight names and also include redundant keys:
# - "qkv.lora.*" duplicates individual "to.q/k/v.lora.*" keys → skip qkv
# - "out.lora.*" duplicates "to_out.0.lora.*" keys → skip bare out
# - "to.q/k/v.lora.*" → normalise to "to_q/k/v.lora_A/B.weight"
lora_dot_down_key = ".lora.down.weight"
lora_dot_up_key = ".lora.up.weight"
has_lora_dot_format = any(lora_dot_down_key in k for k in state_dict)
if has_lora_dot_format:
dot_keys = list(state_dict.keys())
for k in dot_keys:
if lora_dot_down_key not in k:
continue
if k not in state_dict:
continue # already popped by a prior iteration
base = k[: -len(lora_dot_down_key)]
# Skip combined "qkv" projection — individual to.q/k/v keys are also present.
if base.endswith(".qkv"):
if a_key in k or b_key in k:
converted_state_dict[k] = state_dict.pop(k)
elif ".alpha" in k:
state_dict.pop(k)
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
state_dict.pop(base + ".alpha", None)
continue
# Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection.
if re.search(r"\.out$", base) and ".to_out" not in base:
state_dict.pop(k)
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
continue
# Normalise "to.q/k/v" → "to_q/k/v" for the diffusers output key.
norm_k = re.sub(
r"\.to\.([qkv])" + re.escape(lora_dot_down_key) + r"$",
r".to_\1" + lora_dot_down_key,
k,
)
norm_base = norm_k[: -len(lora_dot_down_key)]
alpha_key = norm_base + ".alpha"
diffusers_down = norm_k.replace(lora_dot_down_key, ".lora_A.weight")
diffusers_up = norm_k.replace(lora_dot_down_key, ".lora_B.weight")
down_weight = state_dict.pop(k)
up_weight = state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key))
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[diffusers_down] = down_weight * scale_down
converted_state_dict[diffusers_up] = up_weight * scale_up
if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")

View File

@@ -21,11 +21,8 @@ import torch
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
from diffusers.utils.import_utils import (
is_bitsandbytes_available,
is_gguf_available,
is_nvidia_modelopt_available,
is_optimum_quanto_available,
is_torchao_available,
is_torchao_version,
)
from ...testing_utils import (
@@ -59,13 +56,6 @@ if is_bitsandbytes_available():
if is_optimum_quanto_available():
from optimum.quanto import QLinear
if is_gguf_available():
pass
if is_torchao_available():
if is_torchao_version(">=", "0.9.0"):
pass
class LoRALayer(torch.nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only.
@@ -132,14 +122,14 @@ class QuantizationTesterMixin:
def _verify_if_layer_quantized(self, name, module, config_kwargs):
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
def _is_module_quantized(self, module):
def _is_module_quantized(self, module, quant_config_kwargs=None):
"""
Check if a module is quantized. Returns True if quantized, False otherwise.
Default implementation tries _verify_if_layer_quantized and catches exceptions.
Subclasses can override for more efficient checking.
"""
try:
self._verify_if_layer_quantized("", module, {})
self._verify_if_layer_quantized("", module, quant_config_kwargs or {})
return True
except (AssertionError, AttributeError):
return False
@@ -273,7 +263,9 @@ class QuantizationTesterMixin:
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
)
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
def _test_quantization_modules_to_not_convert(
self, config_kwargs, modules_to_not_convert, to_not_convert_key="modules_to_not_convert"
):
"""
Test that modules specified in modules_to_not_convert are not quantized.
@@ -283,7 +275,7 @@ class QuantizationTesterMixin:
"""
# Create config with modules_to_not_convert
config_kwargs_with_exclusion = config_kwargs.copy()
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
config_kwargs_with_exclusion[to_not_convert_key] = modules_to_not_convert
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
@@ -295,7 +287,7 @@ class QuantizationTesterMixin:
if any(excluded in name for excluded in modules_to_not_convert):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(module), (
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
f"Module {name} should not be quantized but was found to be quantized"
)
@@ -307,7 +299,7 @@ class QuantizationTesterMixin:
if isinstance(module, torch.nn.Linear):
# Check if this module is NOT in the exclusion list
if not any(excluded in name for excluded in modules_to_not_convert):
if self._is_module_quantized(module):
if self._is_module_quantized(module, config_kwargs_with_exclusion):
found_quantized = True
break
@@ -612,7 +604,7 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude, "llm_int8_skip_modules"
)
@pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"])
@@ -826,7 +818,14 @@ class TorchAoConfigMixin:
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
# Check if the weight is actually quantized
weight = module.weight
is_quantized = isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))
assert is_quantized, f"Layer {name} weight is not quantized, got {type(weight)}"
# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
@@ -922,9 +921,39 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
if modules_to_exclude is None:
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(
TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
)
# Custom implementation for torchao that skips memory footprint check
# because get_memory_footprint() doesn't accurately reflect torchao quantization
config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]
config_kwargs_with_exclusion = config_kwargs.copy()
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_exclude
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
# Find a module that should NOT be quantized
found_excluded = False
for name, module in model_with_exclusion.named_modules():
if isinstance(module, torch.nn.Linear):
# Check if this module is in the exclusion list
if any(excluded in name for excluded in modules_to_exclude):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
f"Module {name} should not be quantized but was found to be quantized"
)
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_exclude}"
# Find a module that SHOULD be quantized (not in exclusion list)
found_quantized = False
for name, module in model_with_exclusion.named_modules():
if isinstance(module, torch.nn.Linear):
# Check if this module is NOT in the exclusion list
if not any(excluded in name for excluded in modules_to_exclude):
if self._is_module_quantized(module, config_kwargs_with_exclusion):
found_quantized = True
break
assert found_quantized, "No quantized layers found outside of excluded modules"
def test_torchao_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""

View File

@@ -318,6 +318,10 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Flux Transformer."""
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
"""Quanto quantization tests for Flux Transformer."""
@@ -330,10 +334,18 @@ class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
def pretrained_model_kwargs(self):
return {}
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Flux Transformer."""
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
@property
@@ -402,6 +414,10 @@ class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTes
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
"""ModelOpt quantization tests for Flux Transformer."""
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
"""ModelOpt + compile tests for Flux Transformer."""