mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-13 11:13:33 +08:00
87 lines
2.3 KiB
Python
87 lines
2.3 KiB
Python
import ast
|
|
import json
|
|
import sys
|
|
|
|
|
|
SRC_DIRS = ["src/diffusers/pipelines/", "src/diffusers/models/", "src/diffusers/schedulers/"]
|
|
MIXIN_BASES = {"ModelMixin", "SchedulerMixin", "DiffusionPipeline"}
|
|
|
|
|
|
def extract_classes_from_file(filepath: str) -> list[str]:
|
|
with open(filepath) as f:
|
|
tree = ast.parse(f.read())
|
|
|
|
classes = []
|
|
for node in ast.walk(tree):
|
|
if not isinstance(node, ast.ClassDef):
|
|
continue
|
|
base_names = set()
|
|
for base in node.bases:
|
|
if isinstance(base, ast.Name):
|
|
base_names.add(base.id)
|
|
elif isinstance(base, ast.Attribute):
|
|
base_names.add(base.attr)
|
|
if base_names & MIXIN_BASES:
|
|
classes.append(node.name)
|
|
|
|
return classes
|
|
|
|
|
|
def extract_imports_from_file(filepath: str) -> set[str]:
|
|
with open(filepath) as f:
|
|
tree = ast.parse(f.read())
|
|
|
|
names = set()
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.ImportFrom):
|
|
for alias in node.names:
|
|
names.add(alias.name)
|
|
elif isinstance(node, ast.Import):
|
|
for alias in node.names:
|
|
names.add(alias.name.split(".")[-1])
|
|
|
|
return names
|
|
|
|
|
|
def main():
|
|
pr_files = json.load(sys.stdin)
|
|
|
|
new_classes = []
|
|
for f in pr_files:
|
|
if f["status"] != "added" or not f["filename"].endswith(".py"):
|
|
continue
|
|
if not any(f["filename"].startswith(d) for d in SRC_DIRS):
|
|
continue
|
|
try:
|
|
new_classes.extend(extract_classes_from_file(f["filename"]))
|
|
except (FileNotFoundError, SyntaxError):
|
|
continue
|
|
|
|
if not new_classes:
|
|
sys.exit(0)
|
|
|
|
new_test_files = [
|
|
f["filename"]
|
|
for f in pr_files
|
|
if f["status"] == "added" and f["filename"].startswith("tests/") and f["filename"].endswith(".py")
|
|
]
|
|
|
|
imported_names = set()
|
|
for filepath in new_test_files:
|
|
try:
|
|
imported_names |= extract_imports_from_file(filepath)
|
|
except (FileNotFoundError, SyntaxError):
|
|
continue
|
|
|
|
untested = [cls for cls in new_classes if cls not in imported_names]
|
|
|
|
if untested:
|
|
print(f"missing-tests: {', '.join(untested)}")
|
|
sys.exit(1)
|
|
else:
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|