Compare commits

...

6 Commits

Author SHA1 Message Date
Sayak Paul
8a12103891 Merge branch 'main' into training-group-offloading-tests 2025-05-15 10:08:59 +05:30
Sayak Paul
d84dbabba6 Merge branch 'main' into training-group-offloading-tests 2025-05-12 12:58:40 +05:30
Sayak Paul
90f934df8e Merge branch 'main' into training-group-offloading-tests 2025-05-09 09:34:59 +05:30
Sayak Paul
50507ea7c4 Merge branch 'main' into training-group-offloading-tests 2025-05-08 18:42:05 +05:30
Sayak Paul
fa493e376b Merge branch 'main' into training-group-offloading-tests 2025-05-08 09:12:29 +05:30
sayakpaul
131ed8ed16 add test for group_offloading with training. 2025-05-07 14:37:35 +05:30

View File

@@ -1580,6 +1580,30 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
@parameterized.expand([False, True])
@require_torch_accelerator
def test_group_offloading_with_training(self, use_stream):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
model.enable_group_offload(
torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=use_stream
)
model.train()
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)