Compare commits

...

2 Commits

Author SHA1 Message Date
sayakpaul
4e01e02395 add mslk for additional dependencies. 2026-03-25 09:41:04 +05:30
sayakpaul
5e5b575fb3 fix torchao tests 2026-03-25 09:38:49 +05:30
2 changed files with 7 additions and 1 deletions

View File

@@ -341,7 +341,7 @@ jobs:
additional_deps: ["peft", "kernels"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []
additional_deps: [mslk-cuda]
- backend: "optimum_quanto"
test_location: "quanto"
additional_deps: []

View File

@@ -177,6 +177,11 @@ class QuantizationTesterMixin:
model_quantized.to(torch_device)
inputs = self.get_dummy_inputs()
model_dtype = next(model_quantized.parameters()).dtype
inputs = {
k: v.to(dtype=model_dtype) if torch.is_tensor(v) and torch.is_floating_point(v) else v
for k, v in inputs.items()
}
output = model_quantized(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
@@ -930,6 +935,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
"""Test that device_map='auto' works correctly with quantization."""
self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])
@pytest.mark.xfail(reason="dequantize is not implemented in torchao")
def test_torchao_dequantize(self):
"""Test that dequantize() works correctly."""
self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])