mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
4 Commits
kernelize
...
flux-contr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
acc9c839bc | ||
|
|
a7a14cc943 | ||
|
|
8229ef9eb9 | ||
|
|
38dedc75ba |
@@ -81,12 +81,17 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
|
|||||||
from ..quantizers.gguf.utils import dequantize_gguf_tensor
|
from ..quantizers.gguf.utils import dequantize_gguf_tensor
|
||||||
|
|
||||||
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
|
is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
|
||||||
|
is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params"
|
||||||
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
|
is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
|
||||||
|
|
||||||
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
|
if is_bnb_4bit_quantized and not is_bitsandbytes_available():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
|
"The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
|
||||||
)
|
)
|
||||||
|
if is_bnb_8bit_quantized and not is_bitsandbytes_available():
|
||||||
|
raise ValueError(
|
||||||
|
"The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints."
|
||||||
|
)
|
||||||
if is_gguf_quantized and not is_gguf_available():
|
if is_gguf_quantized and not is_gguf_available():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
|
"The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
|
||||||
@@ -97,10 +102,10 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
|
|||||||
weight_on_cpu = True
|
weight_on_cpu = True
|
||||||
|
|
||||||
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
||||||
if is_bnb_4bit_quantized:
|
if is_bnb_4bit_quantized or is_bnb_8bit_quantized:
|
||||||
module_weight = dequantize_bnb_weight(
|
module_weight = dequantize_bnb_weight(
|
||||||
module.weight.to(device) if weight_on_cpu else module.weight,
|
module.weight.to(device) if weight_on_cpu else module.weight,
|
||||||
state=module.weight.quant_state,
|
state=module.weight.quant_state if is_bnb_4bit_quantized else module.state,
|
||||||
dtype=model.dtype,
|
dtype=model.dtype,
|
||||||
).data
|
).data
|
||||||
elif is_gguf_quantized:
|
elif is_gguf_quantized:
|
||||||
|
|||||||
@@ -19,15 +19,18 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
DiffusionPipeline,
|
DiffusionPipeline,
|
||||||
|
FluxControlPipeline,
|
||||||
FluxTransformer2DModel,
|
FluxTransformer2DModel,
|
||||||
SanaTransformer2DModel,
|
SanaTransformer2DModel,
|
||||||
SD3Transformer2DModel,
|
SD3Transformer2DModel,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
from diffusers.quantizers import PipelineQuantizationConfig
|
||||||
from diffusers.utils import is_accelerate_version
|
from diffusers.utils import is_accelerate_version
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
@@ -39,6 +42,7 @@ from diffusers.utils.testing_utils import (
|
|||||||
numpy_cosine_similarity_distance,
|
numpy_cosine_similarity_distance,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_bitsandbytes_version_greater,
|
require_bitsandbytes_version_greater,
|
||||||
|
require_peft_backend,
|
||||||
require_peft_version_greater,
|
require_peft_version_greater,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
@@ -697,6 +701,50 @@ class SlowBnb8bitFluxTests(Base8bitTests):
|
|||||||
self.assertTrue(max_diff < 1e-3)
|
self.assertTrue(max_diff < 1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@require_transformers_version_greater("4.44.0")
|
||||||
|
@require_peft_backend
|
||||||
|
class SlowBnb4BitFluxControlWithLoraTests(Base8bitTests):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
self.pipeline_8bit = FluxControlPipeline.from_pretrained(
|
||||||
|
"black-forest-labs/FLUX.1-dev",
|
||||||
|
quantization_config=PipelineQuantizationConfig(
|
||||||
|
quant_backend="bitsandbytes_8bit",
|
||||||
|
quant_kwargs={"load_in_8bit": True},
|
||||||
|
components_to_quantize=["transformer", "text_encoder_2"],
|
||||||
|
),
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
self.pipeline_8bit.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
del self.pipeline_8bit
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
backend_empty_cache(torch_device)
|
||||||
|
|
||||||
|
def test_lora_loading(self):
|
||||||
|
self.pipeline_8bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
|
||||||
|
|
||||||
|
output = self.pipeline_8bit(
|
||||||
|
prompt=self.prompt,
|
||||||
|
control_image=Image.new(mode="RGB", size=(256, 256)),
|
||||||
|
height=256,
|
||||||
|
width=256,
|
||||||
|
max_sequence_length=64,
|
||||||
|
output_type="np",
|
||||||
|
num_inference_steps=8,
|
||||||
|
generator=torch.Generator().manual_seed(42),
|
||||||
|
).images
|
||||||
|
out_slice = output[0, -3:, -3:, -1].flatten()
|
||||||
|
expected_slice = np.array([0.2029, 0.2136, 0.2268, 0.1921, 0.1997, 0.2185, 0.2021, 0.2183, 0.2292])
|
||||||
|
|
||||||
|
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
|
||||||
|
self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
class BaseBnb8bitSerializationTests(Base8bitTests):
|
class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user