mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-18 22:48:08 +08:00
Compare commits
8 Commits
tests-cond
...
make-tiny-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1d75ab7e35 | ||
|
|
b086b6da0a | ||
|
|
28c7516229 | ||
|
|
53b9b56059 | ||
|
|
65579667e9 | ||
|
|
cda1c36eeb | ||
|
|
f634485333 | ||
|
|
7820980959 |
210
utils/make_tiny_model.py
Normal file
210
utils/make_tiny_model.py
Normal file
@@ -0,0 +1,210 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "diffusers",
|
||||
# "torch",
|
||||
# "huggingface_hub",
|
||||
# "accelerate",
|
||||
# "transformers",
|
||||
# "sentencepiece",
|
||||
# "protobuf",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Utility script to create tiny versions of diffusers models by reducing layer counts.
|
||||
|
||||
Can be run locally or submitted as an HF Job via `--launch`.
|
||||
|
||||
Usage:
|
||||
# Run locally
|
||||
python make_tiny_model.py --model_repo_id <model_repo_id> --output_repo_id <output_repo_id> [--subfolder transformer] [--num_layers 2]
|
||||
|
||||
# Push to Hub
|
||||
python make_tiny_model.py --model_repo_id <model_repo_id> --output_repo_id <output_repo_id> --push_to_hub --token $HF_TOKEN
|
||||
|
||||
# Submit as an HF Job
|
||||
python make_tiny_model.py --model_repo_id <model_repo_id> --output_repo_id <output_repo_id> --launch [--flavor cpu-basic]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
LAYER_PARAM_PATTERN = re.compile(r"^(num_.*layers?|n_layers|n_refiner_layers)$")
|
||||
|
||||
DIM_PARAM_PATTERNS = {
|
||||
re.compile(r"^num_attention_heads$"): 2,
|
||||
re.compile(r"^num_.*attention_heads$"): 2,
|
||||
re.compile(r"^num_key_value_heads$"): 2,
|
||||
re.compile(r"^num_kv_heads$"): 1,
|
||||
re.compile(r"^n_heads$"): 2,
|
||||
re.compile(r"^n_kv_heads$"): 2,
|
||||
re.compile(r"^attention_head_dim$"): 8,
|
||||
re.compile(r"^.*attention_head_dim$"): 4,
|
||||
re.compile(r"^cross_attention_dim.*$"): 8,
|
||||
re.compile(r"^joint_attention_dim$"): 32,
|
||||
re.compile(r"^pooled_projection_dim$"): 32,
|
||||
re.compile(r"^caption_projection_dim$"): 32,
|
||||
re.compile(r"^caption_channels$"): 8,
|
||||
re.compile(r"^cap_feat_dim$"): 16,
|
||||
re.compile(r"^hidden_size$"): 16,
|
||||
re.compile(r"^dim$"): 16,
|
||||
re.compile(r"^.*embed_dim$"): 16,
|
||||
re.compile(r"^.*embed_.*dim$"): 16,
|
||||
re.compile(r"^text_dim$"): 16,
|
||||
re.compile(r"^time_embed_dim$"): 4,
|
||||
re.compile(r"^ffn_dim$"): 32,
|
||||
re.compile(r"^intermediate_size$"): 32,
|
||||
re.compile(r"^sample_size$"): 32,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Create a tiny version of a diffusers model.")
|
||||
parser.add_argument("--model_repo_id", type=str, required=True, help="HuggingFace repo ID of the source model.")
|
||||
parser.add_argument(
|
||||
"--output_repo_id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace repo ID or local path to save the tiny model to.",
|
||||
)
|
||||
parser.add_argument("--subfolder", type=str, default=None, help="Subfolder within the model repo.")
|
||||
parser.add_argument("--num_layers", type=int, default=2, help="Number of layers to use for the tiny model.")
|
||||
parser.add_argument(
|
||||
"--shrink_dims",
|
||||
action="store_true",
|
||||
help="Also reduce dimension parameters (attention heads, hidden size, embedding dims, etc.).",
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push the tiny model to the HuggingFace Hub.")
|
||||
parser.add_argument(
|
||||
"--token", type=str, default=None, help="HuggingFace token. Defaults to $HF_TOKEN env var if not provided."
|
||||
)
|
||||
|
||||
launch_group = parser.add_argument_group("HF Jobs launch options")
|
||||
launch_group.add_argument("--launch", action="store_true", help="Submit as an HF Job instead of running locally.")
|
||||
launch_group.add_argument("--flavor", type=str, default="cpu-basic", help="HF Jobs hardware flavor.")
|
||||
launch_group.add_argument("--timeout", type=str, default="30m", help="HF Jobs timeout.")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.token is None:
|
||||
args.token = os.environ.get("HF_TOKEN")
|
||||
return args
|
||||
|
||||
|
||||
def launch_job(args):
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
script_args = [
|
||||
"--model_repo_id",
|
||||
args.model_repo_id,
|
||||
"--output_repo_id",
|
||||
args.output_repo_id,
|
||||
"--num_layers",
|
||||
str(args.num_layers),
|
||||
]
|
||||
if args.subfolder:
|
||||
script_args.extend(["--subfolder", args.subfolder])
|
||||
if args.shrink_dims:
|
||||
script_args.append("--shrink_dims")
|
||||
if args.push_to_hub:
|
||||
script_args.append("--push_to_hub")
|
||||
|
||||
job = run_uv_job(
|
||||
__file__,
|
||||
script_args=script_args,
|
||||
flavor=args.flavor,
|
||||
timeout=args.timeout,
|
||||
secrets={"HF_TOKEN": args.token} if args.token else {},
|
||||
)
|
||||
print(f"Job submitted: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
return job
|
||||
|
||||
|
||||
def make_tiny_model(
|
||||
model_repo_id, output_repo_id, subfolder=None, num_layers=2, shrink_dims=False, push_to_hub=False, token=None
|
||||
):
|
||||
from diffusers import AutoModel
|
||||
|
||||
config_kwargs = {}
|
||||
if token:
|
||||
config_kwargs["token"] = token
|
||||
|
||||
config = AutoModel.load_config(model_repo_id, subfolder=subfolder, **config_kwargs)
|
||||
|
||||
modified_keys = {}
|
||||
for key, value in config.items():
|
||||
if LAYER_PARAM_PATTERN.match(key) and isinstance(value, int) and value > num_layers:
|
||||
modified_keys[key] = (value, num_layers)
|
||||
config[key] = num_layers
|
||||
|
||||
if shrink_dims:
|
||||
for key, value in config.items():
|
||||
if not isinstance(value, int) or key.startswith("_"):
|
||||
continue
|
||||
for pattern, tiny_value in DIM_PARAM_PATTERNS.items():
|
||||
if pattern.match(key) and value > tiny_value:
|
||||
modified_keys[key] = (value, tiny_value)
|
||||
config[key] = tiny_value
|
||||
break
|
||||
|
||||
if not modified_keys:
|
||||
print("WARNING: No config parameters were modified.")
|
||||
print(f"Config keys: {[k for k in config if not k.startswith('_')]}")
|
||||
return
|
||||
|
||||
print("Modified config parameters:")
|
||||
for key, (old, new) in modified_keys.items():
|
||||
print(f" {key}: {old} -> {new}")
|
||||
|
||||
model = AutoModel.from_config(config)
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"Tiny model created with {total_params:,} parameters.")
|
||||
|
||||
save_kwargs = {}
|
||||
if token:
|
||||
save_kwargs["token"] = token
|
||||
if push_to_hub:
|
||||
save_kwargs["repo_id"] = output_repo_id
|
||||
model.save_pretrained(output_repo_id, push_to_hub=push_to_hub, **save_kwargs)
|
||||
if push_to_hub:
|
||||
print(f"Model pushed to https://huggingface.co/{output_repo_id}")
|
||||
else:
|
||||
print(f"Model saved to {output_repo_id}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.launch:
|
||||
launch_job(args)
|
||||
else:
|
||||
make_tiny_model(
|
||||
model_repo_id=args.model_repo_id,
|
||||
output_repo_id=args.output_repo_id,
|
||||
subfolder=args.subfolder,
|
||||
num_layers=args.num_layers,
|
||||
shrink_dims=args.shrink_dims,
|
||||
push_to_hub=args.push_to_hub,
|
||||
token=args.token,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user