Compare commits

...

5 Commits

Author SHA1 Message Date
sayakpaul
66d9dc539c close <Tip> 2024-01-22 08:03:04 +05:30
sayakpaul
ac24ec7221 correct pipeline memory footprint 2024-01-22 07:55:56 +05:30
sayakpaul
28dcb729af correct memory footprint for unet2d condition 2024-01-22 07:53:49 +05:30
sayakpaul
9232b71c9c test 2024-01-22 07:51:34 +05:30
sayakpaul
761a1aa2f4 feat: utility for getting memory footprint info like transformers. 2024-01-22 07:39:43 +05:30
4 changed files with 62 additions and 0 deletions

View File

@@ -921,6 +921,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def get_memory_footprint(self, return_buffers=True) -> float:
r"""
Returns the memory footprint of a model. This will return the memory footprint of the current model in bytes.
Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
Arguments:
return_buffers (`bool`, *optional*, defaults to `True`):
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
"""
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
if return_buffers:
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
mem = mem + mem_bufs
return mem
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
deprecated_attention_block_paths = []

View File

@@ -887,6 +887,38 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
)
return self
def get_memory_footprint(self, return_buffers=True) -> float:
"""Returns the memory footprint of a pipeline in bytes.
Arguments:
return_buffers (`bool`, *optional*, defaults to `True`):
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
<Tip>
Only considers the model-level components of the underlying pipeline and not the scheduler.
</Tip>
Examples:
```py
>>> from diffusers import DiffusionPipeline
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> print(pipeline.get_memory_footprint())
```
"""
model_level_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
mem = sum(
[model.get_memory_footprint(return_buffers=return_buffers) for _, model in model_level_components.items()]
)
return mem
@property
def device(self) -> torch.device:
r"""

View File

@@ -270,6 +270,12 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_memory_footprint(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
memory_footprint = model.get_memory_footprint()
assert memory_footprint == 4424848
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",

View File

@@ -181,6 +181,12 @@ class StableDiffusionPipelineFastTests(
}
return inputs
def test_memory_footprint(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
pipeline_memory_footprint = sd_pipe.get_memory_footprint()
assert pipeline_memory_footprint == 362420
def test_stable_diffusion_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator