mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
* OmniGen model.py * update OmniGenTransformerModel * omnigen pipeline * omnigen pipeline * update omnigen_pipeline * test case for omnigen * update omnigenpipeline * update docs * update docs * offload_transformer * enable_transformer_block_cpu_offload * update docs * reformat * reformat * reformat * update docs * update docs * make style * make style * Update docs/source/en/api/models/omnigen_transformer.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update docs * revert changes to examples/ * update OmniGen2DModel * make style * update test cases * Update docs/source/en/api/pipelines/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/using-diffusers/omnigen.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update docs * typo * Update src/diffusers/models/embeddings.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/models/attention.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/models/transformers/transformer_omnigen.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/models/transformers/transformer_omnigen.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/models/transformers/transformer_omnigen.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py Co-authored-by: hlky <hlky@hlky.ac> * Update tests/pipelines/omnigen/test_pipeline_omnigen.py Co-authored-by: hlky <hlky@hlky.ac> * Update tests/pipelines/omnigen/test_pipeline_omnigen.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py Co-authored-by: hlky <hlky@hlky.ac> * Update src/diffusers/pipelines/omnigen/pipeline_omnigen.py Co-authored-by: hlky <hlky@hlky.ac> * consistent attention processor * updata * update * check_inputs * make style * update testpipeline * update testpipeline --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: hlky <hlky@hlky.ac> Co-authored-by: Aryan <aryan@huggingface.co>
204 lines
7.3 KiB
Python
204 lines
7.3 KiB
Python
import argparse
|
|
import os
|
|
|
|
import torch
|
|
from huggingface_hub import snapshot_download
|
|
from safetensors.torch import load_file
|
|
from transformers import AutoTokenizer
|
|
|
|
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
|
|
|
|
|
|
def main(args):
|
|
# checkpoint from https://huggingface.co/Shitao/OmniGen-v1
|
|
|
|
if not os.path.exists(args.origin_ckpt_path):
|
|
print("Model not found, downloading...")
|
|
cache_folder = os.getenv("HF_HUB_CACHE")
|
|
args.origin_ckpt_path = snapshot_download(
|
|
repo_id=args.origin_ckpt_path,
|
|
cache_dir=cache_folder,
|
|
ignore_patterns=["flax_model.msgpack", "rust_model.ot", "tf_model.h5", "model.pt"],
|
|
)
|
|
print(f"Downloaded model to {args.origin_ckpt_path}")
|
|
|
|
ckpt = os.path.join(args.origin_ckpt_path, "model.safetensors")
|
|
ckpt = load_file(ckpt, device="cpu")
|
|
|
|
mapping_dict = {
|
|
"pos_embed": "patch_embedding.pos_embed",
|
|
"x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
|
|
"x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
|
|
"input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
|
|
"input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
|
|
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
|
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
|
"final_layer.linear.weight": "proj_out.weight",
|
|
"final_layer.linear.bias": "proj_out.bias",
|
|
"time_token.mlp.0.weight": "time_token.linear_1.weight",
|
|
"time_token.mlp.0.bias": "time_token.linear_1.bias",
|
|
"time_token.mlp.2.weight": "time_token.linear_2.weight",
|
|
"time_token.mlp.2.bias": "time_token.linear_2.bias",
|
|
"t_embedder.mlp.0.weight": "t_embedder.linear_1.weight",
|
|
"t_embedder.mlp.0.bias": "t_embedder.linear_1.bias",
|
|
"t_embedder.mlp.2.weight": "t_embedder.linear_2.weight",
|
|
"t_embedder.mlp.2.bias": "t_embedder.linear_2.bias",
|
|
"llm.embed_tokens.weight": "embed_tokens.weight",
|
|
}
|
|
|
|
converted_state_dict = {}
|
|
for k, v in ckpt.items():
|
|
if k in mapping_dict:
|
|
converted_state_dict[mapping_dict[k]] = v
|
|
elif "qkv" in k:
|
|
to_q, to_k, to_v = v.chunk(3)
|
|
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_q.weight"] = to_q
|
|
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_k.weight"] = to_k
|
|
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_v.weight"] = to_v
|
|
elif "o_proj" in k:
|
|
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_out.0.weight"] = v
|
|
else:
|
|
converted_state_dict[k[4:]] = v
|
|
|
|
transformer = OmniGenTransformer2DModel(
|
|
rope_scaling={
|
|
"long_factor": [
|
|
1.0299999713897705,
|
|
1.0499999523162842,
|
|
1.0499999523162842,
|
|
1.0799999237060547,
|
|
1.2299998998641968,
|
|
1.2299998998641968,
|
|
1.2999999523162842,
|
|
1.4499999284744263,
|
|
1.5999999046325684,
|
|
1.6499998569488525,
|
|
1.8999998569488525,
|
|
2.859999895095825,
|
|
3.68999981880188,
|
|
5.419999599456787,
|
|
5.489999771118164,
|
|
5.489999771118164,
|
|
9.09000015258789,
|
|
11.579999923706055,
|
|
15.65999984741211,
|
|
15.769999504089355,
|
|
15.789999961853027,
|
|
18.360000610351562,
|
|
21.989999771118164,
|
|
23.079999923706055,
|
|
30.009998321533203,
|
|
32.35000228881836,
|
|
32.590003967285156,
|
|
35.56000518798828,
|
|
39.95000457763672,
|
|
53.840003967285156,
|
|
56.20000457763672,
|
|
57.95000457763672,
|
|
59.29000473022461,
|
|
59.77000427246094,
|
|
59.920005798339844,
|
|
61.190006256103516,
|
|
61.96000671386719,
|
|
62.50000762939453,
|
|
63.3700065612793,
|
|
63.48000717163086,
|
|
63.48000717163086,
|
|
63.66000747680664,
|
|
63.850006103515625,
|
|
64.08000946044922,
|
|
64.760009765625,
|
|
64.80001068115234,
|
|
64.81001281738281,
|
|
64.81001281738281,
|
|
],
|
|
"short_factor": [
|
|
1.05,
|
|
1.05,
|
|
1.05,
|
|
1.1,
|
|
1.1,
|
|
1.1,
|
|
1.2500000000000002,
|
|
1.2500000000000002,
|
|
1.4000000000000004,
|
|
1.4500000000000004,
|
|
1.5500000000000005,
|
|
1.8500000000000008,
|
|
1.9000000000000008,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.000000000000001,
|
|
2.1000000000000005,
|
|
2.1000000000000005,
|
|
2.2,
|
|
2.3499999999999996,
|
|
2.3499999999999996,
|
|
2.3499999999999996,
|
|
2.3499999999999996,
|
|
2.3999999999999995,
|
|
2.3999999999999995,
|
|
2.6499999999999986,
|
|
2.6999999999999984,
|
|
2.8999999999999977,
|
|
2.9499999999999975,
|
|
3.049999999999997,
|
|
3.049999999999997,
|
|
3.049999999999997,
|
|
],
|
|
"type": "su",
|
|
},
|
|
patch_size=2,
|
|
in_channels=4,
|
|
pos_embed_max_size=192,
|
|
)
|
|
transformer.load_state_dict(converted_state_dict, strict=True)
|
|
transformer.to(torch.bfloat16)
|
|
|
|
num_model_params = sum(p.numel() for p in transformer.parameters())
|
|
print(f"Total number of transformer parameters: {num_model_params}")
|
|
|
|
scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)
|
|
|
|
vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
|
|
|
|
pipeline = OmniGenPipeline(tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler)
|
|
pipeline.save_pretrained(args.dump_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
"--origin_ckpt_path",
|
|
default="Shitao/OmniGen-v1",
|
|
type=str,
|
|
required=False,
|
|
help="Path to the checkpoint to convert.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline."
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
main(args)
|