mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-16 04:37:07 +08:00
Compare commits
1 Commits
flux-spmd-
...
bnb-test-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe38d77603 |
@@ -445,10 +445,14 @@ class WanAnimateFaceBlockAttnProcessor:
|
||||
# B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> attn.dim
|
||||
B, T, N, C = encoder_hidden_states.shape
|
||||
|
||||
# Flatten T and N so the K/V projections see a 3D tensor; BnB int8 matmul only
|
||||
# accepts 2D/3D inputs and would otherwise fail on this 4D activation.
|
||||
encoder_hidden_states = encoder_hidden_states.flatten(1, 2) # [B, T, N, C] --> [B, T * N, C]
|
||||
|
||||
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)) # [B, S, H * D] --> [B, S, H, D]
|
||||
key = key.view(B, T, N, attn.heads, -1) # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv]
|
||||
key = key.view(B, T, N, attn.heads, -1) # [B, T * N, H * D_kv] --> [B, T, N, H, D_kv]
|
||||
value = value.view(B, T, N, attn.heads, -1)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
|
||||
@@ -205,6 +205,11 @@ class BaseModelTesterConfig:
|
||||
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def torch_dtype(self) -> torch.dtype:
|
||||
"""Compute dtype used to build dummy inputs and cast inputs where needed."""
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def output_shape(self) -> Optional[tuple]:
|
||||
"""Expected output shape for output validation tests."""
|
||||
|
||||
@@ -359,15 +359,7 @@ class QuantizationTesterMixin:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"
|
||||
|
||||
# Get model dtype from first parameter
|
||||
model_dtype = next(model.parameters()).dtype
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
# Cast inputs to model dtype
|
||||
inputs = {
|
||||
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
|
||||
for k, v in inputs.items()
|
||||
}
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
assert output is not None, "Model output is None after dequantization"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
|
||||
@@ -575,33 +567,28 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
|
||||
|
||||
@torch.no_grad()
|
||||
def test_bnb_keep_modules_in_fp32(self):
|
||||
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
||||
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
|
||||
fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
|
||||
if not fp32_modules:
|
||||
pytest.skip(f"{self.model_class.__name__} does not declare _keep_in_fp32_modules")
|
||||
|
||||
config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"]
|
||||
|
||||
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
|
||||
self.model_class._keep_in_fp32_modules = ["proj_out"]
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
model.to(torch_device)
|
||||
|
||||
try:
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in fp32_modules):
|
||||
assert module.weight.dtype == torch.float32, (
|
||||
f"Module {name} should be FP32 but is {module.weight.dtype}"
|
||||
)
|
||||
else:
|
||||
assert module.weight.dtype == torch.uint8, (
|
||||
f"Module {name} should be uint8 but is {module.weight.dtype}"
|
||||
)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
|
||||
assert module.weight.dtype == torch.float32, (
|
||||
f"Module {name} should be FP32 but is {module.weight.dtype}"
|
||||
)
|
||||
else:
|
||||
assert module.weight.dtype == torch.uint8, (
|
||||
f"Module {name} should be uint8 but is {module.weight.dtype}"
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
_ = model(**inputs)
|
||||
finally:
|
||||
if original_fp32_modules is not None:
|
||||
self.model_class._keep_in_fp32_modules = original_fp32_modules
|
||||
inputs = self.get_dummy_inputs()
|
||||
_ = model(**inputs)
|
||||
|
||||
def test_bnb_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
|
||||
@@ -320,6 +320,51 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.float16
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(batch_size, embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"img_ids": randn_tensor(
|
||||
(height * width, num_image_channels),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"txt_ids": randn_tensor(
|
||||
(sequence_length, num_image_channels),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
"""Quanto quantization tests for Flux Transformer."""
|
||||
|
||||
Reference in New Issue
Block a user