Compare commits

...

13 Commits

Author SHA1 Message Date
sayakpaul
3295c6aba5 up 2025-08-26 19:38:43 +02:00
sayakpaul
8e07445540 up 2025-08-22 19:21:56 +05:30
sayakpaul
4af534bde9 up 2025-08-22 10:36:27 +05:30
sayakpaul
df58c8017e up 2025-08-21 17:08:29 +05:30
sayakpaul
2a827ec19f up 2025-08-21 14:50:57 +05:30
sayakpaul
d35e77ece0 up 2025-08-21 14:48:08 +05:30
sayakpaul
5d08150a2e up 2025-08-21 14:39:58 +05:30
sayakpaul
9e0caa7afc up 2025-08-21 14:02:54 +05:30
sayakpaul
269813fcc5 up 2025-08-21 13:40:32 +05:30
sayakpaul
ac1aa8bbec up 2025-08-21 13:05:12 +05:30
sayakpaul
f4262b8877 up 2025-08-21 13:01:19 +05:30
sayakpaul
7022169c13 up 2025-08-21 12:27:16 +05:30
sayakpaul
8e1ea006f0 start nunchaku. 2025-08-21 12:25:59 +05:30
15 changed files with 659 additions and 5 deletions

View File

@@ -13,6 +13,7 @@ from .utils import (
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
is_nunchaku_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
@@ -99,6 +100,18 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
try:
if not is_torch_available() and not is_accelerate_available() and not is_nunchaku_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_nunchaku_objects
_import_structure["utils.dummy_nunchaku_objects"] = [
name for name in dir(dummy_nunchaku_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("NunchakuConfig")
try:
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
@@ -791,6 +804,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .quantizers.quantization_config import QuantoConfig
try:
if not is_nunchaku_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_optimum_quanto_objects import *
else:
from .quantizers.quantization_config import NunchakuConfig
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()

View File

@@ -23,6 +23,7 @@ from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..quantizers.quantization_config import NunchakuConfig
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
@@ -42,6 +43,7 @@ from .single_file_utils import (
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
convert_mochi_transformer_checkpoint_to_diffusers,
convert_nunchaku_flux_to_diffusers,
convert_sana_transformer_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
@@ -190,6 +192,23 @@ def _get_mapping_function_kwargs(mapping_fn, **kwargs):
return mapping_kwargs
def _maybe_determine_modules_to_not_convert(quantization_config, state_dict):
if quantization_config is None:
return None
else:
is_nunchaku = quantization_config.quant_method == "nunchaku"
if not is_nunchaku:
return None
else:
no_qweight = set()
for key in state_dict:
if key.endswith(".weight"):
# module name is everything except the last piece after "."
module_name = ".".join(key.split(".")[:-1])
no_qweight.add(module_name)
return sorted(no_qweight)
class FromOriginalModelMixin:
"""
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
@@ -404,8 +423,14 @@ class FromOriginalModelMixin:
model = cls.from_config(diffusers_model_config)
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
model_state_dict = model.state_dict()
# TODO: Only flux nunchaku checkpoint for now. Unify with how checkpoint mappers are done.
# For `nunchaku` checkpoints, we might want to determine the `modules_to_not_convert`.
if quantization_config is not None and quantization_config.quant_method == "nunchaku":
diffusers_format_checkpoint = convert_nunchaku_flux_to_diffusers(
checkpoint, model_state_dict=model_state_dict
)
elif _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
@@ -416,6 +441,27 @@ class FromOriginalModelMixin:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)
# This step is better off here than above because `diffusers_format_checkpoint` holds the keys we expect.
# We can move it to a separate function as well.
if quantization_config is not None:
original_modules_to_not_convert = quantization_config.modules_to_not_convert or []
determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert(
quantization_config, checkpoint
)
if determined_modules_to_not_convert:
determined_modules_to_not_convert.extend(original_modules_to_not_convert)
determined_modules_to_not_convert = list(set(determined_modules_to_not_convert))
logger.debug(
f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {determined_modules_to_not_convert}."
)
modified_quant_config = quantization_config.to_dict()
modified_quant_config["modules_to_not_convert"] = determined_modules_to_not_convert
# TODO: figure out a better way.
modified_quant_config = NunchakuConfig.from_dict(modified_quant_config)
setattr(hf_quantizer, "quantization_config", modified_quant_config)
logger.debug("TODO")
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
@@ -443,6 +489,12 @@ class FromOriginalModelMixin:
unexpected_keys = [
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
]
for k in unexpected_keys:
if "single_transformer_blocks.0" in k:
print(f"Unexpected {k=}")
for k in empty_state_dict:
if "single_transformer_blocks.0" in k:
print(f"model {k=}")
device_map = {"": param_device}
load_model_dict_into_meta(
model,

View File

@@ -2189,6 +2189,105 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
return converted_state_dict
# Adapted from https://github.com/nunchaku-tech/nunchaku/blob/3ec299f439f9986a69ded320798cab4e258c871d/nunchaku/models/transformers/transformer_flux_v2.py#L395
def convert_nunchaku_flux_to_diffusers(checkpoint, **kwargs):
from .single_file_utils_nunchaku import _unpack_qkv_state_dict
_SMOOTH_ORIG_RE = re.compile(r"\.smooth_orig(\.|$)")
_SMOOTH_RE = re.compile(r"\.smooth(\.|$)")
new_state_dict = {}
model_state_dict = kwargs["model_state_dict"]
ckpt_keys = list(checkpoint.keys())
for k in ckpt_keys:
if "qweight" in k:
# only the shape information of this tensor is needed
v = checkpoint[k]
# if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
for t in ["lora_up", "lora_down"]:
new_k = k.replace(".qweight", f".{t}")
if new_k not in ckpt_keys:
oc, ic = v.shape
ic = ic * 2 # v is packed into INT8, so we need to double the size
checkpoint[k.replace(".qweight", f".{t}")] = torch.zeros(
(0, ic) if t == "lora_down" else (oc, 0), device=v.device, dtype=torch.bfloat16
)
for k, v in checkpoint.items():
new_k = k # start with original, then apply independent replacements
if k.startswith("single_transformer_blocks."):
# attention / qkv / norms
new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.")
new_k = new_k.replace(".out_proj.", ".proj_out.")
new_k = new_k.replace(".norm_k.", ".attn.norm_k.")
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
# mlp heads
new_k = new_k.replace(".mlp_fc1.", ".proj_mlp.")
new_k = new_k.replace(".mlp_fc2.", ".proj_out.")
# smooth params (use regex to avoid substring collisions)
new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig\1", new_k)
new_k = _SMOOTH_RE.sub(r".smooth_factor\1", new_k)
# lora -> proj
new_k = new_k.replace(".lora_down", ".proj_down")
new_k = new_k.replace(".lora_up", ".proj_up")
elif k.startswith("transformer_blocks."):
# feed-forward (context & base)
new_k = new_k.replace(".mlp_context_fc1.", ".ff_context.net.0.proj.")
new_k = new_k.replace(".mlp_context_fc2.", ".ff_context.net.2.")
new_k = new_k.replace(".mlp_fc1.", ".ff.net.0.proj.")
new_k = new_k.replace(".mlp_fc2.", ".ff.net.2.")
# attention projections
new_k = new_k.replace(".qkv_proj_context.", ".attn.add_qkv_proj.")
new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.")
new_k = new_k.replace(".out_proj.", ".attn.to_out.0.")
new_k = new_k.replace(".out_proj_context.", ".attn.to_add_out.")
# norms
new_k = new_k.replace(".norm_k.", ".attn.norm_k.")
new_k = new_k.replace(".norm_q.", ".attn.norm_q.")
new_k = new_k.replace(".norm_added_k.", ".attn.norm_added_k.")
new_k = new_k.replace(".norm_added_q.", ".attn.norm_added_q.")
# smooth params
new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig\1", new_k)
new_k = _SMOOTH_RE.sub(r".smooth_factor\1", new_k)
# lora -> proj
new_k = new_k.replace(".lora_down", ".proj_down")
new_k = new_k.replace(".lora_up", ".proj_up")
new_state_dict[new_k] = v
new_state_dict = _unpack_qkv_state_dict(new_state_dict)
# some remnant keys need to be patched
new_sd_keys = list(new_state_dict.keys())
for k in new_sd_keys:
if "qweight" in k:
no_qweight_k = ".".join(k.split(".qweight")[:-1])
for unexpected_k in ["wzeros"]:
unexpected_k = no_qweight_k + f".{unexpected_k}"
if unexpected_k in new_sd_keys:
_ = new_state_dict.pop(unexpected_k)
for k in model_state_dict:
if k not in new_state_dict:
# CPU device for now
new_state_dict[k] = torch.ones_like(model_state_dict[k], device="cpu")
for k in new_state_dict:
if "single_transformer_blocks.0" in k and k.endswith(".weight"):
print(f"{k=}")
return new_state_dict
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())

View File

@@ -0,0 +1,104 @@
import re
import torch
_QKV_ANCHORS_NUNCHAKU = ("to_qkv", "add_qkv_proj")
_ALLOWED_SUFFIXES_NUNCHAKU = {
"bias",
"proj_down",
"proj_up",
"qweight",
"smooth_factor",
"smooth_factor_orig",
"wscales",
}
_QKV_NUNCHAKU_REGEX = re.compile(
rf"^(?P<prefix>.*)\.(?:{'|'.join(map(re.escape, _QKV_ANCHORS_NUNCHAKU))})\.(?P<suffix>.+)$"
)
def _pick_split_dim(t: torch.Tensor, suffix: str) -> int:
"""
Choose which dimension to split by 3. Heuristics:
- 1D -> dim 0
- 2D -> prefer dim=1 for 'qweight' (common layout [*, 3*out_features]),
otherwise prefer dim=0 (common layout [3*out_features, *]).
- If preferred dim isn't divisible by 3, try the other; else error.
"""
shape = list(t.shape)
if len(shape) == 0:
raise ValueError("Cannot split a scalar into Q/K/V.")
if len(shape) == 1:
dim = 0
if shape[dim] % 3 == 0:
return dim
raise ValueError(f"1D tensor of length {shape[0]} not divisible by 3.")
# len(shape) >= 2
preferred = 1 if suffix == "qweight" else 0
other = 0 if preferred == 1 else 1
if shape[preferred] % 3 == 0:
return preferred
if shape[other] % 3 == 0:
return other
# Fall back: any dim divisible by 3
for d, s in enumerate(shape):
if s % 3 == 0:
return d
raise ValueError(f"None of the dims {shape} are divisible by 3 for suffix '{suffix}'.")
def _split_qkv(t: torch.Tensor, dim: int):
return torch.tensor_split(t, 3, dim=dim)
def _unpack_qkv_state_dict(
state_dict: dict, anchors=_QKV_ANCHORS_NUNCHAKU, allowed_suffixes=_ALLOWED_SUFFIXES_NUNCHAKU
):
"""
Convert fused QKV entries (e.g., '...to_qkv.bias', '...qkv_proj.wscales') into separate Q/K/V entries:
'...to_q.bias', '...to_k.bias', '...to_v.bias' '...to_q.wscales', '...to_k.wscales', '...to_v.wscales'
Returns a NEW dict; original is not modified.
Only keys with suffix in `allowed_suffixes` are processed. Keys with non-divisible-by-3 tensors raise a ValueError.:
"""
anchors = tuple(anchors)
allowed_suffixes = set(allowed_suffixes)
new_sd: dict = {}
sd_keys = list(state_dict.keys())
for k in sd_keys:
m = _QKV_NUNCHAKU_REGEX.match(k)
v = state_dict.pop(k)
if m:
suffix = m.group("suffix")
if suffix not in allowed_suffixes:
# keep as-is if it's not one of the targeted suffixes
new_sd[k] = v
continue
prefix = m.group("prefix") # everything before .to_qkv/.qkv_proj
# Decide split axis
split_dim = _pick_split_dim(v, suffix)
q, k_, vv = _split_qkv(v, dim=split_dim)
# Build new keys
base_q = f"{prefix}.to_q.{suffix}"
base_k = f"{prefix}.to_k.{suffix}"
base_v = f"{prefix}.to_v.{suffix}"
# Write into result dict
new_sd[base_q] = q
new_sd[base_k] = k_
new_sd[base_v] = vv
else:
# not a fused qkv key
new_sd[k] = v
return new_sd

View File

@@ -297,6 +297,13 @@ def load_model_dict_into_meta(
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
# This check below might be a bit counter-intuitive in nature. This is because we're checking if the param
# or its module is quantized and if so, we're proceeding with creating a quantized param. This is because
# of the way pre-trained models are loaded. They're initialized under "meta" device, where
# quantization layers are first injected. Hence, for a model that is either pre-quantized or supplemented
# with a `quantization_config` during `from_pretrained`, we expect `check_if_quantized_param` to return True.
# Then depending on the quantization backend being used, we run the actual quantization step under
# `create_quantized_param`.
elif is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
):

View File

@@ -21,9 +21,11 @@ from typing import Dict, Optional, Union
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
from .nunchaku import NunchakuQuantizer
from .quantization_config import (
BitsAndBytesConfig,
GGUFQuantizationConfig,
NunchakuConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
@@ -39,6 +41,7 @@ AUTO_QUANTIZER_MAPPING = {
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
"nunchaku": NunchakuQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -47,12 +50,13 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
"nunchaku": NunchakuConfig,
}
class DiffusersAutoQuantizer:
"""
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
`DiffusersQuantizer` given the `QuantizationConfig`.
"""

View File

@@ -90,7 +90,7 @@ class GGUFQuantizer(DiffusersQuantizer):
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: Union["GGUFParameter", "torch.Tensor"],
param_value: Union["torch.Tensor"],
param_name: str,
state_dict: Dict[str, Any],
**kwargs,

View File

@@ -0,0 +1 @@
from .nunchaku_quantizer import NunchakuQuantizer

View File

@@ -0,0 +1,174 @@
from typing import TYPE_CHECKING, Any, Dict, List, Union
from diffusers.utils.import_utils import is_nunchaku_version
from ...utils import get_module_from_name, is_accelerate_available, is_nunchaku_available, is_torch_available, logging
from ...utils.torch_utils import is_fp8_available
from ..base import DiffusersQuantizer
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
if is_torch_available():
import torch
if is_nunchaku_available():
from .utils import replace_with_nunchaku_linear
logger = logging.get_logger(__name__)
KEY_MAP = {
"lora_down": "proj_down",
"lora_up": "proj_up",
"smooth_orig": "smooth_factor_orig",
"smooth": "smooth_factor",
}
class NunchakuQuantizer(DiffusersQuantizer):
r"""
Diffusers Quantizer for Nunchaku (https://github.com/nunchaku-tech/nunchaku)
"""
use_keep_in_fp32_modules = True
requires_calibration = False
required_packages = ["nunchaku", "accelerate"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for nunchaku quantization.")
if not is_nunchaku_available():
raise ImportError(
"Loading an nunchaku quantized model requires nunchaku library (follow https://nunchaku.tech/docs/nunchaku/installation/installation.html)"
)
if not is_nunchaku_version(">=", "0.3.1"):
raise ImportError(
"Loading an nunchaku quantized model requires `nunchaku>=1.0.0`. "
"Please upgrade your installation by following https://nunchaku.tech/docs/nunchaku/installation/installation.html."
)
if not is_accelerate_available():
raise ImportError(
"Loading an nunchaku quantized model requires accelerate library (`pip install accelerate`)"
)
# TODO: check
# device_map = kwargs.get("device_map", None)
# if isinstance(device_map, dict) and len(device_map.keys()) > 1:
# raise ValueError(
# "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the nunchaku backend"
# )
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
from nunchaku.models.linear import SVDQW4A4Linear
module, _ = get_module_from_name(model, param_name)
if self.pre_quantized and isinstance(module, SVDQW4A4Linear):
return True
return False
def create_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
*args,
**kwargs,
):
"""
Create a quantized parameter.
"""
from nunchaku.models.linear import SVDQW4A4Linear
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters and tensor_name not in module._buffers:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
if isinstance(module, SVDQW4A4Linear):
module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to(target_device)
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
precision = self.quantization_config.precision
expected_target_dtypes = [torch.int8]
if is_fp8_available():
expected_target_dtypes.append(torch.float8_e4m3fn)
if target_dtype not in expected_target_dtypes:
new_target_dtype = self.dtype_map[precision]
logger.info(f"target_dtype {target_dtype} is replaced by {new_target_dtype} for `nunchaku` quantization")
return new_target_dtype
else:
raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.")
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
# We force the `dtype` to be bfloat16, this is a requirement from `nunchaku`
logger.info(
"Overriding torch_dtype=%s with `torch_dtype=torch.bfloat16` due to "
"requirements of `nunchaku` to enable model loading in 4-bit. "
"Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
" torch_dtype=torch.bfloat16 to remove this warning.",
torch_dtype,
)
torch_dtype = torch.bfloat16
return torch_dtype
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
self.modules_to_not_convert.extend(keep_in_fp32_modules)
# Purge `None`.
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
# in case of diffusion transformer models. For language models and others alike, `lm_head`
# and tied modules are usually kept in FP32.
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
self.modules_to_not_convert.extend(keys_on_cpu)
model = replace_with_nunchaku_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model, **kwargs):
return model
@property
def is_serializable(self):
return False
@property
def is_trainable(self):
return False

View File

@@ -0,0 +1,80 @@
import torch.nn as nn
from ...utils import is_accelerate_available, is_nunchaku_available, logging
if is_accelerate_available():
from accelerate import init_empty_weights
logger = logging.get_logger(__name__)
def _replace_with_nunchaku_linear(
model,
svdq_linear_cls,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
model._modules[name] = svdq_linear_cls(
in_features,
out_features,
rank=quantization_config.rank,
bias=module.bias is not None,
torch_dtype=module.weight.dtype,
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_nunchaku_linear(
module,
svdq_linear_cls,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
if is_nunchaku_available():
from nunchaku.models.linear import SVDQW4A4Linear
model, _ = _replace_with_nunchaku_linear(
model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config
)
has_been_replaced = any(
isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules()
)
if not has_been_replaced:
logger.warning(
"You are loading your model in the SVDQuant method but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model

View File

@@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum):
GGUF = "gguf"
TORCHAO = "torchao"
QUANTO = "quanto"
NUNCHAKU = "nunchaku"
if is_torchao_available():
@@ -724,3 +725,72 @@ class QuantoConfig(QuantizationConfigMixin):
accepted_weights = ["float8", "int8", "int4", "int2"]
if self.weights_dtype not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
class NunchakuConfig(QuantizationConfigMixin):
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `nunchaku`.
Args:
TODO
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision (e.g. `norm` layers in Qwen-Image).
"""
def __init__(
self,
method: str = "svdquant",
weight_dtype: str = "int4",
weight_scale_dtype: str = None,
weight_group_size: int = 64,
activation_dtype: str = "int4",
activation_scale_dtype: str = None,
activation_group_size: int = 64,
rank: int = 32,
modules_to_not_convert: Optional[List[str]] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.NUNCHAKU
self.method = method
self.weight_dtype = weight_dtype
self.weight_scale_dtype = weight_scale_dtype
self.weight_group_size = weight_group_size
self.activation_dtype = activation_dtype
self.activation_scale_dtype = activation_scale_dtype
self.activation_group_size = activation_group_size
self.rank = rank
self.modules_to_not_convert = modules_to_not_convert
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct. Hardware checks were largely adapted from the official `nunchaku`
codebase.
"""
from ..utils.torch_utils import get_device
device = get_device()
if isinstance(device, str):
device = torch.device(device)
capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
sm = f"{capability[0]}{capability[1]}"
if sm == "120": # you can only use the fp4 models
if self.weight_dtype != "fp4_e2m1_all":
raise ValueError('Please use "fp4" quantization for Blackwell GPUs.')
elif sm in ["75", "80", "86", "89"]:
if self.weight_dtype != "int4":
raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs.')
else:
raise ValueError(
f"Unsupported GPU architecture {sm} due to the lack of 4-bit tensorcores. "
"Please use a Turing, Ampere, Ada or Blackwell GPU for this quantization configuration."
)
# TODO: should there be a check for rank?
def __repr__(self):
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"

View File

@@ -89,6 +89,7 @@ from .import_utils import (
is_matplotlib_available,
is_nltk_available,
is_note_seq_available,
is_nunchaku_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,

View File

@@ -0,0 +1,17 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class NunchakuConfig(metaclass=DummyObject):
_backends = ["nunchaku"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["nunchaku"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["nunchaku"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["nunchaku"])

View File

@@ -217,6 +217,7 @@ _gguf_available, _gguf_version = _is_package_available("gguf")
_torchao_available, _torchao_version = _is_package_available("torchao")
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
_nunchaku_available, _nunchaku_version = _is_package_available("nunchaku", get_dist_name=True)
_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface")
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
_nltk_available, _nltk_version = _is_package_available("nltk")
@@ -363,6 +364,10 @@ def is_optimum_quanto_available():
return _optimum_quanto_available
def is_nunchaku_available():
return _nunchaku_available
def is_timm_available():
return _timm_available
@@ -816,7 +821,7 @@ def is_k_diffusion_version(operation: str, version: str):
def is_optimum_quanto_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
Compares the current quanto version to a given reference with an operation.
Args:
operation (`str`):
@@ -829,6 +834,21 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version)
def is_nunchaku_version(operation: str, version: str):
"""
Compares the current nunchaku version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _nunchaku_available:
return False
return compare_versions(parse(_nunchaku_version), operation, version)
def is_xformers_version(operation: str, version: str):
"""
Compares the current xformers version to a given reference with an operation.

View File

@@ -197,3 +197,7 @@ def device_synchronize(device_type: Optional[str] = None):
device_type = get_device()
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.synchronize()
def is_fp8_available():
return getattr(torch, "float8_e4m3fn", None) is None