mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
13 Commits
custom-mod
...
nunchaku
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3295c6aba5 | ||
|
|
8e07445540 | ||
|
|
4af534bde9 | ||
|
|
df58c8017e | ||
|
|
2a827ec19f | ||
|
|
d35e77ece0 | ||
|
|
5d08150a2e | ||
|
|
9e0caa7afc | ||
|
|
269813fcc5 | ||
|
|
ac1aa8bbec | ||
|
|
f4262b8877 | ||
|
|
7022169c13 | ||
|
|
8e1ea006f0 |
@@ -13,6 +13,7 @@ from .utils import (
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_nunchaku_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_optimum_quanto_available,
|
||||
@@ -99,6 +100,18 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_nunchaku_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_nunchaku_objects
|
||||
|
||||
_import_structure["utils.dummy_nunchaku_objects"] = [
|
||||
name for name in dir(dummy_nunchaku_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("NunchakuConfig")
|
||||
|
||||
try:
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -791,6 +804,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .quantizers.quantization_config import QuantoConfig
|
||||
|
||||
try:
|
||||
if not is_nunchaku_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_optimum_quanto_objects import *
|
||||
else:
|
||||
from .quantizers.quantization_config import NunchakuConfig
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing_extensions import Self
|
||||
|
||||
from .. import __version__
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..quantizers.quantization_config import NunchakuConfig
|
||||
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from .single_file_utils import (
|
||||
@@ -42,6 +43,7 @@ from .single_file_utils import (
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_nunchaku_flux_to_diffusers,
|
||||
convert_sana_transformer_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
@@ -190,6 +192,23 @@ def _get_mapping_function_kwargs(mapping_fn, **kwargs):
|
||||
return mapping_kwargs
|
||||
|
||||
|
||||
def _maybe_determine_modules_to_not_convert(quantization_config, state_dict):
|
||||
if quantization_config is None:
|
||||
return None
|
||||
else:
|
||||
is_nunchaku = quantization_config.quant_method == "nunchaku"
|
||||
if not is_nunchaku:
|
||||
return None
|
||||
else:
|
||||
no_qweight = set()
|
||||
for key in state_dict:
|
||||
if key.endswith(".weight"):
|
||||
# module name is everything except the last piece after "."
|
||||
module_name = ".".join(key.split(".")[:-1])
|
||||
no_qweight.add(module_name)
|
||||
return sorted(no_qweight)
|
||||
|
||||
|
||||
class FromOriginalModelMixin:
|
||||
"""
|
||||
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
|
||||
@@ -404,8 +423,14 @@ class FromOriginalModelMixin:
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
|
||||
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
|
||||
model_state_dict = model.state_dict()
|
||||
# TODO: Only flux nunchaku checkpoint for now. Unify with how checkpoint mappers are done.
|
||||
# For `nunchaku` checkpoints, we might want to determine the `modules_to_not_convert`.
|
||||
if quantization_config is not None and quantization_config.quant_method == "nunchaku":
|
||||
diffusers_format_checkpoint = convert_nunchaku_flux_to_diffusers(
|
||||
checkpoint, model_state_dict=model_state_dict
|
||||
)
|
||||
elif _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
@@ -416,6 +441,27 @@ class FromOriginalModelMixin:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
# This step is better off here than above because `diffusers_format_checkpoint` holds the keys we expect.
|
||||
# We can move it to a separate function as well.
|
||||
if quantization_config is not None:
|
||||
original_modules_to_not_convert = quantization_config.modules_to_not_convert or []
|
||||
determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert(
|
||||
quantization_config, checkpoint
|
||||
)
|
||||
if determined_modules_to_not_convert:
|
||||
determined_modules_to_not_convert.extend(original_modules_to_not_convert)
|
||||
determined_modules_to_not_convert = list(set(determined_modules_to_not_convert))
|
||||
logger.debug(
|
||||
f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {determined_modules_to_not_convert}."
|
||||
)
|
||||
modified_quant_config = quantization_config.to_dict()
|
||||
modified_quant_config["modules_to_not_convert"] = determined_modules_to_not_convert
|
||||
# TODO: figure out a better way.
|
||||
modified_quant_config = NunchakuConfig.from_dict(modified_quant_config)
|
||||
setattr(hf_quantizer, "quantization_config", modified_quant_config)
|
||||
logger.debug("TODO")
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
@@ -443,6 +489,12 @@ class FromOriginalModelMixin:
|
||||
unexpected_keys = [
|
||||
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
|
||||
]
|
||||
for k in unexpected_keys:
|
||||
if "single_transformer_blocks.0" in k:
|
||||
print(f"Unexpected {k=}")
|
||||
for k in empty_state_dict:
|
||||
if "single_transformer_blocks.0" in k:
|
||||
print(f"model {k=}")
|
||||
device_map = {"": param_device}
|
||||
load_model_dict_into_meta(
|
||||
model,
|
||||
|
||||
@@ -2189,6 +2189,105 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
# Adapted from https://github.com/nunchaku-tech/nunchaku/blob/3ec299f439f9986a69ded320798cab4e258c871d/nunchaku/models/transformers/transformer_flux_v2.py#L395
|
||||
def convert_nunchaku_flux_to_diffusers(checkpoint, **kwargs):
|
||||
from .single_file_utils_nunchaku import _unpack_qkv_state_dict
|
||||
|
||||
_SMOOTH_ORIG_RE = re.compile(r"\.smooth_orig(\.|$)")
|
||||
_SMOOTH_RE = re.compile(r"\.smooth(\.|$)")
|
||||
|
||||
new_state_dict = {}
|
||||
model_state_dict = kwargs["model_state_dict"]
|
||||
|
||||
ckpt_keys = list(checkpoint.keys())
|
||||
for k in ckpt_keys:
|
||||
if "qweight" in k:
|
||||
# only the shape information of this tensor is needed
|
||||
v = checkpoint[k]
|
||||
# if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
|
||||
for t in ["lora_up", "lora_down"]:
|
||||
new_k = k.replace(".qweight", f".{t}")
|
||||
if new_k not in ckpt_keys:
|
||||
oc, ic = v.shape
|
||||
ic = ic * 2 # v is packed into INT8, so we need to double the size
|
||||
checkpoint[k.replace(".qweight", f".{t}")] = torch.zeros(
|
||||
(0, ic) if t == "lora_down" else (oc, 0), device=v.device, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
for k, v in checkpoint.items():
|
||||
new_k = k # start with original, then apply independent replacements
|
||||
|
||||
if k.startswith("single_transformer_blocks."):
|
||||
# attention / qkv / norms
|
||||
new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.")
|
||||
new_k = new_k.replace(".out_proj.", ".proj_out.")
|
||||
new_k = new_k.replace(".norm_k.", ".attn.norm_k.")
|
||||
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
|
||||
|
||||
# mlp heads
|
||||
new_k = new_k.replace(".mlp_fc1.", ".proj_mlp.")
|
||||
new_k = new_k.replace(".mlp_fc2.", ".proj_out.")
|
||||
|
||||
# smooth params (use regex to avoid substring collisions)
|
||||
new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig\1", new_k)
|
||||
new_k = _SMOOTH_RE.sub(r".smooth_factor\1", new_k)
|
||||
|
||||
# lora -> proj
|
||||
new_k = new_k.replace(".lora_down", ".proj_down")
|
||||
new_k = new_k.replace(".lora_up", ".proj_up")
|
||||
|
||||
elif k.startswith("transformer_blocks."):
|
||||
# feed-forward (context & base)
|
||||
new_k = new_k.replace(".mlp_context_fc1.", ".ff_context.net.0.proj.")
|
||||
new_k = new_k.replace(".mlp_context_fc2.", ".ff_context.net.2.")
|
||||
new_k = new_k.replace(".mlp_fc1.", ".ff.net.0.proj.")
|
||||
new_k = new_k.replace(".mlp_fc2.", ".ff.net.2.")
|
||||
|
||||
# attention projections
|
||||
new_k = new_k.replace(".qkv_proj_context.", ".attn.add_qkv_proj.")
|
||||
new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.")
|
||||
new_k = new_k.replace(".out_proj.", ".attn.to_out.0.")
|
||||
new_k = new_k.replace(".out_proj_context.", ".attn.to_add_out.")
|
||||
|
||||
# norms
|
||||
new_k = new_k.replace(".norm_k.", ".attn.norm_k.")
|
||||
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
|
||||
new_k = new_k.replace(".norm_added_k.", ".attn.norm_added_k.")
|
||||
new_k = new_k.replace(".norm_added_q.", ".attn.norm_added_q.")
|
||||
|
||||
# smooth params
|
||||
new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig\1", new_k)
|
||||
new_k = _SMOOTH_RE.sub(r".smooth_factor\1", new_k)
|
||||
|
||||
# lora -> proj
|
||||
new_k = new_k.replace(".lora_down", ".proj_down")
|
||||
new_k = new_k.replace(".lora_up", ".proj_up")
|
||||
|
||||
new_state_dict[new_k] = v
|
||||
|
||||
new_state_dict = _unpack_qkv_state_dict(new_state_dict)
|
||||
|
||||
# some remnant keys need to be patched
|
||||
new_sd_keys = list(new_state_dict.keys())
|
||||
for k in new_sd_keys:
|
||||
if "qweight" in k:
|
||||
no_qweight_k = ".".join(k.split(".qweight")[:-1])
|
||||
for unexpected_k in ["wzeros"]:
|
||||
unexpected_k = no_qweight_k + f".{unexpected_k}"
|
||||
if unexpected_k in new_sd_keys:
|
||||
_ = new_state_dict.pop(unexpected_k)
|
||||
for k in model_state_dict:
|
||||
if k not in new_state_dict:
|
||||
# CPU device for now
|
||||
new_state_dict[k] = torch.ones_like(model_state_dict[k], device="cpu")
|
||||
|
||||
for k in new_state_dict:
|
||||
if "single_transformer_blocks.0" in k and k.endswith(".weight"):
|
||||
print(f"{k=}")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
104
src/diffusers/loaders/single_file_utils_nunchaku.py
Normal file
104
src/diffusers/loaders/single_file_utils_nunchaku.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
_QKV_ANCHORS_NUNCHAKU = ("to_qkv", "add_qkv_proj")
|
||||
_ALLOWED_SUFFIXES_NUNCHAKU = {
|
||||
"bias",
|
||||
"proj_down",
|
||||
"proj_up",
|
||||
"qweight",
|
||||
"smooth_factor",
|
||||
"smooth_factor_orig",
|
||||
"wscales",
|
||||
}
|
||||
|
||||
_QKV_NUNCHAKU_REGEX = re.compile(
|
||||
rf"^(?P<prefix>.*)\.(?:{'|'.join(map(re.escape, _QKV_ANCHORS_NUNCHAKU))})\.(?P<suffix>.+)$"
|
||||
)
|
||||
|
||||
|
||||
def _pick_split_dim(t: torch.Tensor, suffix: str) -> int:
|
||||
"""
|
||||
Choose which dimension to split by 3. Heuristics:
|
||||
- 1D -> dim 0
|
||||
- 2D -> prefer dim=1 for 'qweight' (common layout [*, 3*out_features]),
|
||||
otherwise prefer dim=0 (common layout [3*out_features, *]).
|
||||
- If preferred dim isn't divisible by 3, try the other; else error.
|
||||
"""
|
||||
shape = list(t.shape)
|
||||
if len(shape) == 0:
|
||||
raise ValueError("Cannot split a scalar into Q/K/V.")
|
||||
|
||||
if len(shape) == 1:
|
||||
dim = 0
|
||||
if shape[dim] % 3 == 0:
|
||||
return dim
|
||||
raise ValueError(f"1D tensor of length {shape[0]} not divisible by 3.")
|
||||
|
||||
# len(shape) >= 2
|
||||
preferred = 1 if suffix == "qweight" else 0
|
||||
other = 0 if preferred == 1 else 1
|
||||
|
||||
if shape[preferred] % 3 == 0:
|
||||
return preferred
|
||||
if shape[other] % 3 == 0:
|
||||
return other
|
||||
|
||||
# Fall back: any dim divisible by 3
|
||||
for d, s in enumerate(shape):
|
||||
if s % 3 == 0:
|
||||
return d
|
||||
|
||||
raise ValueError(f"None of the dims {shape} are divisible by 3 for suffix '{suffix}'.")
|
||||
|
||||
|
||||
def _split_qkv(t: torch.Tensor, dim: int):
|
||||
return torch.tensor_split(t, 3, dim=dim)
|
||||
|
||||
|
||||
def _unpack_qkv_state_dict(
|
||||
state_dict: dict, anchors=_QKV_ANCHORS_NUNCHAKU, allowed_suffixes=_ALLOWED_SUFFIXES_NUNCHAKU
|
||||
):
|
||||
"""
|
||||
Convert fused QKV entries (e.g., '...to_qkv.bias', '...qkv_proj.wscales') into separate Q/K/V entries:
|
||||
'...to_q.bias', '...to_k.bias', '...to_v.bias' '...to_q.wscales', '...to_k.wscales', '...to_v.wscales'
|
||||
Returns a NEW dict; original is not modified.
|
||||
|
||||
Only keys with suffix in `allowed_suffixes` are processed. Keys with non-divisible-by-3 tensors raise a ValueError.:
|
||||
"""
|
||||
anchors = tuple(anchors)
|
||||
allowed_suffixes = set(allowed_suffixes)
|
||||
|
||||
new_sd: dict = {}
|
||||
sd_keys = list(state_dict.keys())
|
||||
for k in sd_keys:
|
||||
m = _QKV_NUNCHAKU_REGEX.match(k)
|
||||
v = state_dict.pop(k)
|
||||
if m:
|
||||
suffix = m.group("suffix")
|
||||
if suffix not in allowed_suffixes:
|
||||
# keep as-is if it's not one of the targeted suffixes
|
||||
new_sd[k] = v
|
||||
continue
|
||||
|
||||
prefix = m.group("prefix") # everything before .to_qkv/.qkv_proj
|
||||
# Decide split axis
|
||||
split_dim = _pick_split_dim(v, suffix)
|
||||
q, k_, vv = _split_qkv(v, dim=split_dim)
|
||||
|
||||
# Build new keys
|
||||
base_q = f"{prefix}.to_q.{suffix}"
|
||||
base_k = f"{prefix}.to_k.{suffix}"
|
||||
base_v = f"{prefix}.to_v.{suffix}"
|
||||
|
||||
# Write into result dict
|
||||
new_sd[base_q] = q
|
||||
new_sd[base_k] = k_
|
||||
new_sd[base_v] = vv
|
||||
else:
|
||||
# not a fused qkv key
|
||||
new_sd[k] = v
|
||||
|
||||
return new_sd
|
||||
@@ -297,6 +297,13 @@ def load_model_dict_into_meta(
|
||||
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
|
||||
elif param_device == "cpu" and state_dict_index is not None:
|
||||
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
|
||||
# This check below might be a bit counter-intuitive in nature. This is because we're checking if the param
|
||||
# or its module is quantized and if so, we're proceeding with creating a quantized param. This is because
|
||||
# of the way pre-trained models are loaded. They're initialized under "meta" device, where
|
||||
# quantization layers are first injected. Hence, for a model that is either pre-quantized or supplemented
|
||||
# with a `quantization_config` during `from_pretrained`, we expect `check_if_quantized_param` to return True.
|
||||
# Then depending on the quantization backend being used, we run the actual quantization step under
|
||||
# `create_quantized_param`.
|
||||
elif is_quantized and (
|
||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
|
||||
):
|
||||
|
||||
@@ -21,9 +21,11 @@ from typing import Dict, Optional, Union
|
||||
|
||||
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
|
||||
from .gguf import GGUFQuantizer
|
||||
from .nunchaku import NunchakuQuantizer
|
||||
from .quantization_config import (
|
||||
BitsAndBytesConfig,
|
||||
GGUFQuantizationConfig,
|
||||
NunchakuConfig,
|
||||
QuantizationConfigMixin,
|
||||
QuantizationMethod,
|
||||
QuantoConfig,
|
||||
@@ -39,6 +41,7 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"gguf": GGUFQuantizer,
|
||||
"quanto": QuantoQuantizer,
|
||||
"torchao": TorchAoHfQuantizer,
|
||||
"nunchaku": NunchakuQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
@@ -47,12 +50,13 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"gguf": GGUFQuantizationConfig,
|
||||
"quanto": QuantoConfig,
|
||||
"torchao": TorchAoConfig,
|
||||
"nunchaku": NunchakuConfig,
|
||||
}
|
||||
|
||||
|
||||
class DiffusersAutoQuantizer:
|
||||
"""
|
||||
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
|
||||
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
|
||||
`DiffusersQuantizer` given the `QuantizationConfig`.
|
||||
"""
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class GGUFQuantizer(DiffusersQuantizer):
|
||||
def check_if_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: Union["GGUFParameter", "torch.Tensor"],
|
||||
param_value: Union["torch.Tensor"],
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
|
||||
1
src/diffusers/quantizers/nunchaku/__init__.py
Normal file
1
src/diffusers/quantizers/nunchaku/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .nunchaku_quantizer import NunchakuQuantizer
|
||||
174
src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py
Normal file
174
src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from diffusers.utils.import_utils import is_nunchaku_version
|
||||
|
||||
from ...utils import get_module_from_name, is_accelerate_available, is_nunchaku_available, is_torch_available, logging
|
||||
from ...utils.torch_utils import is_fp8_available
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_nunchaku_available():
|
||||
from .utils import replace_with_nunchaku_linear
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
KEY_MAP = {
|
||||
"lora_down": "proj_down",
|
||||
"lora_up": "proj_up",
|
||||
"smooth_orig": "smooth_factor_orig",
|
||||
"smooth": "smooth_factor",
|
||||
}
|
||||
|
||||
|
||||
class NunchakuQuantizer(DiffusersQuantizer):
|
||||
r"""
|
||||
Diffusers Quantizer for Nunchaku (https://github.com/nunchaku-tech/nunchaku)
|
||||
"""
|
||||
|
||||
use_keep_in_fp32_modules = True
|
||||
requires_calibration = False
|
||||
required_packages = ["nunchaku", "accelerate"]
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("No GPU found. A GPU is needed for nunchaku quantization.")
|
||||
|
||||
if not is_nunchaku_available():
|
||||
raise ImportError(
|
||||
"Loading an nunchaku quantized model requires nunchaku library (follow https://nunchaku.tech/docs/nunchaku/installation/installation.html)"
|
||||
)
|
||||
if not is_nunchaku_version(">=", "0.3.1"):
|
||||
raise ImportError(
|
||||
"Loading an nunchaku quantized model requires `nunchaku>=1.0.0`. "
|
||||
"Please upgrade your installation by following https://nunchaku.tech/docs/nunchaku/installation/installation.html."
|
||||
)
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError(
|
||||
"Loading an nunchaku quantized model requires accelerate library (`pip install accelerate`)"
|
||||
)
|
||||
|
||||
# TODO: check
|
||||
# device_map = kwargs.get("device_map", None)
|
||||
# if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
||||
# raise ValueError(
|
||||
# "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the nunchaku backend"
|
||||
# )
|
||||
|
||||
def check_if_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
from nunchaku.models.linear import SVDQW4A4Linear
|
||||
|
||||
module, _ = get_module_from_name(model, param_name)
|
||||
if self.pre_quantized and isinstance(module, SVDQW4A4Linear):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create a quantized parameter.
|
||||
"""
|
||||
from nunchaku.models.linear import SVDQW4A4Linear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if tensor_name not in module._parameters and tensor_name not in module._buffers:
|
||||
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
|
||||
|
||||
if isinstance(module, SVDQW4A4Linear):
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to(target_device)
|
||||
|
||||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
||||
return max_memory
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
precision = self.quantization_config.precision
|
||||
expected_target_dtypes = [torch.int8]
|
||||
if is_fp8_available():
|
||||
expected_target_dtypes.append(torch.float8_e4m3fn)
|
||||
if target_dtype not in expected_target_dtypes:
|
||||
new_target_dtype = self.dtype_map[precision]
|
||||
|
||||
logger.info(f"target_dtype {target_dtype} is replaced by {new_target_dtype} for `nunchaku` quantization")
|
||||
return new_target_dtype
|
||||
else:
|
||||
raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.")
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
# We force the `dtype` to be bfloat16, this is a requirement from `nunchaku`
|
||||
logger.info(
|
||||
"Overriding torch_dtype=%s with `torch_dtype=torch.bfloat16` due to "
|
||||
"requirements of `nunchaku` to enable model loading in 4-bit. "
|
||||
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
|
||||
" torch_dtype=torch.bfloat16 to remove this warning.",
|
||||
torch_dtype,
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
return torch_dtype
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
**kwargs,
|
||||
):
|
||||
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
|
||||
if not isinstance(self.modules_to_not_convert, list):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
# Purge `None`.
|
||||
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
|
||||
# in case of diffusion transformer models. For language models and others alike, `lm_head`
|
||||
# and tied modules are usually kept in FP32.
|
||||
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
|
||||
|
||||
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
|
||||
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
||||
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
|
||||
self.modules_to_not_convert.extend(keys_on_cpu)
|
||||
|
||||
model = replace_with_nunchaku_linear(
|
||||
model,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
quantization_config=self.quantization_config,
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
return False
|
||||
80
src/diffusers/quantizers/nunchaku/utils.py
Normal file
80
src/diffusers/quantizers/nunchaku/utils.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils import is_accelerate_available, is_nunchaku_available, logging
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _replace_with_nunchaku_linear(
|
||||
model,
|
||||
svdq_linear_cls,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(
|
||||
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
||||
):
|
||||
with init_empty_weights():
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
||||
model._modules[name] = svdq_linear_cls(
|
||||
in_features,
|
||||
out_features,
|
||||
rank=quantization_config.rank,
|
||||
bias=module.bias is not None,
|
||||
torch_dtype=module.weight.dtype,
|
||||
)
|
||||
has_been_replaced = True
|
||||
# Store the module class in case we need to transpose the weight later
|
||||
model._modules[name].source_cls = type(module)
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_nunchaku_linear(
|
||||
module,
|
||||
svdq_linear_cls,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
|
||||
if is_nunchaku_available():
|
||||
from nunchaku.models.linear import SVDQW4A4Linear
|
||||
|
||||
model, _ = _replace_with_nunchaku_linear(
|
||||
model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config
|
||||
)
|
||||
|
||||
has_been_replaced = any(
|
||||
isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules()
|
||||
)
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model in the SVDQuant method but no linear modules were found in your model."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
|
||||
return model
|
||||
@@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum):
|
||||
GGUF = "gguf"
|
||||
TORCHAO = "torchao"
|
||||
QUANTO = "quanto"
|
||||
NUNCHAKU = "nunchaku"
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
@@ -724,3 +725,72 @@ class QuantoConfig(QuantizationConfigMixin):
|
||||
accepted_weights = ["float8", "int8", "int4", "int2"]
|
||||
if self.weights_dtype not in accepted_weights:
|
||||
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
|
||||
|
||||
|
||||
class NunchakuConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
||||
loaded using `nunchaku`.
|
||||
|
||||
Args:
|
||||
TODO
|
||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
|
||||
modules left in their original precision (e.g. `norm` layers in Qwen-Image).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str = "svdquant",
|
||||
weight_dtype: str = "int4",
|
||||
weight_scale_dtype: str = None,
|
||||
weight_group_size: int = 64,
|
||||
activation_dtype: str = "int4",
|
||||
activation_scale_dtype: str = None,
|
||||
activation_group_size: int = 64,
|
||||
rank: int = 32,
|
||||
modules_to_not_convert: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.NUNCHAKU
|
||||
self.method = method
|
||||
self.weight_dtype = weight_dtype
|
||||
self.weight_scale_dtype = weight_scale_dtype
|
||||
self.weight_group_size = weight_group_size
|
||||
self.activation_dtype = activation_dtype
|
||||
self.activation_scale_dtype = activation_scale_dtype
|
||||
self.activation_group_size = activation_group_size
|
||||
self.rank = rank
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct. Hardware checks were largely adapted from the official `nunchaku`
|
||||
codebase.
|
||||
"""
|
||||
from ..utils.torch_utils import get_device
|
||||
|
||||
device = get_device()
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
|
||||
sm = f"{capability[0]}{capability[1]}"
|
||||
if sm == "120": # you can only use the fp4 models
|
||||
if self.weight_dtype != "fp4_e2m1_all":
|
||||
raise ValueError('Please use "fp4" quantization for Blackwell GPUs.')
|
||||
elif sm in ["75", "80", "86", "89"]:
|
||||
if self.weight_dtype != "int4":
|
||||
raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs.')
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported GPU architecture {sm} due to the lack of 4-bit tensorcores. "
|
||||
"Please use a Turing, Ampere, Ada or Blackwell GPU for this quantization configuration."
|
||||
)
|
||||
|
||||
# TODO: should there be a check for rank?
|
||||
|
||||
def __repr__(self):
|
||||
config_dict = self.to_dict()
|
||||
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
||||
|
||||
@@ -89,6 +89,7 @@ from .import_utils import (
|
||||
is_matplotlib_available,
|
||||
is_nltk_available,
|
||||
is_note_seq_available,
|
||||
is_nunchaku_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_optimum_quanto_available,
|
||||
|
||||
17
src/diffusers/utils/dummy_nunchaku_objects.py
Normal file
17
src/diffusers/utils/dummy_nunchaku_objects.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class NunchakuConfig(metaclass=DummyObject):
|
||||
_backends = ["nunchaku"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["nunchaku"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["nunchaku"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["nunchaku"])
|
||||
@@ -217,6 +217,7 @@ _gguf_available, _gguf_version = _is_package_available("gguf")
|
||||
_torchao_available, _torchao_version = _is_package_available("torchao")
|
||||
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
|
||||
_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
|
||||
_nunchaku_available, _nunchaku_version = _is_package_available("nunchaku", get_dist_name=True)
|
||||
_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface")
|
||||
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
|
||||
_nltk_available, _nltk_version = _is_package_available("nltk")
|
||||
@@ -363,6 +364,10 @@ def is_optimum_quanto_available():
|
||||
return _optimum_quanto_available
|
||||
|
||||
|
||||
def is_nunchaku_available():
|
||||
return _nunchaku_available
|
||||
|
||||
|
||||
def is_timm_available():
|
||||
return _timm_available
|
||||
|
||||
@@ -816,7 +821,7 @@ def is_k_diffusion_version(operation: str, version: str):
|
||||
|
||||
def is_optimum_quanto_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Accelerate version to a given reference with an operation.
|
||||
Compares the current quanto version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
@@ -829,6 +834,21 @@ def is_optimum_quanto_version(operation: str, version: str):
|
||||
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
||||
|
||||
|
||||
def is_nunchaku_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current nunchaku version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _nunchaku_available:
|
||||
return False
|
||||
return compare_versions(parse(_nunchaku_version), operation, version)
|
||||
|
||||
|
||||
def is_xformers_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current xformers version to a given reference with an operation.
|
||||
|
||||
@@ -197,3 +197,7 @@ def device_synchronize(device_type: Optional[str] = None):
|
||||
device_type = get_device()
|
||||
device_mod = getattr(torch, device_type, torch.cuda)
|
||||
device_mod.synchronize()
|
||||
|
||||
|
||||
def is_fp8_available():
|
||||
return getattr(torch, "float8_e4m3fn", None) is None
|
||||
|
||||
Reference in New Issue
Block a user