mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
5 Commits
stable-cas
...
feat/memor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66d9dc539c | ||
|
|
ac24ec7221 | ||
|
|
28dcb729af | ||
|
|
9232b71c9c | ||
|
|
761a1aa2f4 |
@@ -921,6 +921,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
||||
|
||||
def get_memory_footprint(self, return_buffers=True) -> float:
|
||||
r"""
|
||||
Returns the memory footprint of a model. This will return the memory footprint of the current model in bytes.
|
||||
Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
|
||||
PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
|
||||
|
||||
Arguments:
|
||||
return_buffers (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
|
||||
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
|
||||
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
|
||||
"""
|
||||
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
|
||||
if return_buffers:
|
||||
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
|
||||
mem = mem + mem_bufs
|
||||
return mem
|
||||
|
||||
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
||||
deprecated_attention_block_paths = []
|
||||
|
||||
|
||||
@@ -887,6 +887,38 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
return self
|
||||
|
||||
def get_memory_footprint(self, return_buffers=True) -> float:
|
||||
"""Returns the memory footprint of a pipeline in bytes.
|
||||
|
||||
Arguments:
|
||||
return_buffers (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
|
||||
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
|
||||
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
|
||||
|
||||
<Tip>
|
||||
|
||||
Only considers the model-level components of the underlying pipeline and not the scheduler.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
>>> print(pipeline.get_memory_footprint())
|
||||
|
||||
```
|
||||
|
||||
"""
|
||||
model_level_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
||||
mem = sum(
|
||||
[model.get_memory_footprint(return_buffers=return_buffers) for _, model in model_level_components.items()]
|
||||
)
|
||||
return mem
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
r"""
|
||||
|
||||
@@ -270,6 +270,12 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_memory_footprint(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
memory_footprint = model.get_memory_footprint()
|
||||
assert memory_footprint == 4424848
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
|
||||
@@ -181,6 +181,12 @@ class StableDiffusionPipelineFastTests(
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_memory_footprint(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
pipeline_memory_footprint = sd_pipe.get_memory_footprint()
|
||||
assert pipeline_memory_footprint == 362420
|
||||
|
||||
def test_stable_diffusion_ddim(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
|
||||
Reference in New Issue
Block a user