Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
1f67e4e7f6 add a test for checking effective custom gc. 2025-01-29 11:04:12 +05:30

View File

@@ -966,6 +966,38 @@ class ModelTesterMixin:
assert set(modules_with_gc_enabled.keys()) == expected_set
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
@require_torch_accelerator_with_training
def test_apply_gradient_checkpointing_every_n_block(self, block_num=2):
# Skip test if model does not support gradient checkpointing
if not self.model_class._supports_gradient_checkpointing:
return
# For now, we only test for transformer models.
if "transformer" not in self.model_class.__name__.lower():
return
# enable deterministic behavior for gradient checkpointing
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict)
def gradient_checkpointing_func(model, *args):
if model.layer_index % block_num == 0:
return torch.utils.checkpoint.checkpoint(model.__call__, *args, use_reentrant=False)
return model(*args)
if getattr(model, "transformer_blocks", None) is not None:
for index, layer in enumerate(model.transformer_blocks):
layer.layer_index = index
model.enable_gradient_checkpointing(gradient_checkpointing_func)
assert model.training
for index, layer in enumerate(model.transformer_blocks):
if model.layer_index % block_num == 0:
assert layer.is_gradient_checkpointing
else:
assert not layer.is_gradient_checkpointing
def test_deprecated_kwargs(self):
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0