mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-18 18:34:37 +08:00
Compare commits
5 Commits
complete-s
...
feat/memor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66d9dc539c | ||
|
|
ac24ec7221 | ||
|
|
28dcb729af | ||
|
|
9232b71c9c | ||
|
|
761a1aa2f4 |
@@ -921,6 +921,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
else:
|
else:
|
||||||
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
||||||
|
|
||||||
|
def get_memory_footprint(self, return_buffers=True) -> float:
|
||||||
|
r"""
|
||||||
|
Returns the memory footprint of a model. This will return the memory footprint of the current model in bytes.
|
||||||
|
Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
|
||||||
|
PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
return_buffers (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
|
||||||
|
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
|
||||||
|
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
|
||||||
|
"""
|
||||||
|
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
|
||||||
|
if return_buffers:
|
||||||
|
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
|
||||||
|
mem = mem + mem_bufs
|
||||||
|
return mem
|
||||||
|
|
||||||
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
||||||
deprecated_attention_block_paths = []
|
deprecated_attention_block_paths = []
|
||||||
|
|
||||||
|
|||||||
@@ -887,6 +887,38 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def get_memory_footprint(self, return_buffers=True) -> float:
|
||||||
|
"""Returns the memory footprint of a pipeline in bytes.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
return_buffers (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
|
||||||
|
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
|
||||||
|
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Only considers the model-level components of the underlying pipeline and not the scheduler.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||||
|
>>> print(pipeline.get_memory_footprint())
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_level_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
||||||
|
mem = sum(
|
||||||
|
[model.get_memory_footprint(return_buffers=return_buffers) for _, model in model_level_components.items()]
|
||||||
|
)
|
||||||
|
return mem
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -270,6 +270,12 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
|||||||
inputs_dict = self.dummy_input
|
inputs_dict = self.dummy_input
|
||||||
return init_dict, inputs_dict
|
return init_dict, inputs_dict
|
||||||
|
|
||||||
|
def test_memory_footprint(self):
|
||||||
|
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
memory_footprint = model.get_memory_footprint()
|
||||||
|
assert memory_footprint == 4424848
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch_device != "cuda" or not is_xformers_available(),
|
torch_device != "cuda" or not is_xformers_available(),
|
||||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||||
|
|||||||
@@ -181,6 +181,12 @@ class StableDiffusionPipelineFastTests(
|
|||||||
}
|
}
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
def test_memory_footprint(self):
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
sd_pipe = StableDiffusionPipeline(**components)
|
||||||
|
pipeline_memory_footprint = sd_pipe.get_memory_footprint()
|
||||||
|
assert pipeline_memory_footprint == 362420
|
||||||
|
|
||||||
def test_stable_diffusion_ddim(self):
|
def test_stable_diffusion_ddim(self):
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user