Compare commits

...

11 Commits

Author SHA1 Message Date
Sayak Paul
8d7ef7f32c Merge branch 'main' into higgs 2025-08-06 20:18:44 +05:30
Aryan
cfd6ec7465 [refactor] condense group offloading (#11990)
* update

* update

* refactor

* add test

* address review comment

* nit
2025-08-06 20:01:02 +05:30
jiqing-feng
1082c46afa fix input shape for WanGGUFTexttoVideoSingleFileTests (#12081)
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
2025-08-06 14:12:40 +05:30
Isotr0py
ba2ba9019f Add cuda kernel support for GGUF inference (#11869)
* add gguf kernel support

Signed-off-by: Isotr0py <2037008807@qq.com>

* fix

Signed-off-by: Isotr0py <2037008807@qq.com>

* optimize

Signed-off-by: Isotr0py <2037008807@qq.com>

* update

* update

* update

* update

* update

---------

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: DN6 <dhruv.nair@gmail.com>
2025-08-05 21:36:48 +05:30
C
fa4c0e5e2e optimize QwenImagePipeline to reduce unnecessary CUDA synchronization (#12072) 2025-08-05 04:12:47 -10:00
Sayak Paul
b793debd9d [tests] deal with the failing AudioLDM2 tests (#12069)
up
2025-08-05 15:54:25 +05:30
Sayak Paul
644bc18cc3 Merge branch 'main' into higgs 2025-08-01 08:14:41 +05:30
sayakpaul
34f0ef37cb updates 2025-06-25 13:15:47 +05:30
sayakpaul
c312812eae updates 2025-06-24 17:50:31 +05:30
sayakpaul
f82de3339e updates 2025-06-24 16:25:20 +05:30
sayakpaul
ea6c364485 start higgs 2025-06-24 16:06:13 +05:30
18 changed files with 1332 additions and 121 deletions

View File

@@ -333,7 +333,7 @@ jobs:
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
additional_deps: ["peft"]
additional_deps: ["peft", "kernels"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []

View File

@@ -53,6 +53,16 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```
## Using Optimized CUDA Kernels with GGUF
Optimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library:
```shell
pip install -U kernels
```
Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`.
## Supported Quantization Types
- BF16

View File

@@ -95,7 +95,7 @@ class ModuleGroup:
self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False
if self.offload_to_disk_path:
if self.offload_to_disk_path is not None:
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self.group_id = group_id if group_id is not None else str(id(self))
short_hash = _compute_group_hash(self.group_id)
@@ -115,6 +115,12 @@ class ModuleGroup:
else:
self.cpu_param_dict = self._init_cpu_param_dict()
self._torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
@@ -138,112 +144,76 @@ class ModuleGroup:
@contextmanager
def _pinned_memory_tensors(self):
pinned_dict = {}
try:
for param, tensor in self.cpu_param_dict.items():
if not tensor.is_pinned():
pinned_dict[param] = tensor.pin_memory()
else:
pinned_dict[param] = tensor
pinned_dict = {
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
for param, tensor in self.cpu_param_dict.items()
}
yield pinned_dict
finally:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
def _transfer_tensor_to_device(self, tensor, source_tensor):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream and current_stream is not None:
tensor.data.record_stream(current_stream)
if self.record_stream:
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
def _process_tensors_from_modules(self, pinned_memory=None):
for group_module in self.modules:
for param in group_module.parameters():
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream)
self._transfer_tensor_to_device(param, source)
for buffer in group_module.buffers():
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream)
self._transfer_tensor_to_device(buffer, source)
for param in self.parameters:
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream)
self._transfer_tensor_to_device(param, source)
for buffer in self.buffers:
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream)
def _onload_from_disk(self, current_stream):
if self.stream is not None:
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
with self._pinned_memory_tensors() as pinned_memory:
for key, tensor_obj in self.key_to_tensor.items():
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
self.cpu_param_dict.clear()
else:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
def _onload_from_memory(self, current_stream):
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory, current_stream)
else:
self._process_tensors_from_modules(None, current_stream)
@torch.compiler.disable()
def onload_(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
if self.offload_to_disk_path:
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
self._transfer_tensor_to_device(buffer, source)
def _onload_from_disk(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
with context:
if self.offload_to_disk_path:
self._onload_from_disk(current_stream)
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
device = str(self.onload_device) if self.stream is None else "cpu"
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
if self.stream is not None:
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
self._onload_from_memory(current_stream)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
def _onload_from_memory(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
with context:
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory)
else:
self._process_tensors_from_modules(None)
def _offload_to_disk(self):
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
@@ -264,14 +234,10 @@ class ModuleGroup:
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
def _offload_to_memory(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
if self.stream is not None:
if not self.record_stream:
torch_accelerator_module.current_stream().synchronize()
self._torch_accelerator_module.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
@@ -282,15 +248,23 @@ class ModuleGroup:
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):
r"""Onloads the group of parameters to the onload_device."""
if self.offload_to_disk_path is not None:
self._onload_from_disk()
else:
self._onload_from_memory()
@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
r"""Offloads the group of parameters to the offload_device."""
if self.offload_to_disk_path:
self._offload_to_disk()
else:
@@ -307,11 +281,9 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False
def __init__(
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
) -> None:
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
self.group = group
self.next_group = next_group
self.next_group: Optional[ModuleGroup] = None
self.config = config
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
@@ -459,8 +431,8 @@ class LayerExecutionTrackerHook(ModelHook):
def apply_group_offloading(
module: torch.nn.Module,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
onload_device: Union[str, torch.device],
offload_device: Union[str, torch.device] = torch.device("cpu"),
offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
@@ -546,6 +518,8 @@ def apply_group_offloading(
```
"""
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
offload_type = GroupOffloadingType(offload_type)
stream = None
@@ -633,7 +607,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, None, config=config)
_apply_group_offloading_hook(group_module, group, config=config)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
@@ -662,9 +636,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
@@ -693,7 +667,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(submodule, group, None, config=config)
_apply_group_offloading_hook(submodule, group, config=config)
modules_with_group_offloading.add(name)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -740,7 +714,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(parent_module, group, None, config=config)
_apply_group_offloading_hook(parent_module, group, config=config)
if config.stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -762,13 +736,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self=True,
group_id=_GROUP_ID_LAZY_LEAF,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
def _apply_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
@@ -777,14 +750,13 @@ def _apply_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group, config=config)
hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
def _apply_lazy_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
@@ -793,7 +765,7 @@ def _apply_lazy_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group, config=config)
hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()

View File

@@ -30,6 +30,7 @@ from huggingface_hub import DDUFEntry
from huggingface_hub.utils import EntryNotFoundError
from ..quantizers import DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
@@ -231,6 +232,7 @@ def load_model_dict_into_meta(
"""
is_quantized = hf_quantizer is not None
is_higgs = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HIGGS
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
@@ -280,7 +282,8 @@ def load_model_dict_into_meta(
# bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
if empty_state_dict[param_name].shape != param.shape:
# higgs quants repack the weights so they will have different shapes
if empty_state_dict[param_name].shape != param.shape and not is_higgs:
if (
is_quantized
and hf_quantizer.pre_quantized
@@ -304,7 +307,7 @@ def load_model_dict_into_meta(
hf_quantizer.create_quantized_param(
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
)
else:
elif hf_quantizer is not None:
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
return offload_index, state_dict_index

View File

@@ -312,15 +312,14 @@ class AudioLDM2Pipeline(DiffusionPipeline):
The sequence of generated hidden-states.
"""
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.0.dev0"):
if is_transformers_version("<", "4.52.1"):
cache_position_kwargs["input_ids"] = inputs_embeds
cache_position_kwargs["model_kwargs"] = model_kwargs
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
cache_position_kwargs["model_kwargs"] = model_kwargs
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)

View File

@@ -636,6 +636,11 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -654,7 +659,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
@@ -668,7 +673,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]

View File

@@ -21,9 +21,11 @@ from typing import Dict, Optional, Union
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
from .higgs import HiggsQuantizer
from .quantization_config import (
BitsAndBytesConfig,
GGUFQuantizationConfig,
HiggsConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
@@ -39,6 +41,7 @@ AUTO_QUANTIZER_MAPPING = {
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
"higgs": HiggsQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -47,6 +50,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
"higgs": HiggsConfig,
}

View File

@@ -12,15 +12,15 @@
# # See the License for the specific language governing permissions and
# # limitations under the License.
import inspect
import os
from contextlib import nullcontext
import gguf
import torch
import torch.nn as nn
from ...utils import is_accelerate_available
from ...utils import is_accelerate_available, is_kernels_available
if is_accelerate_available():
@@ -29,6 +29,82 @@ if is_accelerate_available():
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
can_use_cuda_kernels = (
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 7
)
if can_use_cuda_kernels and is_kernels_available():
from kernels import get_kernel
ops = get_kernel("Isotr0py/ggml")
else:
ops = None
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
STANDARD_QUANT_TYPES = {
gguf.GGMLQuantizationType.Q4_0,
gguf.GGMLQuantizationType.Q4_1,
gguf.GGMLQuantizationType.Q5_0,
gguf.GGMLQuantizationType.Q5_1,
gguf.GGMLQuantizationType.Q8_0,
gguf.GGMLQuantizationType.Q8_1,
}
KQUANT_TYPES = {
gguf.GGMLQuantizationType.Q2_K,
gguf.GGMLQuantizationType.Q3_K,
gguf.GGMLQuantizationType.Q4_K,
gguf.GGMLQuantizationType.Q5_K,
gguf.GGMLQuantizationType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
gguf.GGMLQuantizationType.IQ1_M,
gguf.GGMLQuantizationType.IQ1_S,
gguf.GGMLQuantizationType.IQ2_XXS,
gguf.GGMLQuantizationType.IQ2_XS,
gguf.GGMLQuantizationType.IQ2_S,
gguf.GGMLQuantizationType.IQ3_XXS,
gguf.GGMLQuantizationType.IQ3_S,
gguf.GGMLQuantizationType.IQ4_XS,
gguf.GGMLQuantizationType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
# contiguous batching and inefficient with diffusers' batching,
# so we disabled it now.
# elif qweight_type in MMVQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# elif qweight_type in MMQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
if qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.to(x.dtype).T
else:
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = gguf.GGMLQuantizationType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
return y.as_tensor()
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
def _create_accelerate_new_hook(old_hook):
r"""
@@ -451,11 +527,24 @@ class GGUFLinear(nn.Linear):
) -> None:
super().__init__(in_features, out_features, bias, device)
self.compute_dtype = compute_dtype
self.device = device
def forward(self, inputs):
def forward(self, inputs: torch.Tensor):
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
return self.forward_cuda(inputs)
return self.forward_native(inputs)
def forward_native(self, inputs: torch.Tensor):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
output = torch.nn.functional.linear(inputs, weight, bias)
return output
def forward_cuda(self, inputs: torch.Tensor):
quant_type = self.weight.quant_type
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
if self.bias is not None:
output += self.bias.to(self.compute_dtype)
return output

View File

@@ -0,0 +1 @@
from .higgs_quantizer import HiggsQuantizer

View File

@@ -0,0 +1,205 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from
https://github.com/huggingface/transformers/blob/d3d835d4fc145e5062d2153ac23ccd4b3e2c2cbd/src/transformers/quantizers/quantizer_higgs.py
"""
from typing import TYPE_CHECKING, Any, Optional
from ...utils import get_module_from_name
from ..base import DiffusersQuantizer
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
from ...utils import is_accelerate_available, is_torch_available, logging
from ...utils.logging import tqdm
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class HiggsQuantizer(DiffusersQuantizer):
"""
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of
full-precision models.
"""
requires_calibration = False
requires_parameters_quantization = True
required_packages = ["flute-kernel", "fast_hadamard_transform"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
def validate_environment(self, device_map, **kwargs):
if not torch.cuda.is_available():
raise NotImplementedError("HIGGS quantization is only supported on GPU. Please use a different quantizer.")
if not is_accelerate_available():
raise ImportError("Using `higgs` quantization requires Accelerate: `pip install accelerate`")
# TODO: enable this.
# if not is_flute_available():
# raise ImportError("Using `higgs` quantization requires FLUTE: `pip install flute-kernel>=0.3.0`")
# if not is_hadamard_available():
# raise ImportError(
# "Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`"
# )
if device_map is None:
raise ValueError(
"You are attempting to load a HIGGS model without setting device_map."
" Please set device_map comprised of 'cuda' devices."
)
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map."
)
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.float16` for FLUTE compatibility.")
torch_dtype = torch.float16
elif torch_dtype != torch.float16 and torch_dtype != torch.bfloat16:
raise ValueError(
f"Invalid `torch_dtype` {torch_dtype}. HIGGS quantization only supports `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16`."
)
return torch_dtype
def create_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: dict[str, Any],
unexpected_keys: Optional[list[str]] = None,
):
from .utils import quantize_with_higgs
"""
Quantizes weights into weight and weight_scale
"""
flute_dict = quantize_with_higgs(
param_value.to(target_device),
self.quantization_config.bits,
self.quantization_config.p,
self.quantization_config.group_size,
self.quantization_config.hadamard_size,
)
del param_value
module, _ = get_module_from_name(model, param_name)
module_name = ".".join(param_name.split(".")[:-1])
for key, value in flute_dict.items():
if key in module._parameters:
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
elif key in module._buffers:
module._buffers[key] = torch.nn.Buffer(value)
elif key == "tune_metadata":
module.tune_metadata = value
self.quantization_config.tune_metadata[module_name] = value.to_dict()
else:
raise ValueError(f"Unexpected key {key} in module {module}")
if unexpected_keys is not None and param_name in unexpected_keys:
unexpected_keys.remove(param_name)
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
from .utils import HiggsLinear
higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
def should_update(key: str) -> bool:
if key.endswith(".weight") or key.endswith(".bias"):
return False
full_key = f"{prefix}.{key}"
return any(name in key or name in full_key for name in higgs_names)
return [key for key in missing_keys if not should_update(key)]
@property
def is_trainable(self):
return False
def is_serializable(self):
return True
def check_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: dict[str, Any],
**kwargs,
) -> bool:
from .utils import HiggsLinear
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16:
# Only quantize weights of HiggsLinear modules that are not already quantized
return True
else:
return False
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
**kwargs,
):
from .utils import replace_with_higgs_linear
replace_with_higgs_linear(model, quantization_config=self.quantization_config)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
from flute.tune import TuneMetaData, maybe_tune_and_repack
from flute.utils import make_workspace_streamk
from .utils import HiggsLinear
flute_workspaces = {}
flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False):
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
if module.weight.device not in flute_workspaces:
flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device)
module.workspace = flute_workspaces[module.weight.device]
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name])
module.weight.data, module.tune_metadata = maybe_tune_and_repack(
weight=module.weight.data,
scales=module.scales.data,
metadata=module.tune_metadata,
)
self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict()
def _dequantize(self, model):
from .utils import dequantize_higgs
model = dequantize_higgs(model)
return model

View File

@@ -0,0 +1,690 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
HIGGS through FLUTE (Flexible Lookup Table Engine for LUT-quantized LLMs) integration file.
Taken from:
https://github.com/huggingface/transformers/blob/d3d835d4fc145e5062d2153ac23ccd4b3e2c2cbd/src/transformers/integrations/higgs.py
"""
from math import sqrt
from ...utils import (
# TODO enable:
# is_flute_available,
# is_hadamard_available,
is_torch_available,
logging,
)
if is_torch_available():
import torch
from torch import nn
# if is_flute_available():
# if is_hadamard_available():
from fast_hadamard_transform import hadamard_transform
from flute.integrations.higgs import prepare_data_transposed
from flute.tune import TuneMetaData, qgemm_v2
logger = logging.get_logger(__name__)
def pad_to_block(tensor, dims, had_block_size, value=0):
pad_dims = [0 for _ in range(2 * len(tensor.shape))]
for dim in dims:
size = tensor.shape[dim]
next_multiple_of_1024 = ((size - 1) // had_block_size + 1) * had_block_size
delta = next_multiple_of_1024 - size
pad_dims[-2 * dim - 1] = delta
return nn.functional.pad(tensor, pad_dims, "constant", value)
def get_higgs_grid(p: int, n: int):
if (p, n) == (2, 256):
return torch.tensor(
[
[-2.501467704772949, 0.17954708635807037],
[-0.6761789321899414, 1.2728623151779175],
[-1.8025816679000854, 0.7613157629966736],
[-0.538287878036499, -2.6028504371643066],
[0.8415029644966125, -0.8600977659225464],
[0.7023013234138489, 3.3138747215270996],
[0.5699077844619751, 2.5782253742218018],
[3.292393207550049, -0.6016128063201904],
[0.5561617016792297, -1.7723814249038696],
[-2.1012380123138428, 0.020958125591278076],
[0.46085724234580994, 0.8428705334663391],
[1.4548040628433228, -0.6156039237976074],
[3.210029363632202, 0.3546904921531677],
[0.8893890976905823, -0.5967988967895508],
[0.8618854284286499, -3.2061192989349365],
[1.1360996961593628, -0.23852407932281494],
[1.6646337509155273, -0.9265465140342712],
[1.4767773151397705, 1.2476022243499756],
[-1.0511897802352905, 1.94503915309906],
[-1.56318998336792, -0.3264186680316925],
[-0.1829211413860321, 0.2922491431236267],
[-0.8950616717338562, -1.3887052536010742],
[-0.08206957578659058, -1.329533576965332],
[-0.487422913312912, 1.4817842245101929],
[-1.6769757270812988, -2.8269758224487305],
[-1.5057679414749146, 1.8905963897705078],
[1.8335362672805786, 1.0515104532241821],
[0.3273945450782776, 1.0491033792495728],
[-3.295924186706543, -0.7021600008010864],
[-1.8428784608840942, -1.2315762042999268],
[-0.8575026392936707, -1.7005949020385742],
[-1.120667815208435, 0.6467998027801514],
[-0.1588846743106842, -1.804071068763733],
[-0.8539647459983826, 0.5645008683204651],
[-1.4192019701004028, -0.6175029873847961],
[1.0799058675765991, 1.7871345281600952],
[1.171311855316162, 0.7511613965034485],
[2.162078380584717, 0.8044339418411255],
[1.3969420194625854, -1.243762493133545],
[-0.23818807303905487, 0.053944624960422516],
[2.304199457168579, -1.2667627334594727],
[1.4225027561187744, 0.568610668182373],
[0.376836895942688, -0.7134661674499512],
[2.0404467582702637, 0.4087389409542084],
[0.7639489769935608, -1.1367933750152588],
[0.3622530400753021, -1.4827953577041626],
[0.4100743532180786, 0.36108437180519104],
[-1.5867475271224976, -1.618212342262268],
[-2.2769672870635986, -1.2132309675216675],
[0.9184022545814514, -0.34428009390830994],
[-0.3902314603328705, 0.21785245835781097],
[3.120687484741211, 1.3077973127365112],
[1.587440848350525, -1.6506884098052979],
[-1.718808889389038, -0.038405973464250565],
[-0.6888407468795776, -0.8402308821678162],
[-0.7981445789337158, -1.1117373704910278],
[-2.4124443531036377, 1.3419722318649292],
[-0.6611530184745789, 0.9939885139465332],
[-0.33103418350219727, -0.16702833771705627],
[-2.4091389179229736, -2.326857566833496],
[1.6610108613967896, -2.159703254699707],
[0.014884627424180508, 0.3887578248977661],
[0.029668325558304787, 1.8786455392837524],
[1.180362582206726, 2.699317216873169],
[1.821286678314209, -0.5960053205490112],
[-0.44835323095321655, 3.327436685562134],
[-0.3714401423931122, -2.1466753482818604],
[-1.1103475093841553, -2.4536871910095215],
[-0.39110705256462097, 0.6670510172843933],
[0.474752813577652, -1.1959707736968994],
[-0.013110585510730743, -2.52519154548645],
[-2.0836575031280518, -1.703289270401001],
[-1.1077687740325928, -0.1252644956111908],
[-0.4138077199459076, 1.1837692260742188],
[-1.977599024772644, 1.688241720199585],
[-1.659559965133667, -2.1387736797332764],
[0.03242531046271324, 0.6526556015014648],
[0.9127950072288513, 0.6099498867988586],
[-0.38478314876556396, 0.433487206697464],
[0.27454206347465515, -0.27719801664352417],
[0.10388526320457458, 2.2812814712524414],
[-0.014394169673323631, -3.177137613296509],
[-1.2871228456497192, -0.8961855173110962],
[0.5720916986465454, -0.921597957611084],
[1.1159656047821045, -0.7609877586364746],
[2.4383342266082764, -2.2983546257019043],
[-0.294057160615921, -0.9770799875259399],
[-0.9342701435089111, 1.107579231262207],
[-1.549338698387146, 3.090520143508911],
[2.6076579093933105, 2.051239013671875],
[-0.9259037375450134, 1.407211184501648],
[-0.1747353971004486, 0.540488600730896],
[-0.8963701725006104, 0.8271111249923706],
[0.6480194926261902, 1.0128909349441528],
[0.980783998966217, -0.06156221032142639],
[-0.16883476078510284, 1.0601658821105957],
[0.5839992761611938, 0.004697148688137531],
[-0.34228450059890747, -1.2423977851867676],
[2.500824451446533, 0.3665279746055603],
[-0.17641609907150269, 1.3529551029205322],
[0.05378641560673714, 2.817232847213745],
[-1.2391047477722168, 2.354328155517578],
[0.630434513092041, -0.668536365032196],
[1.7576488256454468, 0.6738647818565369],
[0.4435231387615204, 0.6000469326972961],
[-0.08794835954904556, -0.11511358618736267],
[1.6540337800979614, 0.33995017409324646],
[-0.04202975332736969, -0.5375117063522339],
[-0.4247745871543884, -0.7897617220878601],
[0.06695003807544708, 1.2000739574432373],
[-3.2508881092071533, 0.28734830021858215],
[-1.613816261291504, 0.4944162368774414],
[1.3598989248275757, 0.26117825508117676],
[2.308382511138916, 1.3462618589401245],
[-1.2137469053268433, -1.9254342317581177],
[-0.4889402985572815, 1.8136259317398071],
[-0.1870335340499878, -0.3480615019798279],
[1.0766386985778809, -1.0627082586288452],
[0.4651014506816864, 2.131748914718628],
[-0.1306295394897461, -0.7811847925186157],
[0.06433182954788208, -1.5397958755493164],
[-0.2894323468208313, -0.5789554715156555],
[-0.6081662178039551, 0.4845278263092041],
[2.697964668273926, -0.18515698611736298],
[0.1277363896369934, -0.7221432328224182],
[0.8700758218765259, 0.35042452812194824],
[0.22088994085788727, 0.495242178440094],
[-2.5843818187713623, -0.8000828623771667],
[0.6732649803161621, -1.4362232685089111],
[-1.5286413431167603, 1.0417330265045166],
[-1.1222513914108276, -0.6269875764846802],
[-0.9752035140991211, -0.8750635385513306],
[-2.6369473934173584, 0.6918523907661438],
[0.14478731155395508, -0.041986867785453796],
[-1.5629483461380005, 1.4369450807571411],
[0.38952457904815674, -2.16428804397583],
[-0.16885095834732056, 0.7976621985435486],
[-3.12416934967041, 1.256506085395813],
[0.6843105554580688, -0.4203019142150879],
[1.9345275163650513, 1.934950351715088],
[0.012184220366179943, -2.1080918312072754],
[-0.6350273489952087, 0.7358828186988831],
[-0.837304949760437, -0.6214472651481628],
[0.08211923390626907, -0.9472538232803345],
[2.9332995414733887, -1.4956780672073364],
[1.3806978464126587, -0.2916182279586792],
[0.06773144006729126, 0.9285762310028076],
[-1.1943119764328003, 1.5963770151138306],
[1.6395620107650757, -0.32285431027412415],
[-1.390851378440857, -0.08273141086101532],
[1.816330909729004, -1.2812227010726929],
[0.7921574711799622, -2.1135804653167725],
[0.5817914605140686, 1.2644577026367188],
[1.929347038269043, -0.2386285960674286],
[0.8877345323562622, 1.190008521080017],
[1.4732073545455933, 0.8935023546218872],
[-2.8518524169921875, -1.5478795766830444],
[0.2439267635345459, 0.7576767802238464],
[0.5246709585189819, -2.606659412384033],
[1.150876760482788, 1.4073830842971802],
[-0.2643202245235443, 2.0634236335754395],
[1.555483341217041, -0.0023102816194295883],
[2.0830578804016113, -1.7225427627563477],
[-0.5424830317497253, -1.070199728012085],
[0.9168899655342102, 0.8955540060997009],
[-0.8120972514152527, 2.696739912033081],
[-0.29908373951911926, -1.5310651063919067],
[1.2320337295532227, -1.556247353553772],
[1.8612544536590576, 0.08704725652933121],
[0.22133447229862213, -1.8091708421707153],
[-0.4403655230998993, -0.38571012020111084],
[-1.88539457321167, 1.192205786705017],
[2.239687919616699, 0.004709010478109121],
[1.139495611190796, 0.45733731985092163],
[-1.507995367050171, 0.19716016948223114],
[0.46986445784568787, 1.5422041416168213],
[-1.2573751211166382, -0.35984551906585693],
[-1.7415345907211304, -0.6020717024803162],
[1.0751984119415283, 0.19006384909152985],
[2.24186635017395, -0.46343153715133667],
[0.3610347509384155, -0.07658443599939346],
[-1.3111497163772583, 0.432013601064682],
[0.6164408326148987, 0.24538464844226837],
[-1.9266542196273804, -0.3256155550479889],
[-0.5870336890220642, -0.1879584938287735],
[-1.0476511716842651, 0.3677721917629242],
[-1.229940414428711, 1.2433830499649048],
[0.18550436198711395, 0.22753673791885376],
[-0.017921989783644676, 0.12625974416732788],
[1.1659504175186157, -0.5020995736122131],
[-0.5983408093452454, -1.40438973903656],
[0.7519024014472961, -0.16282692551612854],
[0.9920787811279297, -1.344896912574768],
[-0.8103678226470947, 0.3064485788345337],
[0.6956969499588013, 1.8208192586898804],
[-2.7830491065979004, -0.2299390584230423],
[-0.34681546688079834, 2.4890666007995605],
[-1.4452646970748901, -1.2216600179672241],
[-2.1872897148132324, 0.8926076292991638],
[1.706072211265564, -2.8440372943878174],
[1.1119003295898438, -2.4923460483551025],
[-2.582794666290283, 2.0973289012908936],
[0.04987720400094986, -0.2964983284473419],
[-2.063807487487793, -0.7847916483879089],
[-0.4068813621997833, 0.9135897755622864],
[-0.9814359545707703, -0.3874954879283905],
[-1.4227229356765747, 0.7337291240692139],
[0.3065044581890106, 1.3125417232513428],
[1.2160996198654175, -1.9643305540084839],
[-1.2163853645324707, 0.14608727395534515],
[-2.3030710220336914, -0.37558120489120483],
[0.9232977628707886, 2.1843791007995605],
[-0.1989777386188507, 1.651851773262024],
[-0.714374840259552, -0.39365994930267334],
[-0.7805715799331665, -2.099881887435913],
[0.9015759229660034, -1.7053706645965576],
[0.1033422127366066, 1.5256654024124146],
[-1.8773194551467896, 2.324174165725708],
[1.9227174520492554, 2.7441604137420654],
[-0.5994020104408264, 0.23984014987945557],
[1.3496100902557373, -0.9126054644584656],
[-0.8765304088592529, -3.1877026557922363],
[-1.2040035724639893, -1.5169521570205688],
[1.4261796474456787, 2.150200128555298],
[1.463774561882019, 1.6656692028045654],
[0.20364105701446533, -0.4988172650337219],
[0.5195154547691345, -0.24067887663841248],
[-1.1116786003112793, -1.1599653959274292],
[-0.8490808606147766, -0.1681060940027237],
[0.3189965784549713, -0.9641751646995544],
[-0.5664751529693604, -0.5951744318008423],
[-1.6347930431365967, -0.9137664437294006],
[0.44048091769218445, -0.47259435057640076],
[-2.147747039794922, 0.47442489862442017],
[1.834734320640564, 1.4462147951126099],
[1.1777573823928833, 1.0659226179122925],
[-0.9568989872932434, 0.09495053440332413],
[-1.838529348373413, 0.2950586676597595],
[-0.4800611734390259, 0.014894310384988785],
[-0.5235516428947449, -1.7687653303146362],
[2.0735011100769043, -0.8825281262397766],
[2.637502431869507, 0.8455678224563599],
[2.606602907180786, -0.7848446369171143],
[-1.1886937618255615, 0.9330510497093201],
[0.38082656264305115, 0.13328030705451965],
[0.6847941875457764, 0.7384101152420044],
[1.2638574838638306, -0.007309418171644211],
[0.18292222917079926, -1.22371244430542],
[0.8143821954727173, 1.4976691007614136],
[0.6571850776672363, 0.48368802666664124],
[-0.6991601586341858, 2.150190830230713],
[0.8101756572723389, 0.10206498205661774],
[-0.08768226951360703, -1.084917664527893],
[-0.7208092212677002, 0.03657956421375275],
[0.3211449086666107, 1.803687334060669],
[-0.7835946083068848, 1.6869111061096191],
]
)
if (p, n) == (2, 64):
return torch.tensor(
[
[-2.7216711044311523, 0.14431366324424744],
[-0.766914427280426, 1.7193410396575928],
[-2.2575762271881104, 1.2476624250411987],
[1.233758807182312, -2.3560616970062256],
[0.8701965808868408, -0.2649352252483368],
[1.4506438970565796, 2.1776366233825684],
[-0.06305818259716034, 1.9049758911132812],
[2.536226511001587, 0.563927412033081],
[0.4599496126174927, -1.8745561838150024],
[-1.900517225265503, -0.30703988671302795],
[0.09386251866817474, 0.8755807280540466],
[1.946500539779663, -0.6743080615997314],
[2.1338934898376465, 1.4581491947174072],
[0.9429940581321716, -0.8038390278816223],
[2.0697755813598633, -1.614896535873413],
[0.772676408290863, 0.22017823159694672],
[1.0689979791641235, -1.525044322013855],
[0.6813604831695557, 1.1345642805099487],
[0.4706456661224365, 2.606626272201538],
[-1.294018030166626, -0.4372096061706543],
[-0.09134224057197571, 0.4610418677330017],
[-0.7907772064208984, -0.48412787914276123],
[0.060459110885858536, -0.9172890186309814],
[-0.5855047702789307, 2.56172513961792],
[0.11484206467866898, -2.659848213195801],
[-1.5893300771713257, 2.188580274581909],
[1.6750942468643188, 0.7089915871620178],
[-0.445697546005249, 0.7452405095100403],
[-1.8539940118789673, -1.8377939462661743],
[-1.5791912078857422, -1.017285943031311],
[-1.030419945716858, -1.5746369361877441],
[-1.9511750936508179, 0.43696075677871704],
[-0.3446580767631531, -1.8953213691711426],
[-1.4219647645950317, 0.7676230669021606],
[-0.9191089272499084, 0.5021472573280334],
[0.20464491844177246, 1.3684605360031128],
[0.5402919054031372, 0.6699410676956177],
[1.8903915882110596, 0.03638288006186485],
[0.4723062515258789, -0.6216739416122437],
[-0.41345009207725525, -0.22752176225185394],
[2.7119064331054688, -0.5111885070800781],
[1.065286636352539, 0.6950305700302124],
[0.40629103779792786, -0.14339995384216309],
[1.2815024852752686, 0.17108257114887238],
[0.01785222627222538, -0.43778058886528015],
[0.054590027779340744, -1.4225547313690186],
[0.3076786696910858, 0.30697619915008545],
[-0.9498570561408997, -0.9576997756958008],
[-2.4640724658966064, -0.9660449028015137],
[1.3714425563812256, -0.39760473370552063],
[-0.4857747256755829, 0.2386789172887802],
[1.2797833681106567, 1.3097363710403442],
[0.5508887767791748, -1.1777795553207397],
[-1.384316325187683, 0.1465839296579361],
[-0.46556955575942993, -1.2442727088928223],
[-0.3915477693080902, -0.7319604158401489],
[-1.4005504846572876, 1.3890998363494873],
[-0.8647305965423584, 1.0617644786834717],
[-0.8901953101158142, -0.01650036871433258],
[-0.9893633723258972, -2.4662880897521973],
[1.445534110069275, -1.049334168434143],
[-0.041650623083114624, 0.012734669260680676],
[-0.3302375078201294, 1.26217782497406],
[0.6934980154037476, 1.7714335918426514],
]
)
elif (p, n) == (2, 16):
return torch.tensor(
[
[-0.8996632695198059, -1.6360418796539307],
[-0.961183488368988, 1.5999565124511719],
[-1.882026195526123, 0.678778350353241],
[0.36300793290138245, -1.9667866230010986],
[-0.6814072728157043, -0.576818585395813],
[0.7270012497901917, 0.6186859607696533],
[0.3359416127204895, 1.8371193408966064],
[1.859930396080017, 0.036668598651885986],
[0.17208248376846313, -0.9401724338531494],
[-1.7599700689315796, -0.6244229674339294],
[-0.8993809223175049, 0.32267823815345764],
[0.839488685131073, -0.3017036020755768],
[1.5314953327178955, 1.2942044734954834],
[-0.0011779458727687597, 0.00022069070837460458],
[1.4274526834487915, -1.207889199256897],
[-0.16123905777931213, 0.8787511587142944],
]
)
elif (p, n) == (1, 16):
return torch.tensor(
[
[-2.7325894832611084],
[-2.069017171859741],
[-1.6180464029312134],
[-1.2562311887741089],
[-0.9423404335975647],
[-0.6567591428756714],
[-0.38804829120635986],
[-0.12839503586292267],
[0.12839503586292267],
[0.38804829120635986],
[0.6567591428756714],
[0.9423404335975647],
[1.2562311887741089],
[1.6180464029312134],
[2.069017171859741],
[2.7325894832611084],
]
)
elif (p, n) == (1, 8):
return torch.tensor(
[
[-2.1519455909729004],
[-1.3439092636108398],
[-0.7560052871704102],
[-0.2450941801071167],
[0.2450941801071167],
[0.7560052871704102],
[1.3439092636108398],
[2.1519455909729004],
]
)
elif (p, n) == (1, 4):
return torch.tensor([[-1.5104175806045532], [-0.4527800381183624], [0.4527800381183624], [1.5104175806045532]])
else:
raise NotImplementedError(f"Unsupported p={p}, n={n}")
def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256, hadamard_size: int = 1024):
assert len(weight.shape) == 2, "Only 2D weights are supported for now"
grid = get_higgs_grid(p, 2 ** (p * bits)).to(weight.device)
grid_norm_2 = torch.linalg.norm(grid, axis=-1) ** 2
device = weight.device
dtype = weight.dtype
weight = weight.to(copy=True, dtype=torch.float32)
# Pad to Hadamard transform size
weight = pad_to_block(weight, [1], hadamard_size)
# Scale and Hadamard transform
mult = weight.shape[1] // hadamard_size
weight = weight.reshape(-1, mult, hadamard_size)
scales = torch.linalg.norm(weight, axis=-1)
weight = hadamard_transform(weight, 1) / scales[:, :, None]
# Pad to edenn_d and project
weight = pad_to_block(weight, [2], p).reshape(weight.shape[0], mult, -1, p)
# Quantize
codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
for i in range(0, weight.shape[0], 16):
codes[i : i + 16] = torch.argmax(2 * weight[i : i + 16] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
del weight
codes = codes.reshape(codes.shape[0], -1)
scales = scales / sqrt(hadamard_size)
weight, scales, tables, tables2, tune_metadata = prepare_data_transposed(
codes,
torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
grid.to(dtype),
num_bits=bits,
group_size=group_size,
vector_size=p,
dtype=dtype,
device=device,
check_correctness=False,
)
return {
"weight": weight,
"scales": scales,
"tables": tables,
"tables2": tables2.view(dtype=torch.float16),
"tune_metadata": tune_metadata,
}
class HiggsLinear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
num_bits: int,
bias=True,
dtype: torch.dtype = None,
device: torch.device = None,
group_size: int = 256,
hadamard_size: int = 1024,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_bits = num_bits
self.group_size = group_size
self.hadamard_size = hadamard_size
assert in_features % group_size == 0
assert num_bits in [2, 3, 4]
self.weight = nn.Parameter(
torch.empty((out_features * num_bits // 16, in_features), dtype=torch.int16, device=device),
requires_grad=False,
)
self.scales = nn.Parameter(
torch.empty((out_features, in_features // group_size), dtype=dtype, device=device), requires_grad=False
)
self.tables = nn.Parameter(torch.empty((2**num_bits,), dtype=dtype, device=device), requires_grad=False)
self.tables2 = nn.Parameter(
torch.empty((2**num_bits, 2**num_bits, 2), dtype=dtype, device=device), requires_grad=False
)
if bias:
self.bias = nn.Parameter(torch.empty(out_features, device=device, dtype=dtype), requires_grad=False)
else:
self.register_parameter("bias", None)
self.workspace = None # must be set externally to be reused among layers
self.tune_metadata: TuneMetaData = None # must be set externally because architecture dependent
def forward(self, x):
x = pad_to_block(x, [-1], self.hadamard_size)
if self.workspace is None:
raise Exception("Workspace must be set before calling forward")
return qgemm_v2(
x,
self.weight,
self.scales,
self.tables,
self.tables2.view(dtype=torch.float32),
self.workspace,
self.tune_metadata,
hadamard_size=self.hadamard_size,
)
def _replace_with_higgs_linear(
model, quantization_config=None, current_key_name=None, modules_to_not_convert=None, has_been_replaced=False
):
from accelerate import init_empty_weights
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear):
current_key_name_str = ".".join(current_key_name)
if not any(current_key_name_str.endswith(key) for key in modules_to_not_convert):
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
# Original size is [3072, 4096]. But after `HiggsLinear`, this is
# [768, 4096]. 🤯
if name == "context_embedder":
print(f"{in_features=}, {out_features=}")
model._modules[name] = HiggsLinear(
in_features,
out_features,
bias=module.bias is not None,
num_bits=quantization_config.bits,
hadamard_size=quantization_config.hadamard_size,
group_size=quantization_config.group_size,
)
if name == "context_embedder":
print(model._modules[name].weight.shape)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_higgs_linear(
module,
quantization_config=quantization_config,
current_key_name=current_key_name,
modules_to_not_convert=modules_to_not_convert,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def replace_with_higgs_linear(
model,
quantization_config=None,
current_key_name=None,
has_been_replaced=False,
):
"""
Public method that recursively replaces the Linear layers of the given model with HIGGS quantized layers.
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
conversion has been successful or not.
Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
quantization_config (`HiggsConfig`):
The quantization config object that contains the quantization parameters.
current_key_name (`list`, *optional*):
A list that contains the current key name. This is used for recursion and should not be passed by the user.
has_been_replaced (`bool`, *optional*):
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
should not be passed by the user.
"""
modules_to_not_convert = quantization_config.modules_to_not_convert or []
model, _ = _replace_with_higgs_linear(
model, quantization_config, current_key_name, modules_to_not_convert, has_been_replaced
)
has_been_replaced = any(isinstance(replaced_module, HiggsLinear) for _, replaced_module in model.named_modules())
if not has_been_replaced:
logger.warning(
"You are loading your model in Higgs but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model
def dequantize_higgs(model, current_key_name=None):
"""
Dequantizes the HiggsLinear layers in the given model by replacing them with standard torch.nn.Linear layers.
Args:
model (torch.nn.Module): The model containing HiggsLinear layers to be dequantized.
current_key_name (list, optional):
A list to keep track of the current module names during recursion. Defaults to None.
Returns:
torch.nn.Module: The model with HiggsLinear layers replaced by torch.nn.Linear layers.
"""
with torch.no_grad():
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, HiggsLinear):
in_features = module.in_features
out_features = module.out_features
model._modules[name] = torch.nn.Linear(
in_features,
out_features,
bias=module.bias is not None,
device=module.scales.device,
dtype=module.scales.dtype,
)
model._modules[name].weight.data = module(
torch.eye(in_features, device=module.scales.device, dtype=module.scales.dtype)
).T.contiguous()
if len(list(module.children())) > 0:
_ = dequantize_higgs(
module,
current_key_name=current_key_name,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model

View File

@@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum):
GGUF = "gguf"
TORCHAO = "torchao"
QUANTO = "quanto"
HIGGS = "higgs"
if is_torchao_available():
@@ -724,3 +725,62 @@ class QuantoConfig(QuantizationConfigMixin):
accepted_weights = ["float8", "int8", "int4", "int2"]
if self.weights_dtype not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
@dataclass
class HiggsConfig(QuantizationConfigMixin):
"""
HiggsConfig is a configuration class for quantization using the HIGGS method.
Args:
bits (int, *optional*, defaults to 4):
Number of bits to use for quantization. Can be 2, 3 or 4. Default is 4.
p (int, *optional*, defaults to 2):
Quantization grid dimension. 1 and 2 are supported. 2 is always better in practice. Default is 2.
modules_to_not_convert (`list`, *optional*, default to ["lm_head"]):
List of linear layers that should not be quantized.
hadamard_size (int, *optional*, defaults to 512):
Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value.
Decreasing this below 512 will reduce the quality of the quantization.
group_size (int, *optional*, defaults to 256):
Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance.
Default is 256. Must be a divisor of hadamard_size.
tune_metadata ('dict', *optional*, defaults to {}):
Module-wise metadata (gemm block shapes, GPU metadata, etc.) for saving the kernel tuning results. Default
is an empty dictionary. Is set automatically during tuning.
"""
def __init__(
self,
bits: int = 4,
p: int = 2,
modules_to_not_convert: Optional[list[str]] = None,
hadamard_size: int = 512,
group_size: int = 256,
tune_metadata: Optional[dict[str, Any]] = None,
**kwargs,
):
if tune_metadata is None:
tune_metadata = {}
self.quant_method = QuantizationMethod.HIGGS
self.bits = bits
self.p = p
self.modules_to_not_convert = modules_to_not_convert
self.hadamard_size = hadamard_size
self.group_size = group_size
self.tune_metadata = tune_metadata
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if self.bits not in [2, 3, 4]:
raise ValueError("bits must be 2, 3, or 4")
if self.p not in [1, 2]:
raise ValueError("p must be 1 or 2. 2 is always better in practice")
if self.group_size not in [64, 128, 256]:
raise ValueError("group_size must be 64, 128, or 256")
if self.hadamard_size % self.group_size != 0:
raise ValueError("hadamard_size must be divisible by group_size")

View File

@@ -81,6 +81,7 @@ from .import_utils import (
is_invisible_watermark_available,
is_k_diffusion_available,
is_k_diffusion_version,
is_kernels_available,
is_librosa_available,
is_matplotlib_available,
is_nltk_available,

View File

@@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
_inflect_available, _inflect_version = _is_package_available("inflect")
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
@@ -277,6 +278,10 @@ def is_accelerate_available():
return _accelerate_available
def is_kernels_available():
return _kernels_available
def is_k_diffusion_available():
return _k_diffusion_available

View File

@@ -36,6 +36,7 @@ from .import_utils import (
is_compel_available,
is_flax_available,
is_gguf_available,
is_kernels_available,
is_note_seq_available,
is_onnx_available,
is_opencv_available,
@@ -634,6 +635,18 @@ def require_torchao_version_greater_or_equal(torchao_version):
return decorator
def require_kernels_version_greater_or_equal(kernels_version):
def decorator(test_case):
correct_kernels_version = is_kernels_available() and version.parse(
version.parse(importlib.metadata.version("kernels")).base_version
) >= version.parse(kernels_version)
return unittest.skipUnless(
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend

View File

@@ -17,7 +17,9 @@ import gc
import unittest
import torch
from parameterized import parameterized
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger
@@ -99,6 +101,29 @@ class DummyModelWithMultipleBlocks(ModelMixin):
return x
# Test for https://github.com/huggingface/diffusers/pull/12077
class DummyModelWithLayerNorm(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
)
self.layer_norm = torch.nn.LayerNorm(hidden_features, elementwise_affine=True)
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = self.activation(x)
for block in self.blocks:
x = block(x)
x = self.layer_norm(x)
x = self.linear_2(x)
return x
class DummyPipeline(DiffusionPipeline):
model_cpu_offload_seq = "model"
@@ -113,6 +138,16 @@ class DummyPipeline(DiffusionPipeline):
return x
class LayerOutputTrackerHook(ModelHook):
def __init__(self):
super().__init__()
self.outputs = []
def post_forward(self, module, output):
self.outputs.append(output)
return output
@require_torch_accelerator
class GroupOffloadTests(unittest.TestCase):
in_features = 64
@@ -258,6 +293,7 @@ class GroupOffloadTests(unittest.TestCase):
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
model = DummyModelWithMultipleBlocks(
in_features=self.in_features,
hidden_features=self.hidden_features,
@@ -274,3 +310,54 @@ class GroupOffloadTests(unittest.TestCase):
with context:
model(self.input)
@parameterized.expand([("block_level",), ("leaf_level",)])
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
for name, module in model.named_modules():
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LayerOutputTrackerHook()
registry.register_hook(hook, "layer_output_tracker")
model_ref = DummyModelWithLayerNorm(128, 256, 128, 2)
model = DummyModelWithLayerNorm(128, 256, 128, 2)
model.load_state_dict(model_ref.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
apply_layer_output_tracker_hook(model_ref)
apply_layer_output_tracker_hook(model)
x = torch.randn(2, 128).to(torch_device)
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
num_repeats = 4
for i in range(num_repeats):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
assert ref_name == name
ref_outputs = (
HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs
)
outputs = HookRegistry.check_if_exists_or_initialize(module).get_hook("layer_output_tracker").outputs
cumulated_absmax = 0.0
for i in range(len(outputs)):
diff = ref_outputs[0] - outputs[i]
absdiff = diff.abs()
absmax = absdiff.max().item()
cumulated_absmax += absmax
self.assertLess(
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
)

View File

@@ -45,6 +45,7 @@ from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import is_transformers_version
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
@@ -220,6 +221,11 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
}
return inputs
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.54.1"),
reason="Test currently fails on Transformers version 4.54.1.",
strict=False,
)
def test_audioldm2_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -312,7 +318,6 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
@@ -371,6 +376,11 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(audio_1 - audio_2).max() < 1e-2
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.54.1"),
reason="Test currently fails on Transformers version 4.54.1.",
strict=False,
)
def test_audioldm2_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()

View File

@@ -30,8 +30,10 @@ from diffusers.utils.testing_utils import (
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
require_accelerator,
require_big_accelerator,
require_gguf_version_greater_or_equal,
require_kernels_version_greater_or_equal,
require_peft_backend,
require_torch_version_greater,
torch_device,
@@ -41,11 +43,66 @@ from ..test_torch_compile_utils import QuantCompileTests
if is_gguf_available():
import gguf
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
enable_full_determinism()
@nightly
@require_accelerate
@require_accelerator
@require_gguf_version_greater_or_equal("0.10.0")
@require_kernels_version_greater_or_equal("0.9.0")
class GGUFCudaKernelsTests(unittest.TestCase):
def setUp(self):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
def test_cuda_kernels_vs_native(self):
if torch_device != "cuda":
self.skipTest("CUDA kernels test requires CUDA device")
from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels
if not can_use_cuda_kernels:
self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)")
test_quant_types = ["Q4_0", "Q4_K"]
test_shape = (1, 64, 512) # batch, seq_len, hidden_dim
compute_dtype = torch.bfloat16
for quant_type in test_quant_types:
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
in_features, out_features = 512, 512
torch.manual_seed(42)
float_weight = torch.randn(out_features, in_features, dtype=torch.float32)
quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype)
weight_data = torch.from_numpy(quantized_data).to(device=torch_device)
weight = GGUFParameter(weight_data, quant_type=qtype)
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype)
linear.weight = weight
linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype))
linear = linear.to(torch_device)
with torch.no_grad():
output_native = linear.forward_native(x)
output_cuda = linear.forward_cuda(x)
assert torch.allclose(output_native, output_cuda, 1e-2), (
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
)
@nightly
@require_big_accelerator
@require_accelerate
@@ -593,7 +650,7 @@ class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.Test
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
"hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(