mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-22 02:39:51 +08:00
[gguf][torch.compile time] Convert to plain tensor earlier in dequantize_gguf_tensor (#13166)
[gguf] Convert to plain tensor earlier in dequantize_gguf_tensor Once dequantize_gguf_tensor fetches the quant_type attributed from the GGUFParamter tensor subclass, there is no further need of running the actual dequantize operations on the Tensor subclass, we can just convert to plain tensor right away. This not only makes PyTorch eager faster, but reduces torch.compile tracer compile time from 36 seconds to 10 seconds, because there is lot less code to trace now.
This commit is contained in:
@@ -516,6 +516,9 @@ def dequantize_gguf_tensor(tensor):
|
||||
|
||||
block_size, type_size = GGML_QUANT_SIZES[quant_type]
|
||||
|
||||
# Conver to plain tensor to avoid unnecessary __torch_function__ overhead.
|
||||
tensor = tensor.as_tensor()
|
||||
|
||||
tensor = tensor.view(torch.uint8)
|
||||
shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size)
|
||||
|
||||
@@ -525,7 +528,7 @@ def dequantize_gguf_tensor(tensor):
|
||||
dequant = dequant_fn(blocks, block_size, type_size)
|
||||
dequant = dequant.reshape(shape)
|
||||
|
||||
return dequant.as_tensor()
|
||||
return dequant
|
||||
|
||||
|
||||
class GGUFParameter(torch.nn.Parameter):
|
||||
|
||||
Reference in New Issue
Block a user