mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
8 Commits
custom-cod
...
gguf-cuda-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3668dde2b7 | ||
|
|
5c4eee56e5 | ||
|
|
cb004ad5e6 | ||
|
|
db94e2b5a7 | ||
|
|
de1fb4b615 | ||
|
|
e46571a7aa | ||
|
|
66bd237bc5 | ||
|
|
6c4d01def7 |
2
.github/workflows/nightly_tests.yml
vendored
2
.github/workflows/nightly_tests.yml
vendored
@@ -333,7 +333,7 @@ jobs:
|
||||
additional_deps: ["peft"]
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
additional_deps: ["peft"]
|
||||
additional_deps: ["peft", "kernels"]
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
additional_deps: []
|
||||
|
||||
@@ -53,6 +53,16 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
|
||||
image.save("flux-gguf.png")
|
||||
```
|
||||
|
||||
## Using Optimized CUDA Kernels with GGUF
|
||||
|
||||
Optimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library:
|
||||
|
||||
```shell
|
||||
pip install -U kernels
|
||||
```
|
||||
|
||||
Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`.
|
||||
|
||||
## Supported Quantization Types
|
||||
|
||||
- BF16
|
||||
|
||||
@@ -12,15 +12,15 @@
|
||||
# # See the License for the specific language governing permissions and
|
||||
# # limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils import is_accelerate_available
|
||||
from ...utils import is_accelerate_available, is_kernels_available
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -29,6 +29,82 @@ if is_accelerate_available():
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
|
||||
|
||||
can_use_cuda_kernels = (
|
||||
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 7
|
||||
)
|
||||
if can_use_cuda_kernels and is_kernels_available():
|
||||
from kernels import get_kernel
|
||||
|
||||
ops = get_kernel("Isotr0py/ggml")
|
||||
else:
|
||||
ops = None
|
||||
|
||||
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
|
||||
STANDARD_QUANT_TYPES = {
|
||||
gguf.GGMLQuantizationType.Q4_0,
|
||||
gguf.GGMLQuantizationType.Q4_1,
|
||||
gguf.GGMLQuantizationType.Q5_0,
|
||||
gguf.GGMLQuantizationType.Q5_1,
|
||||
gguf.GGMLQuantizationType.Q8_0,
|
||||
gguf.GGMLQuantizationType.Q8_1,
|
||||
}
|
||||
KQUANT_TYPES = {
|
||||
gguf.GGMLQuantizationType.Q2_K,
|
||||
gguf.GGMLQuantizationType.Q3_K,
|
||||
gguf.GGMLQuantizationType.Q4_K,
|
||||
gguf.GGMLQuantizationType.Q5_K,
|
||||
gguf.GGMLQuantizationType.Q6_K,
|
||||
}
|
||||
IMATRIX_QUANT_TYPES = {
|
||||
gguf.GGMLQuantizationType.IQ1_M,
|
||||
gguf.GGMLQuantizationType.IQ1_S,
|
||||
gguf.GGMLQuantizationType.IQ2_XXS,
|
||||
gguf.GGMLQuantizationType.IQ2_XS,
|
||||
gguf.GGMLQuantizationType.IQ2_S,
|
||||
gguf.GGMLQuantizationType.IQ3_XXS,
|
||||
gguf.GGMLQuantizationType.IQ3_S,
|
||||
gguf.GGMLQuantizationType.IQ4_XS,
|
||||
gguf.GGMLQuantizationType.IQ4_NL,
|
||||
}
|
||||
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
|
||||
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
|
||||
# MMQ kernel for I-Matrix quantization.
|
||||
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
|
||||
|
||||
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
|
||||
# there is no need to call any kernel for fp16/bf16
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
|
||||
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
|
||||
# contiguous batching and inefficient with diffusers' batching,
|
||||
# so we disabled it now.
|
||||
|
||||
# elif qweight_type in MMVQ_QUANT_TYPES:
|
||||
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# elif qweight_type in MMQ_QUANT_TYPES:
|
||||
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
|
||||
# If there is no available MMQ kernel, fallback to dequantize
|
||||
if qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
|
||||
y = x @ weight.to(x.dtype).T
|
||||
else:
|
||||
# Raise an error if the quantization type is not supported.
|
||||
# Might be useful if llama.cpp adds a new quantization type.
|
||||
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
|
||||
qweight_type = gguf.GGMLQuantizationType(qweight_type)
|
||||
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
return y.as_tensor()
|
||||
|
||||
|
||||
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
|
||||
def _create_accelerate_new_hook(old_hook):
|
||||
r"""
|
||||
@@ -451,11 +527,24 @@ class GGUFLinear(nn.Linear):
|
||||
) -> None:
|
||||
super().__init__(in_features, out_features, bias, device)
|
||||
self.compute_dtype = compute_dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
|
||||
return self.forward_cuda(inputs)
|
||||
return self.forward_native(inputs)
|
||||
|
||||
def forward_native(self, inputs: torch.Tensor):
|
||||
weight = dequantize_gguf_tensor(self.weight)
|
||||
weight = weight.to(self.compute_dtype)
|
||||
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
|
||||
|
||||
output = torch.nn.functional.linear(inputs, weight, bias)
|
||||
return output
|
||||
|
||||
def forward_cuda(self, inputs: torch.Tensor):
|
||||
quant_type = self.weight.quant_type
|
||||
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
|
||||
if self.bias is not None:
|
||||
output += self.bias.to(self.compute_dtype)
|
||||
return output
|
||||
|
||||
@@ -78,6 +78,7 @@ from .import_utils import (
|
||||
is_invisible_watermark_available,
|
||||
is_k_diffusion_available,
|
||||
is_k_diffusion_version,
|
||||
is_kernels_available,
|
||||
is_librosa_available,
|
||||
is_matplotlib_available,
|
||||
is_nltk_available,
|
||||
|
||||
@@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
|
||||
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
|
||||
_transformers_available, _transformers_version = _is_package_available("transformers")
|
||||
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
|
||||
_kernels_available, _kernels_version = _is_package_available("kernels")
|
||||
_inflect_available, _inflect_version = _is_package_available("inflect")
|
||||
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
|
||||
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
|
||||
@@ -274,6 +275,10 @@ def is_accelerate_available():
|
||||
return _accelerate_available
|
||||
|
||||
|
||||
def is_kernels_available():
|
||||
return _kernels_available
|
||||
|
||||
|
||||
def is_k_diffusion_available():
|
||||
return _k_diffusion_available
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from .import_utils import (
|
||||
is_compel_available,
|
||||
is_flax_available,
|
||||
is_gguf_available,
|
||||
is_kernels_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
@@ -629,6 +630,18 @@ def require_torchao_version_greater_or_equal(torchao_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_kernels_version_greater_or_equal(kernels_version):
|
||||
def decorator(test_case):
|
||||
correct_kernels_version = is_kernels_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("kernels")).base_version
|
||||
) >= version.parse(kernels_version)
|
||||
return unittest.skipUnless(
|
||||
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
|
||||
@@ -29,19 +29,76 @@ from diffusers.utils.testing_utils import (
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate,
|
||||
require_accelerator,
|
||||
require_big_accelerator,
|
||||
require_gguf_version_greater_or_equal,
|
||||
require_kernels_version_greater_or_equal,
|
||||
require_peft_backend,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_gguf_available():
|
||||
import gguf
|
||||
|
||||
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@nightly
|
||||
@require_accelerate
|
||||
@require_accelerator
|
||||
@require_gguf_version_greater_or_equal("0.10.0")
|
||||
@require_kernels_version_greater_or_equal("0.9.0")
|
||||
class GGUFCudaKernelsTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_cuda_kernels_vs_native(self):
|
||||
if torch_device != "cuda":
|
||||
self.skipTest("CUDA kernels test requires CUDA device")
|
||||
|
||||
from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels
|
||||
|
||||
if not can_use_cuda_kernels:
|
||||
self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)")
|
||||
|
||||
test_quant_types = ["Q4_0", "Q4_K"]
|
||||
test_shape = (1, 64, 512) # batch, seq_len, hidden_dim
|
||||
compute_dtype = torch.bfloat16
|
||||
|
||||
for quant_type in test_quant_types:
|
||||
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
|
||||
in_features, out_features = 512, 512
|
||||
|
||||
torch.manual_seed(42)
|
||||
float_weight = torch.randn(out_features, in_features, dtype=torch.float32)
|
||||
quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype)
|
||||
weight_data = torch.from_numpy(quantized_data).to(device=torch_device)
|
||||
weight = GGUFParameter(weight_data, quant_type=qtype)
|
||||
|
||||
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
|
||||
|
||||
linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype)
|
||||
linear.weight = weight
|
||||
linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype))
|
||||
linear = linear.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output_native = linear.forward_native(x)
|
||||
output_cuda = linear.forward_cuda(x)
|
||||
|
||||
assert torch.allclose(output_native, output_cuda, 1e-2), (
|
||||
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
|
||||
)
|
||||
|
||||
|
||||
@nightly
|
||||
@require_big_accelerator
|
||||
@require_accelerate
|
||||
|
||||
Reference in New Issue
Block a user