Compare commits

...

1 Commits

Author SHA1 Message Date
Dhruv Nair
fe38d77603 update 2026-04-15 11:48:47 +02:00
4 changed files with 72 additions and 31 deletions

View File

@@ -445,10 +445,14 @@ class WanAnimateFaceBlockAttnProcessor:
# B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> attn.dim
B, T, N, C = encoder_hidden_states.shape
# Flatten T and N so the K/V projections see a 3D tensor; BnB int8 matmul only
# accepts 2D/3D inputs and would otherwise fail on this 4D activation.
encoder_hidden_states = encoder_hidden_states.flatten(1, 2) # [B, T, N, C] --> [B, T * N, C]
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
query = query.unflatten(2, (attn.heads, -1)) # [B, S, H * D] --> [B, S, H, D]
key = key.view(B, T, N, attn.heads, -1) # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv]
key = key.view(B, T, N, attn.heads, -1) # [B, T * N, H * D_kv] --> [B, T, N, H, D_kv]
value = value.view(B, T, N, attn.heads, -1)
query = attn.norm_q(query)

View File

@@ -205,6 +205,11 @@ class BaseModelTesterConfig:
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
return {}
@property
def torch_dtype(self) -> torch.dtype:
"""Compute dtype used to build dummy inputs and cast inputs where needed."""
return torch.float32
@property
def output_shape(self) -> Optional[tuple]:
"""Expected output shape for output validation tests."""

View File

@@ -359,15 +359,7 @@ class QuantizationTesterMixin:
if isinstance(module, torch.nn.Linear):
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"
# Get model dtype from first parameter
model_dtype = next(model.parameters()).dtype
inputs = self.get_dummy_inputs()
# Cast inputs to model dtype
inputs = {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None after dequantization"
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
@@ -575,33 +567,28 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
@torch.no_grad()
def test_bnb_keep_modules_in_fp32(self):
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
if not fp32_modules:
pytest.skip(f"{self.model_class.__name__} does not declare _keep_in_fp32_modules")
config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"]
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
self.model_class._keep_in_fp32_modules = ["proj_out"]
model = self._create_quantized_model(config_kwargs)
model.to(torch_device)
try:
model = self._create_quantized_model(config_kwargs)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in fp32_modules):
assert module.weight.dtype == torch.float32, (
f"Module {name} should be FP32 but is {module.weight.dtype}"
)
else:
assert module.weight.dtype == torch.uint8, (
f"Module {name} should be uint8 but is {module.weight.dtype}"
)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
assert module.weight.dtype == torch.float32, (
f"Module {name} should be FP32 but is {module.weight.dtype}"
)
else:
assert module.weight.dtype == torch.uint8, (
f"Module {name} should be uint8 but is {module.weight.dtype}"
)
inputs = self.get_dummy_inputs()
_ = model(**inputs)
finally:
if original_fp32_modules is not None:
self.model_class._keep_in_fp32_modules = original_fp32_modules
inputs = self.get_dummy_inputs()
_ = model(**inputs)
def test_bnb_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""

View File

@@ -320,6 +320,51 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Flux Transformer."""
@property
def torch_dtype(self):
return torch.float16
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
height = width = 4
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
return {
"hidden_states": randn_tensor(
(batch_size, height * width, num_latent_channels),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"pooled_projections": randn_tensor(
(batch_size, embedding_dim),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"img_ids": randn_tensor(
(height * width, num_image_channels),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"txt_ids": randn_tensor(
(sequence_length, num_image_channels),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
"""Quanto quantization tests for Flux Transformer."""