Compare commits

..

3 Commits

Author SHA1 Message Date
Sayak Paul
7d20b7261b Merge branch 'main' into modular-more-loading-tests 2026-02-11 08:20:31 +05:30
Sayak Paul
4d00980e25 [lora] fix non-diffusers lora key handling for flux2 (#13119)
fix non-diffusers lora key handling for flux2
2026-02-11 08:06:36 +05:30
sayakpaul
49b256efe5 add tests for robust model loading. 2026-02-10 22:05:01 +05:30
4 changed files with 131 additions and 76 deletions

View File

@@ -2321,6 +2321,14 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
prefix = "diffusion_model."
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
has_lora_down_up = any("lora_down" in k or "lora_up" in k for k in original_state_dict.keys())
if has_lora_down_up:
temp_state_dict = {}
for k, v in original_state_dict.items():
new_key = k.replace("lora_down", "lora_A").replace("lora_up", "lora_B")
temp_state_dict[new_key] = v
original_state_dict = temp_state_dict
num_double_layers = 0
num_single_layers = 0
for key in original_state_dict.keys():
@@ -2337,13 +2345,15 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
attn_prefix = f"single_transformer_blocks.{sl}.attn"
for lora_key in lora_keys:
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
f"{single_block_prefix}.linear1.{lora_key}.weight"
)
linear1_key = f"{single_block_prefix}.linear1.{lora_key}.weight"
if linear1_key in original_state_dict:
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
linear1_key
)
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
f"{single_block_prefix}.linear2.{lora_key}.weight"
)
linear2_key = f"{single_block_prefix}.linear2.{lora_key}.weight"
if linear2_key in original_state_dict:
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(linear2_key)
for dl in range(num_double_layers):
transformer_block_prefix = f"transformer_blocks.{dl}"
@@ -2352,6 +2362,10 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
for attn_type in attn_types:
attn_prefix = f"{transformer_block_prefix}.attn"
qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
if qkv_key not in original_state_dict:
continue
fused_qkv_weight = original_state_dict.pop(qkv_key)
if lora_key == "lora_A":
@@ -2383,8 +2397,9 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
for org_proj, diff_proj in proj_mappings:
for lora_key in lora_keys:
original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
if original_key in original_state_dict:
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
mlp_mappings = [
("img_mlp.0", "ff.linear_in"),
@@ -2395,8 +2410,27 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
for org_mlp, diff_mlp in mlp_mappings:
for lora_key in lora_keys:
original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
if original_key in original_state_dict:
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
extra_mappings = {
"img_in": "x_embedder",
"txt_in": "context_embedder",
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
"final_layer.linear": "proj_out",
"final_layer.adaLN_modulation.1": "norm_out.linear",
"single_stream_modulation.lin": "single_stream_modulation.linear",
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
}
for org_key, diff_key in extra_mappings.items():
for lora_key in lora_keys:
original_key = f"{org_key}.{lora_key}.weight"
if original_key in original_state_dict:
converted_state_dict[f"{diff_key}.{lora_key}.weight"] = original_state_dict.pop(original_key)
if len(original_state_dict) > 0:
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")

View File

@@ -21,8 +21,11 @@ import torch
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
from diffusers.utils.import_utils import (
is_bitsandbytes_available,
is_gguf_available,
is_nvidia_modelopt_available,
is_optimum_quanto_available,
is_torchao_available,
is_torchao_version,
)
from ...testing_utils import (
@@ -56,6 +59,13 @@ if is_bitsandbytes_available():
if is_optimum_quanto_available():
from optimum.quanto import QLinear
if is_gguf_available():
pass
if is_torchao_available():
if is_torchao_version(">=", "0.9.0"):
pass
class LoRALayer(torch.nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only.
@@ -122,14 +132,14 @@ class QuantizationTesterMixin:
def _verify_if_layer_quantized(self, name, module, config_kwargs):
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
def _is_module_quantized(self, module, quant_config_kwargs=None):
def _is_module_quantized(self, module):
"""
Check if a module is quantized. Returns True if quantized, False otherwise.
Default implementation tries _verify_if_layer_quantized and catches exceptions.
Subclasses can override for more efficient checking.
"""
try:
self._verify_if_layer_quantized("", module, quant_config_kwargs or {})
self._verify_if_layer_quantized("", module, {})
return True
except (AssertionError, AttributeError):
return False
@@ -269,9 +279,7 @@ class QuantizationTesterMixin:
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
)
def _test_quantization_modules_to_not_convert(
self, config_kwargs, modules_to_not_convert, to_not_convert_key="modules_to_not_convert"
):
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
"""
Test that modules specified in modules_to_not_convert are not quantized.
@@ -281,7 +289,7 @@ class QuantizationTesterMixin:
"""
# Create config with modules_to_not_convert
config_kwargs_with_exclusion = config_kwargs.copy()
config_kwargs_with_exclusion[to_not_convert_key] = modules_to_not_convert
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
@@ -293,7 +301,7 @@ class QuantizationTesterMixin:
if any(excluded in name for excluded in modules_to_not_convert):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
assert not self._is_module_quantized(module), (
f"Module {name} should not be quantized but was found to be quantized"
)
@@ -305,7 +313,7 @@ class QuantizationTesterMixin:
if isinstance(module, torch.nn.Linear):
# Check if this module is NOT in the exclusion list
if not any(excluded in name for excluded in modules_to_not_convert):
if self._is_module_quantized(module, config_kwargs_with_exclusion):
if self._is_module_quantized(module):
found_quantized = True
break
@@ -610,7 +618,7 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude, "llm_int8_skip_modules"
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude
)
@pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"])
@@ -809,14 +817,7 @@ class TorchAoConfigMixin:
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
# Check if the weight is actually quantized
weight = module.weight
is_quantized = isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))
assert is_quantized, f"Layer {name} weight is not quantized, got {type(weight)}"
# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
@@ -912,39 +913,9 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
if modules_to_exclude is None:
pytest.skip("modules_to_not_convert_for_test not defined for this model")
# Custom implementation for torchao that skips memory footprint check
# because get_memory_footprint() doesn't accurately reflect torchao quantization
config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]
config_kwargs_with_exclusion = config_kwargs.copy()
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_exclude
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
# Find a module that should NOT be quantized
found_excluded = False
for name, module in model_with_exclusion.named_modules():
if isinstance(module, torch.nn.Linear):
# Check if this module is in the exclusion list
if any(excluded in name for excluded in modules_to_exclude):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
f"Module {name} should not be quantized but was found to be quantized"
)
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_exclude}"
# Find a module that SHOULD be quantized (not in exclusion list)
found_quantized = False
for name, module in model_with_exclusion.named_modules():
if isinstance(module, torch.nn.Linear):
# Check if this module is NOT in the exclusion list
if not any(excluded in name for excluded in modules_to_exclude):
if self._is_module_quantized(module, config_kwargs_with_exclusion):
found_quantized = True
break
assert found_quantized, "No quantized layers found outside of excluded modules"
self._test_quantization_modules_to_not_convert(
TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
)
def test_torchao_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""

View File

@@ -318,10 +318,6 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Flux Transformer."""
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
"""Quanto quantization tests for Flux Transformer."""
@@ -334,18 +330,10 @@ class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
def pretrained_model_kwargs(self):
return {}
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Flux Transformer."""
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
@property
@@ -414,10 +402,6 @@ class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTes
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
"""ModelOpt quantization tests for Flux Transformer."""
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
"""ModelOpt + compile tests for Flux Transformer."""

View File

@@ -6,7 +6,7 @@ import pytest
import torch
import diffusers
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
@@ -598,3 +598,69 @@ class TestModularModelCardContent:
content = generate_modular_model_card_content(blocks)
assert "5-block architecture" in content["model_description"]
class TestAutoModelLoadIdTagging:
def test_automodel_tags_load_id(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe", subfolder="unet")
assert hasattr(model, "_diffusers_load_id"), "Model should have _diffusers_load_id attribute"
assert model._diffusers_load_id != "null", "_diffusers_load_id should not be 'null'"
# Verify load_id contains the expected fields
load_id = model._diffusers_load_id
assert "hf-internal-testing/tiny-stable-diffusion-xl-pipe" in load_id
assert "unet" in load_id
def test_automodel_update_components(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)
auto_model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe", subfolder="unet")
pipe.update_components(unet=auto_model)
assert pipe.unet is auto_model
assert "unet" in pipe._component_specs
spec = pipe._component_specs["unet"]
assert spec.pretrained_model_name_or_path == "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
assert spec.subfolder == "unet"
class TestLoadComponentsSkipBehavior:
def test_load_components_skips_already_loaded(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)
original_unet = pipe.unet
pipe.load_components()
# Verify that the unet is the same object (not reloaded)
assert pipe.unet is original_unet, "load_components should skip already loaded components"
def test_load_components_selective_loading(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(names="unet", torch_dtype=torch.float32)
# Verify only requested component was loaded.
assert hasattr(pipe, "unet")
assert pipe.unet is not None
if "vae" in pipe._component_specs:
assert getattr(pipe, "vae", None) is None
def test_load_components_skips_invalid_pretrained_path(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe._component_specs["test_component"] = ComponentSpec(
name="test_component",
type_hint=torch.nn.Module,
pretrained_model_name_or_path=None,
default_creation_method="from_pretrained",
)
pipe.load_components(torch_dtype=torch.float32)
# Verify test_component was not loaded
assert not hasattr(pipe, "test_component") or pipe.test_component is None