mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-11 19:21:44 +08:00
Compare commits
2 Commits
tests-load
...
zimage-lor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dae34b1ec8 | ||
|
|
897aed72fa |
@@ -2538,8 +2538,12 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
alpha_tensor = state_dict.pop(alpha_key, None)
|
||||
if alpha_tensor is None:
|
||||
return 1.0, 1.0
|
||||
scale = (
|
||||
alpha_tensor.item() / rank
|
||||
) # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
|
||||
@@ -36,7 +36,7 @@ from typing import Any, Callable
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
|
||||
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -844,6 +844,8 @@ class QuantoConfig(QuantizationConfigMixin):
|
||||
modules_to_not_convert: list[str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecation_message = "`QuantoConfig` is deprecated and will be removed in version 1.0.0."
|
||||
deprecate("QuantoConfig", "1.0.0", deprecation_message)
|
||||
self.quant_method = QuantizationMethod.QUANTO
|
||||
self.weights_dtype = weights_dtype
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from diffusers.utils.import_utils import is_optimum_quanto_version
|
||||
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
get_module_from_name,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
@@ -42,6 +43,9 @@ class QuantoQuantizer(DiffusersQuantizer):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
deprecation_message = "The Quanto quantizer is deprecated and will be removed in version 1.0.0."
|
||||
deprecate("QuantoQuantizer", "1.0.0", deprecation_message)
|
||||
|
||||
if not is_optimum_quanto_available():
|
||||
raise ImportError(
|
||||
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
@@ -33,33 +32,6 @@ from ..testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
def _get_specified_components(path_or_repo_id, cache_dir=None):
|
||||
if os.path.isdir(path_or_repo_id):
|
||||
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=path_or_repo_id,
|
||||
filename="modular_model_index.json",
|
||||
local_dir=cache_dir,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
components = set()
|
||||
for k, v in config.items():
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
continue
|
||||
for entry in v:
|
||||
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
|
||||
components.add(k)
|
||||
break
|
||||
return components
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
@@ -388,39 +360,6 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_load_expected_components_from_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
|
||||
if not expected:
|
||||
pytest.skip("Skipping test as we couldn't fetch the expected components.")
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in pipe.components
|
||||
if getattr(pipe, name, None) is not None
|
||||
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
|
||||
|
||||
def test_load_expected_components_from_save_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
save_dir = str(tmp_path / "saved-pipeline")
|
||||
pipe.save_pretrained(save_dir)
|
||||
|
||||
expected = _get_specified_components(save_dir)
|
||||
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
|
||||
loaded_pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in loaded_pipe.components
|
||||
if getattr(loaded_pipe, name, None) is not None
|
||||
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, (
|
||||
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
|
||||
)
|
||||
|
||||
def test_modular_index_consistency(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
components_spec = pipe._component_specs
|
||||
|
||||
Reference in New Issue
Block a user