Compare commits

...

8 Commits

Author SHA1 Message Date
DN6
3668dde2b7 update 2025-08-05 18:10:01 +05:30
DN6
5c4eee56e5 update 2025-08-04 21:37:06 +05:30
Dhruv Nair
cb004ad5e6 update 2025-07-24 11:03:39 +02:00
Dhruv Nair
db94e2b5a7 update 2025-07-24 06:30:12 +02:00
DN6
de1fb4b615 update 2025-07-24 08:31:47 +05:30
Isotr0py
e46571a7aa optimize
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-07-06 01:47:13 +08:00
Isotr0py
66bd237bc5 fix
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-07-06 01:00:01 +08:00
Isotr0py
6c4d01def7 add gguf kernel support
Signed-off-by: Isotr0py <2037008807@qq.com>
2025-07-05 17:47:06 +08:00
7 changed files with 179 additions and 4 deletions

View File

@@ -333,7 +333,7 @@ jobs:
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
additional_deps: ["peft"]
additional_deps: ["peft", "kernels"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []

View File

@@ -53,6 +53,16 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```
## Using Optimized CUDA Kernels with GGUF
Optimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library:
```shell
pip install -U kernels
```
Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`.
## Supported Quantization Types
- BF16

View File

@@ -12,15 +12,15 @@
# # See the License for the specific language governing permissions and
# # limitations under the License.
import inspect
import os
from contextlib import nullcontext
import gguf
import torch
import torch.nn as nn
from ...utils import is_accelerate_available
from ...utils import is_accelerate_available, is_kernels_available
if is_accelerate_available():
@@ -29,6 +29,82 @@ if is_accelerate_available():
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
can_use_cuda_kernels = (
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 7
)
if can_use_cuda_kernels and is_kernels_available():
from kernels import get_kernel
ops = get_kernel("Isotr0py/ggml")
else:
ops = None
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
STANDARD_QUANT_TYPES = {
gguf.GGMLQuantizationType.Q4_0,
gguf.GGMLQuantizationType.Q4_1,
gguf.GGMLQuantizationType.Q5_0,
gguf.GGMLQuantizationType.Q5_1,
gguf.GGMLQuantizationType.Q8_0,
gguf.GGMLQuantizationType.Q8_1,
}
KQUANT_TYPES = {
gguf.GGMLQuantizationType.Q2_K,
gguf.GGMLQuantizationType.Q3_K,
gguf.GGMLQuantizationType.Q4_K,
gguf.GGMLQuantizationType.Q5_K,
gguf.GGMLQuantizationType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
gguf.GGMLQuantizationType.IQ1_M,
gguf.GGMLQuantizationType.IQ1_S,
gguf.GGMLQuantizationType.IQ2_XXS,
gguf.GGMLQuantizationType.IQ2_XS,
gguf.GGMLQuantizationType.IQ2_S,
gguf.GGMLQuantizationType.IQ3_XXS,
gguf.GGMLQuantizationType.IQ3_S,
gguf.GGMLQuantizationType.IQ4_XS,
gguf.GGMLQuantizationType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
# contiguous batching and inefficient with diffusers' batching,
# so we disabled it now.
# elif qweight_type in MMVQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# elif qweight_type in MMQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
if qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.to(x.dtype).T
else:
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = gguf.GGMLQuantizationType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
return y.as_tensor()
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
def _create_accelerate_new_hook(old_hook):
r"""
@@ -451,11 +527,24 @@ class GGUFLinear(nn.Linear):
) -> None:
super().__init__(in_features, out_features, bias, device)
self.compute_dtype = compute_dtype
self.device = device
def forward(self, inputs):
def forward(self, inputs: torch.Tensor):
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
return self.forward_cuda(inputs)
return self.forward_native(inputs)
def forward_native(self, inputs: torch.Tensor):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
output = torch.nn.functional.linear(inputs, weight, bias)
return output
def forward_cuda(self, inputs: torch.Tensor):
quant_type = self.weight.quant_type
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
if self.bias is not None:
output += self.bias.to(self.compute_dtype)
return output

View File

@@ -78,6 +78,7 @@ from .import_utils import (
is_invisible_watermark_available,
is_k_diffusion_available,
is_k_diffusion_version,
is_kernels_available,
is_librosa_available,
is_matplotlib_available,
is_nltk_available,

View File

@@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
_inflect_available, _inflect_version = _is_package_available("inflect")
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
@@ -274,6 +275,10 @@ def is_accelerate_available():
return _accelerate_available
def is_kernels_available():
return _kernels_available
def is_k_diffusion_available():
return _k_diffusion_available

View File

@@ -35,6 +35,7 @@ from .import_utils import (
is_compel_available,
is_flax_available,
is_gguf_available,
is_kernels_available,
is_note_seq_available,
is_onnx_available,
is_opencv_available,
@@ -629,6 +630,18 @@ def require_torchao_version_greater_or_equal(torchao_version):
return decorator
def require_kernels_version_greater_or_equal(kernels_version):
def decorator(test_case):
correct_kernels_version = is_kernels_available() and version.parse(
version.parse(importlib.metadata.version("kernels")).base_version
) >= version.parse(kernels_version)
return unittest.skipUnless(
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend

View File

@@ -29,19 +29,76 @@ from diffusers.utils.testing_utils import (
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
require_accelerator,
require_big_accelerator,
require_gguf_version_greater_or_equal,
require_kernels_version_greater_or_equal,
require_peft_backend,
torch_device,
)
if is_gguf_available():
import gguf
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
enable_full_determinism()
@nightly
@require_accelerate
@require_accelerator
@require_gguf_version_greater_or_equal("0.10.0")
@require_kernels_version_greater_or_equal("0.9.0")
class GGUFCudaKernelsTests(unittest.TestCase):
def setUp(self):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
def test_cuda_kernels_vs_native(self):
if torch_device != "cuda":
self.skipTest("CUDA kernels test requires CUDA device")
from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels
if not can_use_cuda_kernels:
self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)")
test_quant_types = ["Q4_0", "Q4_K"]
test_shape = (1, 64, 512) # batch, seq_len, hidden_dim
compute_dtype = torch.bfloat16
for quant_type in test_quant_types:
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
in_features, out_features = 512, 512
torch.manual_seed(42)
float_weight = torch.randn(out_features, in_features, dtype=torch.float32)
quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype)
weight_data = torch.from_numpy(quantized_data).to(device=torch_device)
weight = GGUFParameter(weight_data, quant_type=qtype)
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype)
linear.weight = weight
linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype))
linear = linear.to(torch_device)
with torch.no_grad():
output_native = linear.forward_native(x)
output_cuda = linear.forward_cuda(x)
assert torch.allclose(output_native, output_cuda, 1e-2), (
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
)
@nightly
@require_big_accelerator
@require_accelerate