mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-05 08:10:38 +08:00
Compare commits
6 Commits
fix-zimage
...
fix-module
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d79b88ae8d | ||
|
|
faed0087d3 | ||
|
|
ed734a0e63 | ||
|
|
d676b03490 | ||
|
|
e117274aa5 | ||
|
|
a1804cfa80 |
@@ -2519,13 +2519,6 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
if has_default:
|
||||
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
|
||||
|
||||
# Normalize ZImage-specific dot-separated module names to underscore form so they
|
||||
# match the diffusers model parameter names (context_refiner, noise_refiner).
|
||||
state_dict = {
|
||||
k.replace("context.refiner.", "context_refiner.").replace("noise.refiner.", "noise_refiner."): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
@@ -2536,18 +2529,19 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
|
||||
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
if has_non_diffusers_lora_id:
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
for k in all_keys:
|
||||
if k.endswith(down_key):
|
||||
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
|
||||
@@ -2560,69 +2554,13 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
converted_state_dict[diffusers_down_key] = down_weight * scale_down
|
||||
converted_state_dict[diffusers_up_key] = up_weight * scale_up
|
||||
|
||||
# Already in diffusers format (lora_A/lora_B), apply alpha scaling and pop.
|
||||
# Already in diffusers format (lora_A/lora_B), just pop
|
||||
elif has_diffusers_lora_id:
|
||||
for k in all_keys:
|
||||
if k.endswith(a_key):
|
||||
diffusers_up_key = k.replace(a_key, b_key)
|
||||
alpha_key = k.replace(a_key, ".alpha")
|
||||
|
||||
down_weight = state_dict.pop(k)
|
||||
up_weight = state_dict.pop(diffusers_up_key)
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[k] = down_weight * scale_down
|
||||
converted_state_dict[diffusers_up_key] = up_weight * scale_up
|
||||
|
||||
# Handle dot-format LoRA keys: ".lora.down.weight" / ".lora.up.weight".
|
||||
# Some external ZImage trainers (e.g. Anime-Z) use dots instead of underscores in
|
||||
# lora weight names and also include redundant keys:
|
||||
# - "qkv.lora.*" duplicates individual "to.q/k/v.lora.*" keys → skip qkv
|
||||
# - "out.lora.*" duplicates "to_out.0.lora.*" keys → skip bare out
|
||||
# - "to.q/k/v.lora.*" → normalise to "to_q/k/v.lora_A/B.weight"
|
||||
lora_dot_down_key = ".lora.down.weight"
|
||||
lora_dot_up_key = ".lora.up.weight"
|
||||
has_lora_dot_format = any(lora_dot_down_key in k for k in state_dict)
|
||||
|
||||
if has_lora_dot_format:
|
||||
dot_keys = list(state_dict.keys())
|
||||
for k in dot_keys:
|
||||
if lora_dot_down_key not in k:
|
||||
continue
|
||||
if k not in state_dict:
|
||||
continue # already popped by a prior iteration
|
||||
|
||||
base = k[: -len(lora_dot_down_key)]
|
||||
|
||||
# Skip combined "qkv" projection — individual to.q/k/v keys are also present.
|
||||
if base.endswith(".qkv"):
|
||||
if a_key in k or b_key in k:
|
||||
converted_state_dict[k] = state_dict.pop(k)
|
||||
elif ".alpha" in k:
|
||||
state_dict.pop(k)
|
||||
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
|
||||
state_dict.pop(base + ".alpha", None)
|
||||
continue
|
||||
|
||||
# Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection.
|
||||
if re.search(r"\.out$", base) and ".to_out" not in base:
|
||||
state_dict.pop(k)
|
||||
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
|
||||
continue
|
||||
|
||||
# Normalise "to.q/k/v" → "to_q/k/v" for the diffusers output key.
|
||||
norm_k = re.sub(
|
||||
r"\.to\.([qkv])" + re.escape(lora_dot_down_key) + r"$",
|
||||
r".to_\1" + lora_dot_down_key,
|
||||
k,
|
||||
)
|
||||
norm_base = norm_k[: -len(lora_dot_down_key)]
|
||||
alpha_key = norm_base + ".alpha"
|
||||
|
||||
diffusers_down = norm_k.replace(lora_dot_down_key, ".lora_A.weight")
|
||||
diffusers_up = norm_k.replace(lora_dot_down_key, ".lora_B.weight")
|
||||
|
||||
down_weight = state_dict.pop(k)
|
||||
up_weight = state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key))
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[diffusers_down] = down_weight * scale_down
|
||||
converted_state_dict[diffusers_up] = up_weight * scale_up
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
|
||||
|
||||
@@ -21,11 +21,8 @@ import torch
|
||||
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
|
||||
from diffusers.utils.import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_gguf_available,
|
||||
is_nvidia_modelopt_available,
|
||||
is_optimum_quanto_available,
|
||||
is_torchao_available,
|
||||
is_torchao_version,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
@@ -59,13 +56,6 @@ if is_bitsandbytes_available():
|
||||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import QLinear
|
||||
|
||||
if is_gguf_available():
|
||||
pass
|
||||
|
||||
if is_torchao_available():
|
||||
if is_torchao_version(">=", "0.9.0"):
|
||||
pass
|
||||
|
||||
|
||||
class LoRALayer(torch.nn.Module):
|
||||
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only.
|
||||
@@ -132,14 +122,14 @@ class QuantizationTesterMixin:
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
|
||||
|
||||
def _is_module_quantized(self, module):
|
||||
def _is_module_quantized(self, module, quant_config_kwargs=None):
|
||||
"""
|
||||
Check if a module is quantized. Returns True if quantized, False otherwise.
|
||||
Default implementation tries _verify_if_layer_quantized and catches exceptions.
|
||||
Subclasses can override for more efficient checking.
|
||||
"""
|
||||
try:
|
||||
self._verify_if_layer_quantized("", module, {})
|
||||
self._verify_if_layer_quantized("", module, quant_config_kwargs or {})
|
||||
return True
|
||||
except (AssertionError, AttributeError):
|
||||
return False
|
||||
@@ -273,7 +263,9 @@ class QuantizationTesterMixin:
|
||||
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
|
||||
)
|
||||
|
||||
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
|
||||
def _test_quantization_modules_to_not_convert(
|
||||
self, config_kwargs, modules_to_not_convert, to_not_convert_key="modules_to_not_convert"
|
||||
):
|
||||
"""
|
||||
Test that modules specified in modules_to_not_convert are not quantized.
|
||||
|
||||
@@ -283,7 +275,7 @@ class QuantizationTesterMixin:
|
||||
"""
|
||||
# Create config with modules_to_not_convert
|
||||
config_kwargs_with_exclusion = config_kwargs.copy()
|
||||
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
|
||||
config_kwargs_with_exclusion[to_not_convert_key] = modules_to_not_convert
|
||||
|
||||
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
|
||||
|
||||
@@ -295,7 +287,7 @@ class QuantizationTesterMixin:
|
||||
if any(excluded in name for excluded in modules_to_not_convert):
|
||||
found_excluded = True
|
||||
# This module should NOT be quantized
|
||||
assert not self._is_module_quantized(module), (
|
||||
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
|
||||
f"Module {name} should not be quantized but was found to be quantized"
|
||||
)
|
||||
|
||||
@@ -307,7 +299,7 @@ class QuantizationTesterMixin:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is NOT in the exclusion list
|
||||
if not any(excluded in name for excluded in modules_to_not_convert):
|
||||
if self._is_module_quantized(module):
|
||||
if self._is_module_quantized(module, config_kwargs_with_exclusion):
|
||||
found_quantized = True
|
||||
break
|
||||
|
||||
@@ -612,7 +604,7 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(
|
||||
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude
|
||||
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude, "llm_int8_skip_modules"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"])
|
||||
@@ -826,7 +818,14 @@ class TorchAoConfigMixin:
|
||||
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
|
||||
|
||||
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
|
||||
# Check if the weight is actually quantized
|
||||
weight = module.weight
|
||||
is_quantized = isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))
|
||||
assert is_quantized, f"Layer {name} weight is not quantized, got {type(weight)}"
|
||||
|
||||
|
||||
# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
|
||||
@@ -922,9 +921,39 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(
|
||||
TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
|
||||
)
|
||||
# Custom implementation for torchao that skips memory footprint check
|
||||
# because get_memory_footprint() doesn't accurately reflect torchao quantization
|
||||
config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]
|
||||
config_kwargs_with_exclusion = config_kwargs.copy()
|
||||
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_exclude
|
||||
|
||||
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
|
||||
|
||||
# Find a module that should NOT be quantized
|
||||
found_excluded = False
|
||||
for name, module in model_with_exclusion.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is in the exclusion list
|
||||
if any(excluded in name for excluded in modules_to_exclude):
|
||||
found_excluded = True
|
||||
# This module should NOT be quantized
|
||||
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
|
||||
f"Module {name} should not be quantized but was found to be quantized"
|
||||
)
|
||||
|
||||
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_exclude}"
|
||||
|
||||
# Find a module that SHOULD be quantized (not in exclusion list)
|
||||
found_quantized = False
|
||||
for name, module in model_with_exclusion.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Check if this module is NOT in the exclusion list
|
||||
if not any(excluded in name for excluded in modules_to_exclude):
|
||||
if self._is_module_quantized(module, config_kwargs_with_exclusion):
|
||||
found_quantized = True
|
||||
break
|
||||
|
||||
assert found_quantized, "No quantized layers found outside of excluded modules"
|
||||
|
||||
def test_torchao_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
|
||||
@@ -318,6 +318,10 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def modules_to_not_convert_for_test(self):
|
||||
return ["norm_out.linear"]
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
"""Quanto quantization tests for Flux Transformer."""
|
||||
@@ -330,10 +334,18 @@ class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
def pretrained_model_kwargs(self):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def modules_to_not_convert_for_test(self):
|
||||
return ["norm_out.linear"]
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def modules_to_not_convert_for_test(self):
|
||||
return ["norm_out.linear"]
|
||||
|
||||
|
||||
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
|
||||
@property
|
||||
@@ -402,6 +414,10 @@ class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTes
|
||||
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
|
||||
"""ModelOpt quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def modules_to_not_convert_for_test(self):
|
||||
return ["norm_out.linear"]
|
||||
|
||||
|
||||
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
|
||||
"""ModelOpt + compile tests for Flux Transformer."""
|
||||
|
||||
Reference in New Issue
Block a user