Files
diffusers/utils/check_test_missing.py
Dhruv Nair d7bc233b4b [CI] Add PR/Issue Auto Labeler (#13380)
* update

* update

* update

* update

* update

* update

* update

* update

* Apply suggestion from @sayakpaul

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-04-07 10:02:18 +05:30

87 lines
2.3 KiB
Python

import ast
import json
import sys
SRC_DIRS = ["src/diffusers/pipelines/", "src/diffusers/models/", "src/diffusers/schedulers/"]
MIXIN_BASES = {"ModelMixin", "SchedulerMixin", "DiffusionPipeline"}
def extract_classes_from_file(filepath: str) -> list[str]:
with open(filepath) as f:
tree = ast.parse(f.read())
classes = []
for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef):
continue
base_names = set()
for base in node.bases:
if isinstance(base, ast.Name):
base_names.add(base.id)
elif isinstance(base, ast.Attribute):
base_names.add(base.attr)
if base_names & MIXIN_BASES:
classes.append(node.name)
return classes
def extract_imports_from_file(filepath: str) -> set[str]:
with open(filepath) as f:
tree = ast.parse(f.read())
names = set()
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
for alias in node.names:
names.add(alias.name)
elif isinstance(node, ast.Import):
for alias in node.names:
names.add(alias.name.split(".")[-1])
return names
def main():
pr_files = json.load(sys.stdin)
new_classes = []
for f in pr_files:
if f["status"] != "added" or not f["filename"].endswith(".py"):
continue
if not any(f["filename"].startswith(d) for d in SRC_DIRS):
continue
try:
new_classes.extend(extract_classes_from_file(f["filename"]))
except (FileNotFoundError, SyntaxError):
continue
if not new_classes:
sys.exit(0)
new_test_files = [
f["filename"]
for f in pr_files
if f["status"] == "added" and f["filename"].startswith("tests/") and f["filename"].endswith(".py")
]
imported_names = set()
for filepath in new_test_files:
try:
imported_names |= extract_imports_from_file(filepath)
except (FileNotFoundError, SyntaxError):
continue
untested = [cls for cls in new_classes if cls not in imported_names]
if untested:
print(f"missing-tests: {', '.join(untested)}")
sys.exit(1)
else:
sys.exit(0)
if __name__ == "__main__":
main()