Compare commits

...

1 Commits

Author SHA1 Message Date
Aryan
6832e6a592 update 2025-07-27 00:16:40 +02:00

View File

@@ -15,6 +15,7 @@
PyTorch utilities: Utilities related to PyTorch
"""
import re
from typing import List, Optional, Tuple, Union
from . import logging
@@ -195,3 +196,17 @@ def device_synchronize(device_type: Optional[str] = None):
device_type = get_device()
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.synchronize()
def _find_modules_by_class_name(module: "torch.nn.Module", class_name: str) -> List[Tuple[str, "torch.nn.Module"]]:
"""
Recursively find all modules in a PyTorch module that match the specified class name. The class
name could be partial/full name or a regex pattern.
"""
pattern = re.compile(class_name)
matching_name_module_pairs = []
for name, submodule in module.named_modules():
submodule_cls = unwrap_module(submodule).__class__
if pattern.search(submodule_cls.__name__):
matching_name_module_pairs.append((name, submodule))
return matching_name_module_pairs