mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-20 03:14:43 +08:00
510 lines
18 KiB
Python
510 lines
18 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
# Copyright 2025 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Utility script to generate test suites for diffusers model classes.
|
|
|
|
Usage:
|
|
python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_flux.py
|
|
|
|
This will analyze the model file and generate a test file with appropriate
|
|
test classes based on the model's mixins and attributes.
|
|
"""
|
|
|
|
import argparse
|
|
import ast
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
MIXIN_TO_TESTER = {
|
|
"ModelMixin": "ModelTesterMixin",
|
|
"PeftAdapterMixin": "LoraTesterMixin",
|
|
}
|
|
|
|
ATTRIBUTE_TO_TESTER = {
|
|
"_cp_plan": "ContextParallelTesterMixin",
|
|
"_supports_gradient_checkpointing": "TrainingTesterMixin",
|
|
}
|
|
|
|
ALWAYS_INCLUDE_TESTERS = [
|
|
"ModelTesterMixin",
|
|
"MemoryTesterMixin",
|
|
"TorchCompileTesterMixin",
|
|
]
|
|
|
|
# Attention-related class names that indicate the model uses attention
|
|
ATTENTION_INDICATORS = {
|
|
"AttentionMixin",
|
|
"AttentionModuleMixin",
|
|
}
|
|
|
|
OPTIONAL_TESTERS = [
|
|
("BitsAndBytesTesterMixin", "bnb"),
|
|
("QuantoTesterMixin", "quanto"),
|
|
("TorchAoTesterMixin", "torchao"),
|
|
("GGUFTesterMixin", "gguf"),
|
|
("ModelOptTesterMixin", "modelopt"),
|
|
("SingleFileTesterMixin", "single_file"),
|
|
("IPAdapterTesterMixin", "ip_adapter"),
|
|
]
|
|
|
|
|
|
class ModelAnalyzer(ast.NodeVisitor):
|
|
def __init__(self):
|
|
self.model_classes = []
|
|
self.current_class = None
|
|
self.imports = set()
|
|
|
|
def visit_Import(self, node: ast.Import):
|
|
for alias in node.names:
|
|
self.imports.add(alias.name.split(".")[-1])
|
|
self.generic_visit(node)
|
|
|
|
def visit_ImportFrom(self, node: ast.ImportFrom):
|
|
for alias in node.names:
|
|
self.imports.add(alias.name)
|
|
self.generic_visit(node)
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef):
|
|
base_names = []
|
|
for base in node.bases:
|
|
if isinstance(base, ast.Name):
|
|
base_names.append(base.id)
|
|
elif isinstance(base, ast.Attribute):
|
|
base_names.append(base.attr)
|
|
|
|
if "ModelMixin" in base_names:
|
|
class_info = {
|
|
"name": node.name,
|
|
"bases": base_names,
|
|
"attributes": {},
|
|
"has_forward": False,
|
|
"init_params": [],
|
|
}
|
|
|
|
for item in node.body:
|
|
if isinstance(item, ast.Assign):
|
|
for target in item.targets:
|
|
if isinstance(target, ast.Name):
|
|
attr_name = target.id
|
|
if attr_name.startswith("_"):
|
|
class_info["attributes"][attr_name] = self._get_value(item.value)
|
|
|
|
elif isinstance(item, ast.FunctionDef):
|
|
if item.name == "forward":
|
|
class_info["has_forward"] = True
|
|
class_info["forward_params"] = self._extract_func_params(item)
|
|
elif item.name == "__init__":
|
|
class_info["init_params"] = self._extract_func_params(item)
|
|
|
|
self.model_classes.append(class_info)
|
|
|
|
self.generic_visit(node)
|
|
|
|
def _extract_func_params(self, func_node: ast.FunctionDef) -> list[dict]:
|
|
params = []
|
|
args = func_node.args
|
|
|
|
num_defaults = len(args.defaults)
|
|
num_args = len(args.args)
|
|
first_default_idx = num_args - num_defaults
|
|
|
|
for i, arg in enumerate(args.args):
|
|
if arg.arg == "self":
|
|
continue
|
|
|
|
param_info = {"name": arg.arg, "type": None, "default": None}
|
|
|
|
if arg.annotation:
|
|
param_info["type"] = self._get_annotation_str(arg.annotation)
|
|
|
|
default_idx = i - first_default_idx
|
|
if default_idx >= 0 and default_idx < len(args.defaults):
|
|
param_info["default"] = self._get_value(args.defaults[default_idx])
|
|
|
|
params.append(param_info)
|
|
|
|
return params
|
|
|
|
def _get_annotation_str(self, node) -> str:
|
|
if isinstance(node, ast.Name):
|
|
return node.id
|
|
elif isinstance(node, ast.Constant):
|
|
return repr(node.value)
|
|
elif isinstance(node, ast.Subscript):
|
|
base = self._get_annotation_str(node.value)
|
|
if isinstance(node.slice, ast.Tuple):
|
|
args = ", ".join(self._get_annotation_str(el) for el in node.slice.elts)
|
|
else:
|
|
args = self._get_annotation_str(node.slice)
|
|
return f"{base}[{args}]"
|
|
elif isinstance(node, ast.Attribute):
|
|
return f"{self._get_annotation_str(node.value)}.{node.attr}"
|
|
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
|
left = self._get_annotation_str(node.left)
|
|
right = self._get_annotation_str(node.right)
|
|
return f"{left} | {right}"
|
|
elif isinstance(node, ast.Tuple):
|
|
return ", ".join(self._get_annotation_str(el) for el in node.elts)
|
|
return "Any"
|
|
|
|
def _get_value(self, node):
|
|
if isinstance(node, ast.Constant):
|
|
return node.value
|
|
elif isinstance(node, ast.Name):
|
|
if node.id == "None":
|
|
return None
|
|
elif node.id == "True":
|
|
return True
|
|
elif node.id == "False":
|
|
return False
|
|
return node.id
|
|
elif isinstance(node, ast.List):
|
|
return [self._get_value(el) for el in node.elts]
|
|
elif isinstance(node, ast.Dict):
|
|
return {self._get_value(k): self._get_value(v) for k, v in zip(node.keys, node.values)}
|
|
return "<complex>"
|
|
|
|
|
|
def analyze_model_file(filepath: str) -> tuple[list[dict], set[str]]:
|
|
with open(filepath) as f:
|
|
source = f.read()
|
|
|
|
tree = ast.parse(source)
|
|
analyzer = ModelAnalyzer()
|
|
analyzer.visit(tree)
|
|
|
|
return analyzer.model_classes, analyzer.imports
|
|
|
|
|
|
def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]:
|
|
testers = list(ALWAYS_INCLUDE_TESTERS)
|
|
|
|
for base in model_info["bases"]:
|
|
if base in MIXIN_TO_TESTER:
|
|
tester = MIXIN_TO_TESTER[base]
|
|
if tester not in testers:
|
|
testers.append(tester)
|
|
|
|
for attr, tester in ATTRIBUTE_TO_TESTER.items():
|
|
if attr in model_info["attributes"]:
|
|
value = model_info["attributes"][attr]
|
|
if value is not None and value is not False:
|
|
if tester not in testers:
|
|
testers.append(tester)
|
|
|
|
if "_cp_plan" in model_info["attributes"] and model_info["attributes"]["_cp_plan"] is not None:
|
|
if "ContextParallelTesterMixin" not in testers:
|
|
testers.append("ContextParallelTesterMixin")
|
|
|
|
# Include AttentionTesterMixin if the model imports attention-related classes
|
|
if imports & ATTENTION_INDICATORS:
|
|
testers.append("AttentionTesterMixin")
|
|
|
|
for tester, flag in OPTIONAL_TESTERS:
|
|
if flag in include_optional:
|
|
if tester not in testers:
|
|
testers.append(tester)
|
|
|
|
return testers
|
|
|
|
|
|
def generate_config_class(model_info: dict, model_name: str) -> str:
|
|
class_name = f"{model_name}TesterConfig"
|
|
model_class = model_info["name"]
|
|
forward_params = model_info.get("forward_params", [])
|
|
init_params = model_info.get("init_params", [])
|
|
|
|
lines = [
|
|
f"class {class_name}:",
|
|
f" model_class = {model_class}",
|
|
' pretrained_model_name_or_path = ""',
|
|
' pretrained_model_kwargs = {"subfolder": "transformer"}',
|
|
"",
|
|
" @property",
|
|
" def generator(self):",
|
|
' return torch.Generator("cpu").manual_seed(0)',
|
|
"",
|
|
" def get_init_dict(self) -> dict[str, int | list[int]]:",
|
|
]
|
|
|
|
if init_params:
|
|
lines.append(" # __init__ parameters:")
|
|
for param in init_params:
|
|
type_str = f": {param['type']}" if param["type"] else ""
|
|
default_str = f" = {param['default']}" if param["default"] is not None else ""
|
|
lines.append(f" # {param['name']}{type_str}{default_str}")
|
|
|
|
lines.extend(
|
|
[
|
|
" return {}",
|
|
"",
|
|
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
|
]
|
|
)
|
|
|
|
if forward_params:
|
|
lines.append(" # forward() parameters:")
|
|
for param in forward_params:
|
|
type_str = f": {param['type']}" if param["type"] else ""
|
|
default_str = f" = {param['default']}" if param["default"] is not None else ""
|
|
lines.append(f" # {param['name']}{type_str}{default_str}")
|
|
|
|
lines.extend(
|
|
[
|
|
" # TODO: Fill in dummy inputs",
|
|
" return {}",
|
|
"",
|
|
" @property",
|
|
" def input_shape(self) -> tuple[int, ...]:",
|
|
" return (1, 1)",
|
|
"",
|
|
" @property",
|
|
" def output_shape(self) -> tuple[int, ...]:",
|
|
" return (1, 1)",
|
|
]
|
|
)
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
def generate_test_class(model_name: str, config_class: str, tester: str) -> str:
|
|
tester_short = tester.replace("TesterMixin", "")
|
|
class_name = f"Test{model_name}{tester_short}"
|
|
|
|
lines = [f"class {class_name}({config_class}, {tester}):"]
|
|
|
|
if tester == "TorchCompileTesterMixin":
|
|
lines.extend(
|
|
[
|
|
" different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]",
|
|
"",
|
|
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
|
|
" # TODO: Implement dynamic input generation",
|
|
" return {}",
|
|
]
|
|
)
|
|
elif tester == "IPAdapterTesterMixin":
|
|
lines.extend(
|
|
[
|
|
" ip_adapter_processor_cls = None # TODO: Set processor class",
|
|
"",
|
|
" def modify_inputs_for_ip_adapter(self, model, inputs_dict):",
|
|
" # TODO: Add IP adapter image embeds to inputs",
|
|
" return inputs_dict",
|
|
"",
|
|
" def create_ip_adapter_state_dict(self, model):",
|
|
" # TODO: Create IP adapter state dict",
|
|
" return {}",
|
|
]
|
|
)
|
|
elif tester == "SingleFileTesterMixin":
|
|
lines.extend(
|
|
[
|
|
' ckpt_path = "" # TODO: Set checkpoint path',
|
|
" alternate_keys_ckpt_paths = []",
|
|
' pretrained_model_name_or_path = ""',
|
|
' subfolder = "transformer"',
|
|
]
|
|
)
|
|
elif tester == "GGUFTesterMixin":
|
|
lines.extend(
|
|
[
|
|
' gguf_filename = "" # TODO: Set GGUF filename',
|
|
"",
|
|
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
|
" # TODO: Override with larger inputs for quantization tests",
|
|
" return {}",
|
|
]
|
|
)
|
|
elif tester in ["BitsAndBytesTesterMixin", "QuantoTesterMixin", "TorchAoTesterMixin", "ModelOptTesterMixin"]:
|
|
lines.extend(
|
|
[
|
|
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
|
" # TODO: Override with larger inputs for quantization tests",
|
|
" return {}",
|
|
]
|
|
)
|
|
elif tester == "LoraHotSwappingForModelTesterMixin":
|
|
lines.extend(
|
|
[
|
|
" different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]",
|
|
"",
|
|
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
|
|
" # TODO: Implement dynamic input generation",
|
|
" return {}",
|
|
]
|
|
)
|
|
else:
|
|
lines.append(" pass")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str:
|
|
model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "")
|
|
testers = determine_testers(model_info, include_optional, imports)
|
|
tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"})
|
|
|
|
lines = [
|
|
"# coding=utf-8",
|
|
"# Copyright 2025 HuggingFace Inc.",
|
|
"#",
|
|
'# Licensed under the Apache License, Version 2.0 (the "License");',
|
|
"# you may not use this file except in compliance with the License.",
|
|
"# You may obtain a copy of the License at",
|
|
"#",
|
|
"# http://www.apache.org/licenses/LICENSE-2.0",
|
|
"#",
|
|
"# Unless required by applicable law or agreed to in writing, software",
|
|
'# distributed under the License is distributed on an "AS IS" BASIS,',
|
|
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
|
|
"# See the License for the specific language governing permissions and",
|
|
"# limitations under the License.",
|
|
"",
|
|
"import torch",
|
|
"",
|
|
f"from diffusers import {model_info['name']}",
|
|
"from diffusers.utils.torch_utils import randn_tensor",
|
|
"",
|
|
"from ...testing_utils import enable_full_determinism, torch_device",
|
|
]
|
|
|
|
if "LoraTesterMixin" in testers:
|
|
lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin")
|
|
|
|
lines.extend(
|
|
[
|
|
"from ..testing_utils import (",
|
|
*[f" {tester}," for tester in sorted(tester_imports)],
|
|
")",
|
|
"",
|
|
"",
|
|
"enable_full_determinism()",
|
|
"",
|
|
"",
|
|
]
|
|
)
|
|
|
|
config_class = f"{model_name}TesterConfig"
|
|
lines.append(generate_config_class(model_info, model_name))
|
|
lines.append("")
|
|
lines.append("")
|
|
|
|
for tester in testers:
|
|
lines.append(generate_test_class(model_name, config_class, tester))
|
|
lines.append("")
|
|
lines.append("")
|
|
|
|
if "LoraTesterMixin" in testers:
|
|
lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin"))
|
|
lines.append("")
|
|
lines.append("")
|
|
|
|
return "\n".join(lines).rstrip() + "\n"
|
|
|
|
|
|
def get_test_output_path(model_filepath: str) -> str:
|
|
path = Path(model_filepath)
|
|
model_filename = path.stem
|
|
|
|
if "transformers" in path.parts:
|
|
return f"tests/models/transformers/test_models_{model_filename}.py"
|
|
elif "unets" in path.parts:
|
|
return f"tests/models/unets/test_models_{model_filename}.py"
|
|
elif "autoencoders" in path.parts:
|
|
return f"tests/models/autoencoders/test_models_{model_filename}.py"
|
|
else:
|
|
return f"tests/models/test_models_{model_filename}.py"
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Generate test suite for a diffusers model class")
|
|
parser.add_argument(
|
|
"model_filepath",
|
|
type=str,
|
|
help="Path to the model file (e.g., src/diffusers/models/transformers/transformer_flux.py)",
|
|
)
|
|
parser.add_argument(
|
|
"--output", "-o", type=str, default=None, help="Output file path (default: auto-generated based on model path)"
|
|
)
|
|
parser.add_argument(
|
|
"--include",
|
|
"-i",
|
|
type=str,
|
|
nargs="*",
|
|
default=[],
|
|
choices=["compile", "bnb", "quanto", "torchao", "gguf", "modelopt", "single_file", "ip_adapter", "all"],
|
|
help="Optional testers to include",
|
|
)
|
|
parser.add_argument(
|
|
"--class-name",
|
|
"-c",
|
|
type=str,
|
|
default=None,
|
|
help="Specific model class to generate tests for (default: first model class found)",
|
|
)
|
|
parser.add_argument("--dry-run", action="store_true", help="Print generated code without writing to file")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not Path(args.model_filepath).exists():
|
|
print(f"Error: File not found: {args.model_filepath}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
model_classes, imports = analyze_model_file(args.model_filepath)
|
|
|
|
if not model_classes:
|
|
print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
if args.class_name:
|
|
model_info = next((m for m in model_classes if m["name"] == args.class_name), None)
|
|
if not model_info:
|
|
available = [m["name"] for m in model_classes]
|
|
print(f"Error: Class '{args.class_name}' not found. Available: {available}", file=sys.stderr)
|
|
sys.exit(1)
|
|
else:
|
|
model_info = model_classes[0]
|
|
if len(model_classes) > 1:
|
|
print(f"Multiple model classes found, using: {model_info['name']}", file=sys.stderr)
|
|
print("Use --class-name to specify a different class", file=sys.stderr)
|
|
|
|
include_optional = args.include
|
|
if "all" in include_optional:
|
|
include_optional = [flag for _, flag in OPTIONAL_TESTERS]
|
|
|
|
generated_code = generate_test_file(model_info, args.model_filepath, include_optional, imports)
|
|
|
|
if args.dry_run:
|
|
print(generated_code)
|
|
else:
|
|
output_path = args.output or get_test_output_path(args.model_filepath)
|
|
output_dir = Path(output_path).parent
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path, "w") as f:
|
|
f.write(generated_code)
|
|
|
|
print(f"Generated test file: {output_path}")
|
|
print(f"Model class: {model_info['name']}")
|
|
print(f"Detected attributes: {list(model_info['attributes'].keys())}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|