mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 06:24:19 +08:00
Compare commits
11 Commits
modular-di
...
higgs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d7ef7f32c | ||
|
|
cfd6ec7465 | ||
|
|
1082c46afa | ||
|
|
ba2ba9019f | ||
|
|
fa4c0e5e2e | ||
|
|
b793debd9d | ||
|
|
644bc18cc3 | ||
|
|
34f0ef37cb | ||
|
|
c312812eae | ||
|
|
f82de3339e | ||
|
|
ea6c364485 |
2
.github/workflows/nightly_tests.yml
vendored
2
.github/workflows/nightly_tests.yml
vendored
@@ -333,7 +333,7 @@ jobs:
|
||||
additional_deps: ["peft"]
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
additional_deps: ["peft"]
|
||||
additional_deps: ["peft", "kernels"]
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
additional_deps: []
|
||||
|
||||
@@ -53,6 +53,16 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
|
||||
image.save("flux-gguf.png")
|
||||
```
|
||||
|
||||
## Using Optimized CUDA Kernels with GGUF
|
||||
|
||||
Optimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library:
|
||||
|
||||
```shell
|
||||
pip install -U kernels
|
||||
```
|
||||
|
||||
Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`.
|
||||
|
||||
## Supported Quantization Types
|
||||
|
||||
- BF16
|
||||
|
||||
@@ -95,7 +95,7 @@ class ModuleGroup:
|
||||
self.offload_to_disk_path = offload_to_disk_path
|
||||
self._is_offloaded_to_disk = False
|
||||
|
||||
if self.offload_to_disk_path:
|
||||
if self.offload_to_disk_path is not None:
|
||||
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
|
||||
self.group_id = group_id if group_id is not None else str(id(self))
|
||||
short_hash = _compute_group_hash(self.group_id)
|
||||
@@ -115,6 +115,12 @@ class ModuleGroup:
|
||||
else:
|
||||
self.cpu_param_dict = self._init_cpu_param_dict()
|
||||
|
||||
self._torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
|
||||
def _init_cpu_param_dict(self):
|
||||
cpu_param_dict = {}
|
||||
if self.stream is None:
|
||||
@@ -138,112 +144,76 @@ class ModuleGroup:
|
||||
|
||||
@contextmanager
|
||||
def _pinned_memory_tensors(self):
|
||||
pinned_dict = {}
|
||||
try:
|
||||
for param, tensor in self.cpu_param_dict.items():
|
||||
if not tensor.is_pinned():
|
||||
pinned_dict[param] = tensor.pin_memory()
|
||||
else:
|
||||
pinned_dict[param] = tensor
|
||||
|
||||
pinned_dict = {
|
||||
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
|
||||
for param, tensor in self.cpu_param_dict.items()
|
||||
}
|
||||
yield pinned_dict
|
||||
|
||||
finally:
|
||||
pinned_dict = None
|
||||
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor):
|
||||
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream and current_stream is not None:
|
||||
tensor.data.record_stream(current_stream)
|
||||
if self.record_stream:
|
||||
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
|
||||
|
||||
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
|
||||
def _process_tensors_from_modules(self, pinned_memory=None):
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
source = pinned_memory[param] if pinned_memory else param.data
|
||||
self._transfer_tensor_to_device(param, source, current_stream)
|
||||
self._transfer_tensor_to_device(param, source)
|
||||
for buffer in group_module.buffers():
|
||||
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
||||
self._transfer_tensor_to_device(buffer, source, current_stream)
|
||||
self._transfer_tensor_to_device(buffer, source)
|
||||
|
||||
for param in self.parameters:
|
||||
source = pinned_memory[param] if pinned_memory else param.data
|
||||
self._transfer_tensor_to_device(param, source, current_stream)
|
||||
self._transfer_tensor_to_device(param, source)
|
||||
|
||||
for buffer in self.buffers:
|
||||
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
||||
self._transfer_tensor_to_device(buffer, source, current_stream)
|
||||
|
||||
def _onload_from_disk(self, current_stream):
|
||||
if self.stream is not None:
|
||||
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
|
||||
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
|
||||
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
|
||||
|
||||
self.cpu_param_dict.clear()
|
||||
|
||||
else:
|
||||
onload_device = (
|
||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||||
)
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
tensor_obj.data = loaded_tensors[key]
|
||||
|
||||
def _onload_from_memory(self, current_stream):
|
||||
if self.stream is not None:
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
self._process_tensors_from_modules(pinned_memory, current_stream)
|
||||
else:
|
||||
self._process_tensors_from_modules(None, current_stream)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
|
||||
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
|
||||
|
||||
if self.offload_to_disk_path:
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
if self.stream is not None:
|
||||
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
|
||||
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
|
||||
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
tensor_obj.data.record_stream(current_stream)
|
||||
else:
|
||||
# Load directly to the target device (synchronous)
|
||||
onload_device = (
|
||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||||
)
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
tensor_obj.data = loaded_tensors[key]
|
||||
return
|
||||
self._transfer_tensor_to_device(buffer, source)
|
||||
|
||||
def _onload_from_disk(self):
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
|
||||
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
||||
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
|
||||
|
||||
with context:
|
||||
if self.offload_to_disk_path:
|
||||
self._onload_from_disk(current_stream)
|
||||
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
|
||||
device = str(self.onload_device) if self.stream is None else "cpu"
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
|
||||
|
||||
if self.stream is not None:
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
pinned_tensor = loaded_tensors[key].pin_memory()
|
||||
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
tensor_obj.data.record_stream(current_stream)
|
||||
else:
|
||||
self._onload_from_memory(current_stream)
|
||||
onload_device = (
|
||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||||
)
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
tensor_obj.data = loaded_tensors[key]
|
||||
|
||||
def _onload_from_memory(self):
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
|
||||
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
||||
with context:
|
||||
if self.stream is not None:
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
self._process_tensors_from_modules(pinned_memory)
|
||||
else:
|
||||
self._process_tensors_from_modules(None)
|
||||
|
||||
def _offload_to_disk(self):
|
||||
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
||||
@@ -264,14 +234,10 @@ class ModuleGroup:
|
||||
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
||||
|
||||
def _offload_to_memory(self):
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
if self.stream is not None:
|
||||
if not self.record_stream:
|
||||
torch_accelerator_module.current_stream().synchronize()
|
||||
self._torch_accelerator_module.current_stream().synchronize()
|
||||
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
@@ -282,15 +248,23 @@ class ModuleGroup:
|
||||
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
group_module.to(self.offload_device, non_blocking=False)
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
r"""Onloads the group of parameters to the onload_device."""
|
||||
if self.offload_to_disk_path is not None:
|
||||
self._onload_from_disk()
|
||||
else:
|
||||
self._onload_from_memory()
|
||||
|
||||
@torch.compiler.disable()
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
r"""Offloads the group of parameters to the offload_device."""
|
||||
if self.offload_to_disk_path:
|
||||
self._offload_to_disk()
|
||||
else:
|
||||
@@ -307,11 +281,9 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(
|
||||
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
|
||||
) -> None:
|
||||
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
|
||||
self.group = group
|
||||
self.next_group = next_group
|
||||
self.next_group: Optional[ModuleGroup] = None
|
||||
self.config = config
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
@@ -459,8 +431,8 @@ class LayerExecutionTrackerHook(ModelHook):
|
||||
|
||||
def apply_group_offloading(
|
||||
module: torch.nn.Module,
|
||||
onload_device: torch.device,
|
||||
offload_device: torch.device = torch.device("cpu"),
|
||||
onload_device: Union[str, torch.device],
|
||||
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
||||
offload_type: Union[str, GroupOffloadingType] = "block_level",
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
@@ -546,6 +518,8 @@ def apply_group_offloading(
|
||||
```
|
||||
"""
|
||||
|
||||
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
|
||||
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
|
||||
offload_type = GroupOffloadingType(offload_type)
|
||||
|
||||
stream = None
|
||||
@@ -633,7 +607,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
|
||||
# Apply group offloading hooks to the module groups
|
||||
for i, group in enumerate(matched_module_groups):
|
||||
for group_module in group.modules:
|
||||
_apply_group_offloading_hook(group_module, group, None, config=config)
|
||||
_apply_group_offloading_hook(group_module, group, config=config)
|
||||
|
||||
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
||||
# when the forward pass of this module is called. This is because the top-level module is not
|
||||
@@ -662,9 +636,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
|
||||
group_id=f"{module.__class__.__name__}_unmatched_group",
|
||||
)
|
||||
if config.stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
_apply_group_offloading_hook(module, unmatched_group, config=config)
|
||||
else:
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
||||
|
||||
|
||||
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
@@ -693,7 +667,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
|
||||
onload_self=True,
|
||||
group_id=name,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None, config=config)
|
||||
_apply_group_offloading_hook(submodule, group, config=config)
|
||||
modules_with_group_offloading.add(name)
|
||||
|
||||
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
||||
@@ -740,7 +714,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
|
||||
onload_self=True,
|
||||
group_id=name,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None, config=config)
|
||||
_apply_group_offloading_hook(parent_module, group, config=config)
|
||||
|
||||
if config.stream is not None:
|
||||
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
||||
@@ -762,13 +736,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
|
||||
onload_self=True,
|
||||
group_id=_GROUP_ID_LAZY_LEAF,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
||||
|
||||
|
||||
def _apply_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
*,
|
||||
config: GroupOffloadingConfig,
|
||||
) -> None:
|
||||
@@ -777,14 +750,13 @@ def _apply_group_offloading_hook(
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group, config=config)
|
||||
hook = GroupOffloadingHook(group, config=config)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
|
||||
def _apply_lazy_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
*,
|
||||
config: GroupOffloadingConfig,
|
||||
) -> None:
|
||||
@@ -793,7 +765,7 @@ def _apply_lazy_group_offloading_hook(
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group, config=config)
|
||||
hook = GroupOffloadingHook(group, config=config)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
|
||||
|
||||
@@ -30,6 +30,7 @@ from huggingface_hub import DDUFEntry
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
|
||||
from ..quantizers import DiffusersQuantizer
|
||||
from ..quantizers.quantization_config import QuantizationMethod
|
||||
from ..utils import (
|
||||
GGUF_FILE_EXTENSION,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -231,6 +232,7 @@ def load_model_dict_into_meta(
|
||||
"""
|
||||
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_higgs = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HIGGS
|
||||
empty_state_dict = model.state_dict()
|
||||
|
||||
for param_name, param in state_dict.items():
|
||||
@@ -280,7 +282,8 @@ def load_model_dict_into_meta(
|
||||
|
||||
# bnb params are flattened.
|
||||
# gguf quants have a different shape based on the type of quantization applied
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
# higgs quants repack the weights so they will have different shapes
|
||||
if empty_state_dict[param_name].shape != param.shape and not is_higgs:
|
||||
if (
|
||||
is_quantized
|
||||
and hf_quantizer.pre_quantized
|
||||
@@ -304,7 +307,7 @@ def load_model_dict_into_meta(
|
||||
hf_quantizer.create_quantized_param(
|
||||
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
|
||||
)
|
||||
else:
|
||||
elif hf_quantizer is not None:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
|
||||
|
||||
return offload_index, state_dict_index
|
||||
|
||||
@@ -312,15 +312,14 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.0.dev0"):
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
|
||||
|
||||
@@ -636,6 +636,11 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
if self.attention_kwargs is None:
|
||||
self._attention_kwargs = {}
|
||||
|
||||
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -654,7 +659,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -668,7 +673,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
txt_seq_lens=negative_txt_seq_lens,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -21,9 +21,11 @@ from typing import Dict, Optional, Union
|
||||
|
||||
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
|
||||
from .gguf import GGUFQuantizer
|
||||
from .higgs import HiggsQuantizer
|
||||
from .quantization_config import (
|
||||
BitsAndBytesConfig,
|
||||
GGUFQuantizationConfig,
|
||||
HiggsConfig,
|
||||
QuantizationConfigMixin,
|
||||
QuantizationMethod,
|
||||
QuantoConfig,
|
||||
@@ -39,6 +41,7 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"gguf": GGUFQuantizer,
|
||||
"quanto": QuantoQuantizer,
|
||||
"torchao": TorchAoHfQuantizer,
|
||||
"higgs": HiggsQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
@@ -47,6 +50,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"gguf": GGUFQuantizationConfig,
|
||||
"quanto": QuantoConfig,
|
||||
"torchao": TorchAoConfig,
|
||||
"higgs": HiggsConfig,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -12,15 +12,15 @@
|
||||
# # See the License for the specific language governing permissions and
|
||||
# # limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils import is_accelerate_available
|
||||
from ...utils import is_accelerate_available, is_kernels_available
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -29,6 +29,82 @@ if is_accelerate_available():
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
|
||||
|
||||
can_use_cuda_kernels = (
|
||||
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 7
|
||||
)
|
||||
if can_use_cuda_kernels and is_kernels_available():
|
||||
from kernels import get_kernel
|
||||
|
||||
ops = get_kernel("Isotr0py/ggml")
|
||||
else:
|
||||
ops = None
|
||||
|
||||
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
|
||||
STANDARD_QUANT_TYPES = {
|
||||
gguf.GGMLQuantizationType.Q4_0,
|
||||
gguf.GGMLQuantizationType.Q4_1,
|
||||
gguf.GGMLQuantizationType.Q5_0,
|
||||
gguf.GGMLQuantizationType.Q5_1,
|
||||
gguf.GGMLQuantizationType.Q8_0,
|
||||
gguf.GGMLQuantizationType.Q8_1,
|
||||
}
|
||||
KQUANT_TYPES = {
|
||||
gguf.GGMLQuantizationType.Q2_K,
|
||||
gguf.GGMLQuantizationType.Q3_K,
|
||||
gguf.GGMLQuantizationType.Q4_K,
|
||||
gguf.GGMLQuantizationType.Q5_K,
|
||||
gguf.GGMLQuantizationType.Q6_K,
|
||||
}
|
||||
IMATRIX_QUANT_TYPES = {
|
||||
gguf.GGMLQuantizationType.IQ1_M,
|
||||
gguf.GGMLQuantizationType.IQ1_S,
|
||||
gguf.GGMLQuantizationType.IQ2_XXS,
|
||||
gguf.GGMLQuantizationType.IQ2_XS,
|
||||
gguf.GGMLQuantizationType.IQ2_S,
|
||||
gguf.GGMLQuantizationType.IQ3_XXS,
|
||||
gguf.GGMLQuantizationType.IQ3_S,
|
||||
gguf.GGMLQuantizationType.IQ4_XS,
|
||||
gguf.GGMLQuantizationType.IQ4_NL,
|
||||
}
|
||||
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
|
||||
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
|
||||
# MMQ kernel for I-Matrix quantization.
|
||||
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
|
||||
|
||||
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
|
||||
# there is no need to call any kernel for fp16/bf16
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
|
||||
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
|
||||
# contiguous batching and inefficient with diffusers' batching,
|
||||
# so we disabled it now.
|
||||
|
||||
# elif qweight_type in MMVQ_QUANT_TYPES:
|
||||
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# elif qweight_type in MMQ_QUANT_TYPES:
|
||||
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
|
||||
# If there is no available MMQ kernel, fallback to dequantize
|
||||
if qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
|
||||
y = x @ weight.to(x.dtype).T
|
||||
else:
|
||||
# Raise an error if the quantization type is not supported.
|
||||
# Might be useful if llama.cpp adds a new quantization type.
|
||||
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
|
||||
qweight_type = gguf.GGMLQuantizationType(qweight_type)
|
||||
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
return y.as_tensor()
|
||||
|
||||
|
||||
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
|
||||
def _create_accelerate_new_hook(old_hook):
|
||||
r"""
|
||||
@@ -451,11 +527,24 @@ class GGUFLinear(nn.Linear):
|
||||
) -> None:
|
||||
super().__init__(in_features, out_features, bias, device)
|
||||
self.compute_dtype = compute_dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
|
||||
return self.forward_cuda(inputs)
|
||||
return self.forward_native(inputs)
|
||||
|
||||
def forward_native(self, inputs: torch.Tensor):
|
||||
weight = dequantize_gguf_tensor(self.weight)
|
||||
weight = weight.to(self.compute_dtype)
|
||||
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
|
||||
|
||||
output = torch.nn.functional.linear(inputs, weight, bias)
|
||||
return output
|
||||
|
||||
def forward_cuda(self, inputs: torch.Tensor):
|
||||
quant_type = self.weight.quant_type
|
||||
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
|
||||
if self.bias is not None:
|
||||
output += self.bias.to(self.compute_dtype)
|
||||
return output
|
||||
|
||||
1
src/diffusers/quantizers/higgs/__init__.py
Normal file
1
src/diffusers/quantizers/higgs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .higgs_quantizer import HiggsQuantizer
|
||||
205
src/diffusers/quantizers/higgs/higgs_quantizer.py
Normal file
205
src/diffusers/quantizers/higgs/higgs_quantizer.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/huggingface/transformers/blob/d3d835d4fc145e5062d2153ac23ccd4b3e2c2cbd/src/transformers/quantizers/quantizer_higgs.py
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...utils import get_module_from_name
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
from ...utils import is_accelerate_available, is_torch_available, logging
|
||||
from ...utils.logging import tqdm
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class HiggsQuantizer(DiffusersQuantizer):
|
||||
"""
|
||||
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of
|
||||
full-precision models.
|
||||
"""
|
||||
|
||||
requires_calibration = False
|
||||
requires_parameters_quantization = True
|
||||
required_packages = ["flute-kernel", "fast_hadamard_transform"]
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def validate_environment(self, device_map, **kwargs):
|
||||
if not torch.cuda.is_available():
|
||||
raise NotImplementedError("HIGGS quantization is only supported on GPU. Please use a different quantizer.")
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Using `higgs` quantization requires Accelerate: `pip install accelerate`")
|
||||
|
||||
# TODO: enable this.
|
||||
# if not is_flute_available():
|
||||
# raise ImportError("Using `higgs` quantization requires FLUTE: `pip install flute-kernel>=0.3.0`")
|
||||
|
||||
# if not is_hadamard_available():
|
||||
# raise ImportError(
|
||||
# "Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`"
|
||||
# )
|
||||
|
||||
if device_map is None:
|
||||
raise ValueError(
|
||||
"You are attempting to load a HIGGS model without setting device_map."
|
||||
" Please set device_map comprised of 'cuda' devices."
|
||||
)
|
||||
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
|
||||
raise ValueError(
|
||||
"You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device."
|
||||
" This is not supported. Please remove the CPU or disk device from the device_map."
|
||||
)
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
logger.info("`torch_dtype` is None. Setting `torch_dtype=torch.float16` for FLUTE compatibility.")
|
||||
torch_dtype = torch.float16
|
||||
elif torch_dtype != torch.float16 and torch_dtype != torch.bfloat16:
|
||||
raise ValueError(
|
||||
f"Invalid `torch_dtype` {torch_dtype}. HIGGS quantization only supports `torch_dtype=torch.float16` or `torch_dtype=torch.bfloat16`."
|
||||
)
|
||||
|
||||
return torch_dtype
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: dict[str, Any],
|
||||
unexpected_keys: Optional[list[str]] = None,
|
||||
):
|
||||
from .utils import quantize_with_higgs
|
||||
|
||||
"""
|
||||
Quantizes weights into weight and weight_scale
|
||||
"""
|
||||
flute_dict = quantize_with_higgs(
|
||||
param_value.to(target_device),
|
||||
self.quantization_config.bits,
|
||||
self.quantization_config.p,
|
||||
self.quantization_config.group_size,
|
||||
self.quantization_config.hadamard_size,
|
||||
)
|
||||
del param_value
|
||||
|
||||
module, _ = get_module_from_name(model, param_name)
|
||||
module_name = ".".join(param_name.split(".")[:-1])
|
||||
for key, value in flute_dict.items():
|
||||
if key in module._parameters:
|
||||
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
|
||||
elif key in module._buffers:
|
||||
module._buffers[key] = torch.nn.Buffer(value)
|
||||
elif key == "tune_metadata":
|
||||
module.tune_metadata = value
|
||||
self.quantization_config.tune_metadata[module_name] = value.to_dict()
|
||||
else:
|
||||
raise ValueError(f"Unexpected key {key} in module {module}")
|
||||
|
||||
if unexpected_keys is not None and param_name in unexpected_keys:
|
||||
unexpected_keys.remove(param_name)
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
|
||||
from .utils import HiggsLinear
|
||||
|
||||
higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
|
||||
|
||||
def should_update(key: str) -> bool:
|
||||
if key.endswith(".weight") or key.endswith(".bias"):
|
||||
return False
|
||||
full_key = f"{prefix}.{key}"
|
||||
return any(name in key or name in full_key for name in higgs_names)
|
||||
|
||||
return [key for key in missing_keys if not should_update(key)]
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
return False
|
||||
|
||||
def is_serializable(self):
|
||||
return True
|
||||
|
||||
def check_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
from .utils import HiggsLinear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16:
|
||||
# Only quantize weights of HiggsLinear modules that are not already quantized
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
**kwargs,
|
||||
):
|
||||
from .utils import replace_with_higgs_linear
|
||||
|
||||
replace_with_higgs_linear(model, quantization_config=self.quantization_config)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
|
||||
from flute.tune import TuneMetaData, maybe_tune_and_repack
|
||||
from flute.utils import make_workspace_streamk
|
||||
|
||||
from .utils import HiggsLinear
|
||||
|
||||
flute_workspaces = {}
|
||||
flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
|
||||
for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False):
|
||||
# Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
|
||||
# This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
|
||||
if module.weight.device not in flute_workspaces:
|
||||
flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device)
|
||||
module.workspace = flute_workspaces[module.weight.device]
|
||||
|
||||
# FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
|
||||
# If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
|
||||
module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name])
|
||||
module.weight.data, module.tune_metadata = maybe_tune_and_repack(
|
||||
weight=module.weight.data,
|
||||
scales=module.scales.data,
|
||||
metadata=module.tune_metadata,
|
||||
)
|
||||
self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict()
|
||||
|
||||
def _dequantize(self, model):
|
||||
from .utils import dequantize_higgs
|
||||
|
||||
model = dequantize_higgs(model)
|
||||
return model
|
||||
690
src/diffusers/quantizers/higgs/utils.py
Normal file
690
src/diffusers/quantizers/higgs/utils.py
Normal file
@@ -0,0 +1,690 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
HIGGS through FLUTE (Flexible Lookup Table Engine for LUT-quantized LLMs) integration file.
|
||||
|
||||
Taken from:
|
||||
https://github.com/huggingface/transformers/blob/d3d835d4fc145e5062d2153ac23ccd4b3e2c2cbd/src/transformers/integrations/higgs.py
|
||||
"""
|
||||
|
||||
from math import sqrt
|
||||
|
||||
from ...utils import (
|
||||
# TODO enable:
|
||||
# is_flute_available,
|
||||
# is_hadamard_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
# if is_flute_available():
|
||||
# if is_hadamard_available():
|
||||
from fast_hadamard_transform import hadamard_transform
|
||||
from flute.integrations.higgs import prepare_data_transposed
|
||||
from flute.tune import TuneMetaData, qgemm_v2
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def pad_to_block(tensor, dims, had_block_size, value=0):
|
||||
pad_dims = [0 for _ in range(2 * len(tensor.shape))]
|
||||
for dim in dims:
|
||||
size = tensor.shape[dim]
|
||||
next_multiple_of_1024 = ((size - 1) // had_block_size + 1) * had_block_size
|
||||
delta = next_multiple_of_1024 - size
|
||||
pad_dims[-2 * dim - 1] = delta
|
||||
|
||||
return nn.functional.pad(tensor, pad_dims, "constant", value)
|
||||
|
||||
|
||||
def get_higgs_grid(p: int, n: int):
|
||||
if (p, n) == (2, 256):
|
||||
return torch.tensor(
|
||||
[
|
||||
[-2.501467704772949, 0.17954708635807037],
|
||||
[-0.6761789321899414, 1.2728623151779175],
|
||||
[-1.8025816679000854, 0.7613157629966736],
|
||||
[-0.538287878036499, -2.6028504371643066],
|
||||
[0.8415029644966125, -0.8600977659225464],
|
||||
[0.7023013234138489, 3.3138747215270996],
|
||||
[0.5699077844619751, 2.5782253742218018],
|
||||
[3.292393207550049, -0.6016128063201904],
|
||||
[0.5561617016792297, -1.7723814249038696],
|
||||
[-2.1012380123138428, 0.020958125591278076],
|
||||
[0.46085724234580994, 0.8428705334663391],
|
||||
[1.4548040628433228, -0.6156039237976074],
|
||||
[3.210029363632202, 0.3546904921531677],
|
||||
[0.8893890976905823, -0.5967988967895508],
|
||||
[0.8618854284286499, -3.2061192989349365],
|
||||
[1.1360996961593628, -0.23852407932281494],
|
||||
[1.6646337509155273, -0.9265465140342712],
|
||||
[1.4767773151397705, 1.2476022243499756],
|
||||
[-1.0511897802352905, 1.94503915309906],
|
||||
[-1.56318998336792, -0.3264186680316925],
|
||||
[-0.1829211413860321, 0.2922491431236267],
|
||||
[-0.8950616717338562, -1.3887052536010742],
|
||||
[-0.08206957578659058, -1.329533576965332],
|
||||
[-0.487422913312912, 1.4817842245101929],
|
||||
[-1.6769757270812988, -2.8269758224487305],
|
||||
[-1.5057679414749146, 1.8905963897705078],
|
||||
[1.8335362672805786, 1.0515104532241821],
|
||||
[0.3273945450782776, 1.0491033792495728],
|
||||
[-3.295924186706543, -0.7021600008010864],
|
||||
[-1.8428784608840942, -1.2315762042999268],
|
||||
[-0.8575026392936707, -1.7005949020385742],
|
||||
[-1.120667815208435, 0.6467998027801514],
|
||||
[-0.1588846743106842, -1.804071068763733],
|
||||
[-0.8539647459983826, 0.5645008683204651],
|
||||
[-1.4192019701004028, -0.6175029873847961],
|
||||
[1.0799058675765991, 1.7871345281600952],
|
||||
[1.171311855316162, 0.7511613965034485],
|
||||
[2.162078380584717, 0.8044339418411255],
|
||||
[1.3969420194625854, -1.243762493133545],
|
||||
[-0.23818807303905487, 0.053944624960422516],
|
||||
[2.304199457168579, -1.2667627334594727],
|
||||
[1.4225027561187744, 0.568610668182373],
|
||||
[0.376836895942688, -0.7134661674499512],
|
||||
[2.0404467582702637, 0.4087389409542084],
|
||||
[0.7639489769935608, -1.1367933750152588],
|
||||
[0.3622530400753021, -1.4827953577041626],
|
||||
[0.4100743532180786, 0.36108437180519104],
|
||||
[-1.5867475271224976, -1.618212342262268],
|
||||
[-2.2769672870635986, -1.2132309675216675],
|
||||
[0.9184022545814514, -0.34428009390830994],
|
||||
[-0.3902314603328705, 0.21785245835781097],
|
||||
[3.120687484741211, 1.3077973127365112],
|
||||
[1.587440848350525, -1.6506884098052979],
|
||||
[-1.718808889389038, -0.038405973464250565],
|
||||
[-0.6888407468795776, -0.8402308821678162],
|
||||
[-0.7981445789337158, -1.1117373704910278],
|
||||
[-2.4124443531036377, 1.3419722318649292],
|
||||
[-0.6611530184745789, 0.9939885139465332],
|
||||
[-0.33103418350219727, -0.16702833771705627],
|
||||
[-2.4091389179229736, -2.326857566833496],
|
||||
[1.6610108613967896, -2.159703254699707],
|
||||
[0.014884627424180508, 0.3887578248977661],
|
||||
[0.029668325558304787, 1.8786455392837524],
|
||||
[1.180362582206726, 2.699317216873169],
|
||||
[1.821286678314209, -0.5960053205490112],
|
||||
[-0.44835323095321655, 3.327436685562134],
|
||||
[-0.3714401423931122, -2.1466753482818604],
|
||||
[-1.1103475093841553, -2.4536871910095215],
|
||||
[-0.39110705256462097, 0.6670510172843933],
|
||||
[0.474752813577652, -1.1959707736968994],
|
||||
[-0.013110585510730743, -2.52519154548645],
|
||||
[-2.0836575031280518, -1.703289270401001],
|
||||
[-1.1077687740325928, -0.1252644956111908],
|
||||
[-0.4138077199459076, 1.1837692260742188],
|
||||
[-1.977599024772644, 1.688241720199585],
|
||||
[-1.659559965133667, -2.1387736797332764],
|
||||
[0.03242531046271324, 0.6526556015014648],
|
||||
[0.9127950072288513, 0.6099498867988586],
|
||||
[-0.38478314876556396, 0.433487206697464],
|
||||
[0.27454206347465515, -0.27719801664352417],
|
||||
[0.10388526320457458, 2.2812814712524414],
|
||||
[-0.014394169673323631, -3.177137613296509],
|
||||
[-1.2871228456497192, -0.8961855173110962],
|
||||
[0.5720916986465454, -0.921597957611084],
|
||||
[1.1159656047821045, -0.7609877586364746],
|
||||
[2.4383342266082764, -2.2983546257019043],
|
||||
[-0.294057160615921, -0.9770799875259399],
|
||||
[-0.9342701435089111, 1.107579231262207],
|
||||
[-1.549338698387146, 3.090520143508911],
|
||||
[2.6076579093933105, 2.051239013671875],
|
||||
[-0.9259037375450134, 1.407211184501648],
|
||||
[-0.1747353971004486, 0.540488600730896],
|
||||
[-0.8963701725006104, 0.8271111249923706],
|
||||
[0.6480194926261902, 1.0128909349441528],
|
||||
[0.980783998966217, -0.06156221032142639],
|
||||
[-0.16883476078510284, 1.0601658821105957],
|
||||
[0.5839992761611938, 0.004697148688137531],
|
||||
[-0.34228450059890747, -1.2423977851867676],
|
||||
[2.500824451446533, 0.3665279746055603],
|
||||
[-0.17641609907150269, 1.3529551029205322],
|
||||
[0.05378641560673714, 2.817232847213745],
|
||||
[-1.2391047477722168, 2.354328155517578],
|
||||
[0.630434513092041, -0.668536365032196],
|
||||
[1.7576488256454468, 0.6738647818565369],
|
||||
[0.4435231387615204, 0.6000469326972961],
|
||||
[-0.08794835954904556, -0.11511358618736267],
|
||||
[1.6540337800979614, 0.33995017409324646],
|
||||
[-0.04202975332736969, -0.5375117063522339],
|
||||
[-0.4247745871543884, -0.7897617220878601],
|
||||
[0.06695003807544708, 1.2000739574432373],
|
||||
[-3.2508881092071533, 0.28734830021858215],
|
||||
[-1.613816261291504, 0.4944162368774414],
|
||||
[1.3598989248275757, 0.26117825508117676],
|
||||
[2.308382511138916, 1.3462618589401245],
|
||||
[-1.2137469053268433, -1.9254342317581177],
|
||||
[-0.4889402985572815, 1.8136259317398071],
|
||||
[-0.1870335340499878, -0.3480615019798279],
|
||||
[1.0766386985778809, -1.0627082586288452],
|
||||
[0.4651014506816864, 2.131748914718628],
|
||||
[-0.1306295394897461, -0.7811847925186157],
|
||||
[0.06433182954788208, -1.5397958755493164],
|
||||
[-0.2894323468208313, -0.5789554715156555],
|
||||
[-0.6081662178039551, 0.4845278263092041],
|
||||
[2.697964668273926, -0.18515698611736298],
|
||||
[0.1277363896369934, -0.7221432328224182],
|
||||
[0.8700758218765259, 0.35042452812194824],
|
||||
[0.22088994085788727, 0.495242178440094],
|
||||
[-2.5843818187713623, -0.8000828623771667],
|
||||
[0.6732649803161621, -1.4362232685089111],
|
||||
[-1.5286413431167603, 1.0417330265045166],
|
||||
[-1.1222513914108276, -0.6269875764846802],
|
||||
[-0.9752035140991211, -0.8750635385513306],
|
||||
[-2.6369473934173584, 0.6918523907661438],
|
||||
[0.14478731155395508, -0.041986867785453796],
|
||||
[-1.5629483461380005, 1.4369450807571411],
|
||||
[0.38952457904815674, -2.16428804397583],
|
||||
[-0.16885095834732056, 0.7976621985435486],
|
||||
[-3.12416934967041, 1.256506085395813],
|
||||
[0.6843105554580688, -0.4203019142150879],
|
||||
[1.9345275163650513, 1.934950351715088],
|
||||
[0.012184220366179943, -2.1080918312072754],
|
||||
[-0.6350273489952087, 0.7358828186988831],
|
||||
[-0.837304949760437, -0.6214472651481628],
|
||||
[0.08211923390626907, -0.9472538232803345],
|
||||
[2.9332995414733887, -1.4956780672073364],
|
||||
[1.3806978464126587, -0.2916182279586792],
|
||||
[0.06773144006729126, 0.9285762310028076],
|
||||
[-1.1943119764328003, 1.5963770151138306],
|
||||
[1.6395620107650757, -0.32285431027412415],
|
||||
[-1.390851378440857, -0.08273141086101532],
|
||||
[1.816330909729004, -1.2812227010726929],
|
||||
[0.7921574711799622, -2.1135804653167725],
|
||||
[0.5817914605140686, 1.2644577026367188],
|
||||
[1.929347038269043, -0.2386285960674286],
|
||||
[0.8877345323562622, 1.190008521080017],
|
||||
[1.4732073545455933, 0.8935023546218872],
|
||||
[-2.8518524169921875, -1.5478795766830444],
|
||||
[0.2439267635345459, 0.7576767802238464],
|
||||
[0.5246709585189819, -2.606659412384033],
|
||||
[1.150876760482788, 1.4073830842971802],
|
||||
[-0.2643202245235443, 2.0634236335754395],
|
||||
[1.555483341217041, -0.0023102816194295883],
|
||||
[2.0830578804016113, -1.7225427627563477],
|
||||
[-0.5424830317497253, -1.070199728012085],
|
||||
[0.9168899655342102, 0.8955540060997009],
|
||||
[-0.8120972514152527, 2.696739912033081],
|
||||
[-0.29908373951911926, -1.5310651063919067],
|
||||
[1.2320337295532227, -1.556247353553772],
|
||||
[1.8612544536590576, 0.08704725652933121],
|
||||
[0.22133447229862213, -1.8091708421707153],
|
||||
[-0.4403655230998993, -0.38571012020111084],
|
||||
[-1.88539457321167, 1.192205786705017],
|
||||
[2.239687919616699, 0.004709010478109121],
|
||||
[1.139495611190796, 0.45733731985092163],
|
||||
[-1.507995367050171, 0.19716016948223114],
|
||||
[0.46986445784568787, 1.5422041416168213],
|
||||
[-1.2573751211166382, -0.35984551906585693],
|
||||
[-1.7415345907211304, -0.6020717024803162],
|
||||
[1.0751984119415283, 0.19006384909152985],
|
||||
[2.24186635017395, -0.46343153715133667],
|
||||
[0.3610347509384155, -0.07658443599939346],
|
||||
[-1.3111497163772583, 0.432013601064682],
|
||||
[0.6164408326148987, 0.24538464844226837],
|
||||
[-1.9266542196273804, -0.3256155550479889],
|
||||
[-0.5870336890220642, -0.1879584938287735],
|
||||
[-1.0476511716842651, 0.3677721917629242],
|
||||
[-1.229940414428711, 1.2433830499649048],
|
||||
[0.18550436198711395, 0.22753673791885376],
|
||||
[-0.017921989783644676, 0.12625974416732788],
|
||||
[1.1659504175186157, -0.5020995736122131],
|
||||
[-0.5983408093452454, -1.40438973903656],
|
||||
[0.7519024014472961, -0.16282692551612854],
|
||||
[0.9920787811279297, -1.344896912574768],
|
||||
[-0.8103678226470947, 0.3064485788345337],
|
||||
[0.6956969499588013, 1.8208192586898804],
|
||||
[-2.7830491065979004, -0.2299390584230423],
|
||||
[-0.34681546688079834, 2.4890666007995605],
|
||||
[-1.4452646970748901, -1.2216600179672241],
|
||||
[-2.1872897148132324, 0.8926076292991638],
|
||||
[1.706072211265564, -2.8440372943878174],
|
||||
[1.1119003295898438, -2.4923460483551025],
|
||||
[-2.582794666290283, 2.0973289012908936],
|
||||
[0.04987720400094986, -0.2964983284473419],
|
||||
[-2.063807487487793, -0.7847916483879089],
|
||||
[-0.4068813621997833, 0.9135897755622864],
|
||||
[-0.9814359545707703, -0.3874954879283905],
|
||||
[-1.4227229356765747, 0.7337291240692139],
|
||||
[0.3065044581890106, 1.3125417232513428],
|
||||
[1.2160996198654175, -1.9643305540084839],
|
||||
[-1.2163853645324707, 0.14608727395534515],
|
||||
[-2.3030710220336914, -0.37558120489120483],
|
||||
[0.9232977628707886, 2.1843791007995605],
|
||||
[-0.1989777386188507, 1.651851773262024],
|
||||
[-0.714374840259552, -0.39365994930267334],
|
||||
[-0.7805715799331665, -2.099881887435913],
|
||||
[0.9015759229660034, -1.7053706645965576],
|
||||
[0.1033422127366066, 1.5256654024124146],
|
||||
[-1.8773194551467896, 2.324174165725708],
|
||||
[1.9227174520492554, 2.7441604137420654],
|
||||
[-0.5994020104408264, 0.23984014987945557],
|
||||
[1.3496100902557373, -0.9126054644584656],
|
||||
[-0.8765304088592529, -3.1877026557922363],
|
||||
[-1.2040035724639893, -1.5169521570205688],
|
||||
[1.4261796474456787, 2.150200128555298],
|
||||
[1.463774561882019, 1.6656692028045654],
|
||||
[0.20364105701446533, -0.4988172650337219],
|
||||
[0.5195154547691345, -0.24067887663841248],
|
||||
[-1.1116786003112793, -1.1599653959274292],
|
||||
[-0.8490808606147766, -0.1681060940027237],
|
||||
[0.3189965784549713, -0.9641751646995544],
|
||||
[-0.5664751529693604, -0.5951744318008423],
|
||||
[-1.6347930431365967, -0.9137664437294006],
|
||||
[0.44048091769218445, -0.47259435057640076],
|
||||
[-2.147747039794922, 0.47442489862442017],
|
||||
[1.834734320640564, 1.4462147951126099],
|
||||
[1.1777573823928833, 1.0659226179122925],
|
||||
[-0.9568989872932434, 0.09495053440332413],
|
||||
[-1.838529348373413, 0.2950586676597595],
|
||||
[-0.4800611734390259, 0.014894310384988785],
|
||||
[-0.5235516428947449, -1.7687653303146362],
|
||||
[2.0735011100769043, -0.8825281262397766],
|
||||
[2.637502431869507, 0.8455678224563599],
|
||||
[2.606602907180786, -0.7848446369171143],
|
||||
[-1.1886937618255615, 0.9330510497093201],
|
||||
[0.38082656264305115, 0.13328030705451965],
|
||||
[0.6847941875457764, 0.7384101152420044],
|
||||
[1.2638574838638306, -0.007309418171644211],
|
||||
[0.18292222917079926, -1.22371244430542],
|
||||
[0.8143821954727173, 1.4976691007614136],
|
||||
[0.6571850776672363, 0.48368802666664124],
|
||||
[-0.6991601586341858, 2.150190830230713],
|
||||
[0.8101756572723389, 0.10206498205661774],
|
||||
[-0.08768226951360703, -1.084917664527893],
|
||||
[-0.7208092212677002, 0.03657956421375275],
|
||||
[0.3211449086666107, 1.803687334060669],
|
||||
[-0.7835946083068848, 1.6869111061096191],
|
||||
]
|
||||
)
|
||||
if (p, n) == (2, 64):
|
||||
return torch.tensor(
|
||||
[
|
||||
[-2.7216711044311523, 0.14431366324424744],
|
||||
[-0.766914427280426, 1.7193410396575928],
|
||||
[-2.2575762271881104, 1.2476624250411987],
|
||||
[1.233758807182312, -2.3560616970062256],
|
||||
[0.8701965808868408, -0.2649352252483368],
|
||||
[1.4506438970565796, 2.1776366233825684],
|
||||
[-0.06305818259716034, 1.9049758911132812],
|
||||
[2.536226511001587, 0.563927412033081],
|
||||
[0.4599496126174927, -1.8745561838150024],
|
||||
[-1.900517225265503, -0.30703988671302795],
|
||||
[0.09386251866817474, 0.8755807280540466],
|
||||
[1.946500539779663, -0.6743080615997314],
|
||||
[2.1338934898376465, 1.4581491947174072],
|
||||
[0.9429940581321716, -0.8038390278816223],
|
||||
[2.0697755813598633, -1.614896535873413],
|
||||
[0.772676408290863, 0.22017823159694672],
|
||||
[1.0689979791641235, -1.525044322013855],
|
||||
[0.6813604831695557, 1.1345642805099487],
|
||||
[0.4706456661224365, 2.606626272201538],
|
||||
[-1.294018030166626, -0.4372096061706543],
|
||||
[-0.09134224057197571, 0.4610418677330017],
|
||||
[-0.7907772064208984, -0.48412787914276123],
|
||||
[0.060459110885858536, -0.9172890186309814],
|
||||
[-0.5855047702789307, 2.56172513961792],
|
||||
[0.11484206467866898, -2.659848213195801],
|
||||
[-1.5893300771713257, 2.188580274581909],
|
||||
[1.6750942468643188, 0.7089915871620178],
|
||||
[-0.445697546005249, 0.7452405095100403],
|
||||
[-1.8539940118789673, -1.8377939462661743],
|
||||
[-1.5791912078857422, -1.017285943031311],
|
||||
[-1.030419945716858, -1.5746369361877441],
|
||||
[-1.9511750936508179, 0.43696075677871704],
|
||||
[-0.3446580767631531, -1.8953213691711426],
|
||||
[-1.4219647645950317, 0.7676230669021606],
|
||||
[-0.9191089272499084, 0.5021472573280334],
|
||||
[0.20464491844177246, 1.3684605360031128],
|
||||
[0.5402919054031372, 0.6699410676956177],
|
||||
[1.8903915882110596, 0.03638288006186485],
|
||||
[0.4723062515258789, -0.6216739416122437],
|
||||
[-0.41345009207725525, -0.22752176225185394],
|
||||
[2.7119064331054688, -0.5111885070800781],
|
||||
[1.065286636352539, 0.6950305700302124],
|
||||
[0.40629103779792786, -0.14339995384216309],
|
||||
[1.2815024852752686, 0.17108257114887238],
|
||||
[0.01785222627222538, -0.43778058886528015],
|
||||
[0.054590027779340744, -1.4225547313690186],
|
||||
[0.3076786696910858, 0.30697619915008545],
|
||||
[-0.9498570561408997, -0.9576997756958008],
|
||||
[-2.4640724658966064, -0.9660449028015137],
|
||||
[1.3714425563812256, -0.39760473370552063],
|
||||
[-0.4857747256755829, 0.2386789172887802],
|
||||
[1.2797833681106567, 1.3097363710403442],
|
||||
[0.5508887767791748, -1.1777795553207397],
|
||||
[-1.384316325187683, 0.1465839296579361],
|
||||
[-0.46556955575942993, -1.2442727088928223],
|
||||
[-0.3915477693080902, -0.7319604158401489],
|
||||
[-1.4005504846572876, 1.3890998363494873],
|
||||
[-0.8647305965423584, 1.0617644786834717],
|
||||
[-0.8901953101158142, -0.01650036871433258],
|
||||
[-0.9893633723258972, -2.4662880897521973],
|
||||
[1.445534110069275, -1.049334168434143],
|
||||
[-0.041650623083114624, 0.012734669260680676],
|
||||
[-0.3302375078201294, 1.26217782497406],
|
||||
[0.6934980154037476, 1.7714335918426514],
|
||||
]
|
||||
)
|
||||
elif (p, n) == (2, 16):
|
||||
return torch.tensor(
|
||||
[
|
||||
[-0.8996632695198059, -1.6360418796539307],
|
||||
[-0.961183488368988, 1.5999565124511719],
|
||||
[-1.882026195526123, 0.678778350353241],
|
||||
[0.36300793290138245, -1.9667866230010986],
|
||||
[-0.6814072728157043, -0.576818585395813],
|
||||
[0.7270012497901917, 0.6186859607696533],
|
||||
[0.3359416127204895, 1.8371193408966064],
|
||||
[1.859930396080017, 0.036668598651885986],
|
||||
[0.17208248376846313, -0.9401724338531494],
|
||||
[-1.7599700689315796, -0.6244229674339294],
|
||||
[-0.8993809223175049, 0.32267823815345764],
|
||||
[0.839488685131073, -0.3017036020755768],
|
||||
[1.5314953327178955, 1.2942044734954834],
|
||||
[-0.0011779458727687597, 0.00022069070837460458],
|
||||
[1.4274526834487915, -1.207889199256897],
|
||||
[-0.16123905777931213, 0.8787511587142944],
|
||||
]
|
||||
)
|
||||
elif (p, n) == (1, 16):
|
||||
return torch.tensor(
|
||||
[
|
||||
[-2.7325894832611084],
|
||||
[-2.069017171859741],
|
||||
[-1.6180464029312134],
|
||||
[-1.2562311887741089],
|
||||
[-0.9423404335975647],
|
||||
[-0.6567591428756714],
|
||||
[-0.38804829120635986],
|
||||
[-0.12839503586292267],
|
||||
[0.12839503586292267],
|
||||
[0.38804829120635986],
|
||||
[0.6567591428756714],
|
||||
[0.9423404335975647],
|
||||
[1.2562311887741089],
|
||||
[1.6180464029312134],
|
||||
[2.069017171859741],
|
||||
[2.7325894832611084],
|
||||
]
|
||||
)
|
||||
elif (p, n) == (1, 8):
|
||||
return torch.tensor(
|
||||
[
|
||||
[-2.1519455909729004],
|
||||
[-1.3439092636108398],
|
||||
[-0.7560052871704102],
|
||||
[-0.2450941801071167],
|
||||
[0.2450941801071167],
|
||||
[0.7560052871704102],
|
||||
[1.3439092636108398],
|
||||
[2.1519455909729004],
|
||||
]
|
||||
)
|
||||
elif (p, n) == (1, 4):
|
||||
return torch.tensor([[-1.5104175806045532], [-0.4527800381183624], [0.4527800381183624], [1.5104175806045532]])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported p={p}, n={n}")
|
||||
|
||||
|
||||
def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256, hadamard_size: int = 1024):
|
||||
assert len(weight.shape) == 2, "Only 2D weights are supported for now"
|
||||
|
||||
grid = get_higgs_grid(p, 2 ** (p * bits)).to(weight.device)
|
||||
grid_norm_2 = torch.linalg.norm(grid, axis=-1) ** 2
|
||||
|
||||
device = weight.device
|
||||
dtype = weight.dtype
|
||||
weight = weight.to(copy=True, dtype=torch.float32)
|
||||
# Pad to Hadamard transform size
|
||||
weight = pad_to_block(weight, [1], hadamard_size)
|
||||
|
||||
# Scale and Hadamard transform
|
||||
mult = weight.shape[1] // hadamard_size
|
||||
weight = weight.reshape(-1, mult, hadamard_size)
|
||||
scales = torch.linalg.norm(weight, axis=-1)
|
||||
weight = hadamard_transform(weight, 1) / scales[:, :, None]
|
||||
|
||||
# Pad to edenn_d and project
|
||||
weight = pad_to_block(weight, [2], p).reshape(weight.shape[0], mult, -1, p)
|
||||
|
||||
# Quantize
|
||||
codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8)
|
||||
for i in range(0, weight.shape[0], 16):
|
||||
codes[i : i + 16] = torch.argmax(2 * weight[i : i + 16] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8)
|
||||
del weight
|
||||
|
||||
codes = codes.reshape(codes.shape[0], -1)
|
||||
scales = scales / sqrt(hadamard_size)
|
||||
|
||||
weight, scales, tables, tables2, tune_metadata = prepare_data_transposed(
|
||||
codes,
|
||||
torch.repeat_interleave(scales.to(dtype), hadamard_size // group_size, dim=1),
|
||||
grid.to(dtype),
|
||||
num_bits=bits,
|
||||
group_size=group_size,
|
||||
vector_size=p,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
check_correctness=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"weight": weight,
|
||||
"scales": scales,
|
||||
"tables": tables,
|
||||
"tables2": tables2.view(dtype=torch.float16),
|
||||
"tune_metadata": tune_metadata,
|
||||
}
|
||||
|
||||
|
||||
class HiggsLinear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
num_bits: int,
|
||||
bias=True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
group_size: int = 256,
|
||||
hadamard_size: int = 1024,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.num_bits = num_bits
|
||||
self.group_size = group_size
|
||||
self.hadamard_size = hadamard_size
|
||||
|
||||
assert in_features % group_size == 0
|
||||
assert num_bits in [2, 3, 4]
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((out_features * num_bits // 16, in_features), dtype=torch.int16, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty((out_features, in_features // group_size), dtype=dtype, device=device), requires_grad=False
|
||||
)
|
||||
self.tables = nn.Parameter(torch.empty((2**num_bits,), dtype=dtype, device=device), requires_grad=False)
|
||||
self.tables2 = nn.Parameter(
|
||||
torch.empty((2**num_bits, 2**num_bits, 2), dtype=dtype, device=device), requires_grad=False
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(out_features, device=device, dtype=dtype), requires_grad=False)
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.workspace = None # must be set externally to be reused among layers
|
||||
self.tune_metadata: TuneMetaData = None # must be set externally because architecture dependent
|
||||
|
||||
def forward(self, x):
|
||||
x = pad_to_block(x, [-1], self.hadamard_size)
|
||||
|
||||
if self.workspace is None:
|
||||
raise Exception("Workspace must be set before calling forward")
|
||||
|
||||
return qgemm_v2(
|
||||
x,
|
||||
self.weight,
|
||||
self.scales,
|
||||
self.tables,
|
||||
self.tables2.view(dtype=torch.float32),
|
||||
self.workspace,
|
||||
self.tune_metadata,
|
||||
hadamard_size=self.hadamard_size,
|
||||
)
|
||||
|
||||
|
||||
def _replace_with_higgs_linear(
|
||||
model, quantization_config=None, current_key_name=None, modules_to_not_convert=None, has_been_replaced=False
|
||||
):
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(current_key_name_str.endswith(key) for key in modules_to_not_convert):
|
||||
with init_empty_weights():
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
# Original size is [3072, 4096]. But after `HiggsLinear`, this is
|
||||
# [768, 4096]. 🤯
|
||||
if name == "context_embedder":
|
||||
print(f"{in_features=}, {out_features=}")
|
||||
model._modules[name] = HiggsLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=module.bias is not None,
|
||||
num_bits=quantization_config.bits,
|
||||
hadamard_size=quantization_config.hadamard_size,
|
||||
group_size=quantization_config.group_size,
|
||||
)
|
||||
if name == "context_embedder":
|
||||
print(model._modules[name].weight.shape)
|
||||
has_been_replaced = True
|
||||
|
||||
# Store the module class in case we need to transpose the weight later
|
||||
model._modules[name].source_cls = type(module)
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_higgs_linear(
|
||||
module,
|
||||
quantization_config=quantization_config,
|
||||
current_key_name=current_key_name,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def replace_with_higgs_linear(
|
||||
model,
|
||||
quantization_config=None,
|
||||
current_key_name=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""
|
||||
Public method that recursively replaces the Linear layers of the given model with HIGGS quantized layers.
|
||||
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
|
||||
conversion has been successful or not.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
The model to convert, can be any `torch.nn.Module` instance.
|
||||
quantization_config (`HiggsConfig`):
|
||||
The quantization config object that contains the quantization parameters.
|
||||
current_key_name (`list`, *optional*):
|
||||
A list that contains the current key name. This is used for recursion and should not be passed by the user.
|
||||
has_been_replaced (`bool`, *optional*):
|
||||
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
|
||||
should not be passed by the user.
|
||||
"""
|
||||
modules_to_not_convert = quantization_config.modules_to_not_convert or []
|
||||
model, _ = _replace_with_higgs_linear(
|
||||
model, quantization_config, current_key_name, modules_to_not_convert, has_been_replaced
|
||||
)
|
||||
|
||||
has_been_replaced = any(isinstance(replaced_module, HiggsLinear) for _, replaced_module in model.named_modules())
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model in Higgs but no linear modules were found in your model."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def dequantize_higgs(model, current_key_name=None):
|
||||
"""
|
||||
Dequantizes the HiggsLinear layers in the given model by replacing them with standard torch.nn.Linear layers.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model containing HiggsLinear layers to be dequantized.
|
||||
current_key_name (list, optional):
|
||||
A list to keep track of the current module names during recursion. Defaults to None.
|
||||
Returns:
|
||||
torch.nn.Module: The model with HiggsLinear layers replaced by torch.nn.Linear layers.
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, HiggsLinear):
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
||||
model._modules[name] = torch.nn.Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.scales.device,
|
||||
dtype=module.scales.dtype,
|
||||
)
|
||||
|
||||
model._modules[name].weight.data = module(
|
||||
torch.eye(in_features, device=module.scales.device, dtype=module.scales.dtype)
|
||||
).T.contiguous()
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_ = dequantize_higgs(
|
||||
module,
|
||||
current_key_name=current_key_name,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model
|
||||
@@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum):
|
||||
GGUF = "gguf"
|
||||
TORCHAO = "torchao"
|
||||
QUANTO = "quanto"
|
||||
HIGGS = "higgs"
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
@@ -724,3 +725,62 @@ class QuantoConfig(QuantizationConfigMixin):
|
||||
accepted_weights = ["float8", "int8", "int4", "int2"]
|
||||
if self.weights_dtype not in accepted_weights:
|
||||
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class HiggsConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
HiggsConfig is a configuration class for quantization using the HIGGS method.
|
||||
|
||||
Args:
|
||||
bits (int, *optional*, defaults to 4):
|
||||
Number of bits to use for quantization. Can be 2, 3 or 4. Default is 4.
|
||||
p (int, *optional*, defaults to 2):
|
||||
Quantization grid dimension. 1 and 2 are supported. 2 is always better in practice. Default is 2.
|
||||
modules_to_not_convert (`list`, *optional*, default to ["lm_head"]):
|
||||
List of linear layers that should not be quantized.
|
||||
hadamard_size (int, *optional*, defaults to 512):
|
||||
Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value.
|
||||
Decreasing this below 512 will reduce the quality of the quantization.
|
||||
group_size (int, *optional*, defaults to 256):
|
||||
Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance.
|
||||
Default is 256. Must be a divisor of hadamard_size.
|
||||
tune_metadata ('dict', *optional*, defaults to {}):
|
||||
Module-wise metadata (gemm block shapes, GPU metadata, etc.) for saving the kernel tuning results. Default
|
||||
is an empty dictionary. Is set automatically during tuning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bits: int = 4,
|
||||
p: int = 2,
|
||||
modules_to_not_convert: Optional[list[str]] = None,
|
||||
hadamard_size: int = 512,
|
||||
group_size: int = 256,
|
||||
tune_metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if tune_metadata is None:
|
||||
tune_metadata = {}
|
||||
self.quant_method = QuantizationMethod.HIGGS
|
||||
self.bits = bits
|
||||
self.p = p
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.hadamard_size = hadamard_size
|
||||
self.group_size = group_size
|
||||
self.tune_metadata = tune_metadata
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
||||
"""
|
||||
if self.bits not in [2, 3, 4]:
|
||||
raise ValueError("bits must be 2, 3, or 4")
|
||||
if self.p not in [1, 2]:
|
||||
raise ValueError("p must be 1 or 2. 2 is always better in practice")
|
||||
if self.group_size not in [64, 128, 256]:
|
||||
raise ValueError("group_size must be 64, 128, or 256")
|
||||
if self.hadamard_size % self.group_size != 0:
|
||||
raise ValueError("hadamard_size must be divisible by group_size")
|
||||
|
||||
@@ -81,6 +81,7 @@ from .import_utils import (
|
||||
is_invisible_watermark_available,
|
||||
is_k_diffusion_available,
|
||||
is_k_diffusion_version,
|
||||
is_kernels_available,
|
||||
is_librosa_available,
|
||||
is_matplotlib_available,
|
||||
is_nltk_available,
|
||||
|
||||
@@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
|
||||
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
|
||||
_transformers_available, _transformers_version = _is_package_available("transformers")
|
||||
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
|
||||
_kernels_available, _kernels_version = _is_package_available("kernels")
|
||||
_inflect_available, _inflect_version = _is_package_available("inflect")
|
||||
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
|
||||
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
|
||||
@@ -277,6 +278,10 @@ def is_accelerate_available():
|
||||
return _accelerate_available
|
||||
|
||||
|
||||
def is_kernels_available():
|
||||
return _kernels_available
|
||||
|
||||
|
||||
def is_k_diffusion_available():
|
||||
return _k_diffusion_available
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from .import_utils import (
|
||||
is_compel_available,
|
||||
is_flax_available,
|
||||
is_gguf_available,
|
||||
is_kernels_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
@@ -634,6 +635,18 @@ def require_torchao_version_greater_or_equal(torchao_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_kernels_version_greater_or_equal(kernels_version):
|
||||
def decorator(test_case):
|
||||
correct_kernels_version = is_kernels_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("kernels")).base_version
|
||||
) >= version.parse(kernels_version)
|
||||
return unittest.skipUnless(
|
||||
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
|
||||
@@ -17,7 +17,9 @@ import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers.hooks import HookRegistry, ModelHook
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.utils import get_logger
|
||||
@@ -99,6 +101,29 @@ class DummyModelWithMultipleBlocks(ModelMixin):
|
||||
return x
|
||||
|
||||
|
||||
# Test for https://github.com/huggingface/diffusers/pull/12077
|
||||
class DummyModelWithLayerNorm(ModelMixin):
|
||||
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
|
||||
self.activation = torch.nn.ReLU()
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
|
||||
)
|
||||
self.layer_norm = torch.nn.LayerNorm(hidden_features, elementwise_affine=True)
|
||||
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_1(x)
|
||||
x = self.activation(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.layer_norm(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class DummyPipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = "model"
|
||||
|
||||
@@ -113,6 +138,16 @@ class DummyPipeline(DiffusionPipeline):
|
||||
return x
|
||||
|
||||
|
||||
class LayerOutputTrackerHook(ModelHook):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.outputs = []
|
||||
|
||||
def post_forward(self, module, output):
|
||||
self.outputs.append(output)
|
||||
return output
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class GroupOffloadTests(unittest.TestCase):
|
||||
in_features = 64
|
||||
@@ -258,6 +293,7 @@ class GroupOffloadTests(unittest.TestCase):
|
||||
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
model = DummyModelWithMultipleBlocks(
|
||||
in_features=self.in_features,
|
||||
hidden_features=self.hidden_features,
|
||||
@@ -274,3 +310,54 @@ class GroupOffloadTests(unittest.TestCase):
|
||||
|
||||
with context:
|
||||
model(self.input)
|
||||
|
||||
@parameterized.expand([("block_level",), ("leaf_level",)])
|
||||
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
|
||||
def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
|
||||
for name, module in model.named_modules():
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = LayerOutputTrackerHook()
|
||||
registry.register_hook(hook, "layer_output_tracker")
|
||||
|
||||
model_ref = DummyModelWithLayerNorm(128, 256, 128, 2)
|
||||
model = DummyModelWithLayerNorm(128, 256, 128, 2)
|
||||
|
||||
model.load_state_dict(model_ref.state_dict(), strict=True)
|
||||
|
||||
model_ref.to(torch_device)
|
||||
model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
|
||||
|
||||
apply_layer_output_tracker_hook(model_ref)
|
||||
apply_layer_output_tracker_hook(model)
|
||||
|
||||
x = torch.randn(2, 128).to(torch_device)
|
||||
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
|
||||
|
||||
num_repeats = 4
|
||||
for i in range(num_repeats):
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
|
||||
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
|
||||
|
||||
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
|
||||
assert ref_name == name
|
||||
ref_outputs = (
|
||||
HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs
|
||||
)
|
||||
outputs = HookRegistry.check_if_exists_or_initialize(module).get_hook("layer_output_tracker").outputs
|
||||
cumulated_absmax = 0.0
|
||||
for i in range(len(outputs)):
|
||||
diff = ref_outputs[0] - outputs[i]
|
||||
absdiff = diff.abs()
|
||||
absmax = absdiff.max().item()
|
||||
cumulated_absmax += absmax
|
||||
self.assertLess(
|
||||
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
|
||||
)
|
||||
|
||||
@@ -45,6 +45,7 @@ from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import is_transformers_version
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
@@ -220,6 +221,11 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
}
|
||||
return inputs
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.54.1"),
|
||||
reason="Test currently fails on Transformers version 4.54.1.",
|
||||
strict=False,
|
||||
)
|
||||
def test_audioldm2_ddim(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
@@ -312,7 +318,6 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
@@ -371,6 +376,11 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
assert np.abs(audio_1 - audio_2).max() < 1e-2
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">=", "4.54.1"),
|
||||
reason="Test currently fails on Transformers version 4.54.1.",
|
||||
strict=False,
|
||||
)
|
||||
def test_audioldm2_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -30,8 +30,10 @@ from diffusers.utils.testing_utils import (
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate,
|
||||
require_accelerator,
|
||||
require_big_accelerator,
|
||||
require_gguf_version_greater_or_equal,
|
||||
require_kernels_version_greater_or_equal,
|
||||
require_peft_backend,
|
||||
require_torch_version_greater,
|
||||
torch_device,
|
||||
@@ -41,11 +43,66 @@ from ..test_torch_compile_utils import QuantCompileTests
|
||||
|
||||
|
||||
if is_gguf_available():
|
||||
import gguf
|
||||
|
||||
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@nightly
|
||||
@require_accelerate
|
||||
@require_accelerator
|
||||
@require_gguf_version_greater_or_equal("0.10.0")
|
||||
@require_kernels_version_greater_or_equal("0.9.0")
|
||||
class GGUFCudaKernelsTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_cuda_kernels_vs_native(self):
|
||||
if torch_device != "cuda":
|
||||
self.skipTest("CUDA kernels test requires CUDA device")
|
||||
|
||||
from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels
|
||||
|
||||
if not can_use_cuda_kernels:
|
||||
self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)")
|
||||
|
||||
test_quant_types = ["Q4_0", "Q4_K"]
|
||||
test_shape = (1, 64, 512) # batch, seq_len, hidden_dim
|
||||
compute_dtype = torch.bfloat16
|
||||
|
||||
for quant_type in test_quant_types:
|
||||
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
|
||||
in_features, out_features = 512, 512
|
||||
|
||||
torch.manual_seed(42)
|
||||
float_weight = torch.randn(out_features, in_features, dtype=torch.float32)
|
||||
quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype)
|
||||
weight_data = torch.from_numpy(quantized_data).to(device=torch_device)
|
||||
weight = GGUFParameter(weight_data, quant_type=qtype)
|
||||
|
||||
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
|
||||
|
||||
linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype)
|
||||
linear.weight = weight
|
||||
linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype))
|
||||
linear = linear.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output_native = linear.forward_native(x)
|
||||
output_cuda = linear.forward_cuda(x)
|
||||
|
||||
assert torch.allclose(output_native, output_cuda, 1e-2), (
|
||||
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
|
||||
)
|
||||
|
||||
|
||||
@nightly
|
||||
@require_big_accelerator
|
||||
@require_accelerate
|
||||
@@ -593,7 +650,7 @@ class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.Test
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
"hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": torch.randn(
|
||||
|
||||
Reference in New Issue
Block a user