Compare commits

...

5 Commits

Author SHA1 Message Date
Dhruv Nair
32476f9717 update 2025-03-26 16:37:57 +01:00
Dhruv Nair
a69dd6fdf0 Merge branch 'main' into dtype-fix 2025-03-26 16:33:21 +01:00
Dhruv Nair
e539cd32c4 update 2025-03-26 16:32:14 +01:00
Dhruv Nair
611d37549f update 2025-03-26 16:31:13 +01:00
Dhruv Nair
12eeb252d5 update 2025-03-25 18:41:45 +01:00
3 changed files with 8 additions and 4 deletions

View File

@@ -282,6 +282,7 @@ class FromOriginalModelMixin:
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
hf_quantizer.validate_environment()
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
else:
hf_quantizer = None

View File

@@ -90,13 +90,16 @@ class Base8bitTests(unittest.TestCase):
def get_dummy_inputs(self):
prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
map_location="cpu",
)
pooled_prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt",
map_location="cpu",
)
latent_model_input = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt",
map_location="cpu",
)
input_dict_for_transformer = {

View File

@@ -57,7 +57,7 @@ class GGUFSingleFileTesterMixin:
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
assert module.weight.dtype == torch.uint8
if module.bias is not None:
assert module.bias.dtype == torch.float32
assert module.bias.dtype == self.torch_dtype
def test_gguf_memory_usage(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)