mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
* rename photon to prx * rename photon into prx * Revert .gitignore to state before commitb7fb0fe9d6* rename photon to prx * rename photon into prx * Revert .gitignore to state before commitb7fb0fe9d6* make fix-copies
346 lines
12 KiB
Python
346 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Script to convert PRX checkpoint from original codebase to diffusers format.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Dict, Tuple
|
|
|
|
import torch
|
|
from safetensors.torch import save_file
|
|
|
|
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
|
from diffusers.pipelines.prx import PRXPipeline
|
|
|
|
|
|
DEFAULT_RESOLUTION = 512
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PRXBase:
|
|
context_in_dim: int = 2304
|
|
hidden_size: int = 1792
|
|
mlp_ratio: float = 3.5
|
|
num_heads: int = 28
|
|
depth: int = 16
|
|
axes_dim: Tuple[int, int] = (32, 32)
|
|
theta: int = 10_000
|
|
time_factor: float = 1000.0
|
|
time_max_period: int = 10_000
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PRXFlux(PRXBase):
|
|
in_channels: int = 16
|
|
patch_size: int = 2
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PRXDCAE(PRXBase):
|
|
in_channels: int = 32
|
|
patch_size: int = 1
|
|
|
|
|
|
def build_config(vae_type: str) -> Tuple[dict, int]:
|
|
if vae_type == "flux":
|
|
cfg = PRXFlux()
|
|
elif vae_type == "dc-ae":
|
|
cfg = PRXDCAE()
|
|
else:
|
|
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
|
|
|
|
config_dict = asdict(cfg)
|
|
config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
|
|
return config_dict
|
|
|
|
|
|
def create_parameter_mapping(depth: int) -> dict:
|
|
"""Create mapping from old parameter names to new diffusers names."""
|
|
|
|
# Key mappings for structural changes
|
|
mapping = {}
|
|
|
|
# Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
|
|
for i in range(depth):
|
|
# QKV projections moved to attention module
|
|
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
|
|
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
|
|
|
|
# QK norm moved to attention module and renamed to match Attention's qk_norm structure
|
|
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
|
|
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
|
|
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
|
|
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
|
|
|
|
# K norm for text tokens moved to attention module
|
|
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
|
|
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
|
|
|
|
# Attention output projection
|
|
mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
|
|
|
|
return mapping
|
|
|
|
|
|
def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
|
|
"""Convert old checkpoint parameters to new diffusers format."""
|
|
|
|
print("Converting checkpoint parameters...")
|
|
|
|
mapping = create_parameter_mapping(depth)
|
|
converted_state_dict = {}
|
|
|
|
for key, value in old_state_dict.items():
|
|
new_key = key
|
|
|
|
# Apply specific mappings if needed
|
|
if key in mapping:
|
|
new_key = mapping[key]
|
|
print(f" Mapped: {key} -> {new_key}")
|
|
|
|
converted_state_dict[new_key] = value
|
|
|
|
print(f"✓ Converted {len(converted_state_dict)} parameters")
|
|
return converted_state_dict
|
|
|
|
|
|
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
|
|
"""Create and load PRXTransformer2DModel from old checkpoint."""
|
|
|
|
print(f"Loading checkpoint from: {checkpoint_path}")
|
|
|
|
# Load old checkpoint
|
|
if not os.path.exists(checkpoint_path):
|
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
|
|
old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
|
|
# Handle different checkpoint formats
|
|
if isinstance(old_checkpoint, dict):
|
|
if "model" in old_checkpoint:
|
|
state_dict = old_checkpoint["model"]
|
|
elif "state_dict" in old_checkpoint:
|
|
state_dict = old_checkpoint["state_dict"]
|
|
else:
|
|
state_dict = old_checkpoint
|
|
else:
|
|
state_dict = old_checkpoint
|
|
|
|
print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
|
|
|
|
# Convert parameter names if needed
|
|
model_depth = int(config.get("depth", 16))
|
|
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
|
|
|
|
# Create transformer with config
|
|
print("Creating PRXTransformer2DModel...")
|
|
transformer = PRXTransformer2DModel(**config)
|
|
|
|
# Load state dict
|
|
print("Loading converted parameters...")
|
|
missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
|
|
|
|
if missing_keys:
|
|
print(f"⚠ Missing keys: {missing_keys}")
|
|
if unexpected_keys:
|
|
print(f"⚠ Unexpected keys: {unexpected_keys}")
|
|
|
|
if not missing_keys and not unexpected_keys:
|
|
print("✓ All parameters loaded successfully!")
|
|
|
|
return transformer
|
|
|
|
|
|
def create_scheduler_config(output_path: str, shift: float):
|
|
"""Create FlowMatchEulerDiscreteScheduler config."""
|
|
|
|
scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift}
|
|
|
|
scheduler_path = os.path.join(output_path, "scheduler")
|
|
os.makedirs(scheduler_path, exist_ok=True)
|
|
|
|
with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
|
|
json.dump(scheduler_config, f, indent=2)
|
|
|
|
print("✓ Created scheduler config")
|
|
|
|
|
|
def download_and_save_vae(vae_type: str, output_path: str):
|
|
"""Download and save VAE to local directory."""
|
|
from diffusers import AutoencoderDC, AutoencoderKL
|
|
|
|
vae_path = os.path.join(output_path, "vae")
|
|
os.makedirs(vae_path, exist_ok=True)
|
|
|
|
if vae_type == "flux":
|
|
print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
|
|
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
|
|
else: # dc-ae
|
|
print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...")
|
|
vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
|
|
|
|
vae.save_pretrained(vae_path)
|
|
print(f"✓ Saved VAE to {vae_path}")
|
|
|
|
|
|
def download_and_save_text_encoder(output_path: str):
|
|
"""Download and save T5Gemma text encoder and tokenizer."""
|
|
from transformers import GemmaTokenizerFast
|
|
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
|
|
|
|
text_encoder_path = os.path.join(output_path, "text_encoder")
|
|
tokenizer_path = os.path.join(output_path, "tokenizer")
|
|
os.makedirs(text_encoder_path, exist_ok=True)
|
|
os.makedirs(tokenizer_path, exist_ok=True)
|
|
|
|
print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
|
|
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
|
|
|
|
# Extract and save only the encoder
|
|
t5gemma_encoder = t5gemma_model.encoder
|
|
t5gemma_encoder.save_pretrained(text_encoder_path)
|
|
print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}")
|
|
|
|
print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
|
|
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
|
|
tokenizer.model_max_length = 256
|
|
tokenizer.save_pretrained(tokenizer_path)
|
|
print(f"✓ Saved tokenizer to {tokenizer_path}")
|
|
|
|
|
|
def create_model_index(vae_type: str, default_image_size: int, output_path: str):
|
|
"""Create model_index.json for the pipeline."""
|
|
|
|
if vae_type == "flux":
|
|
vae_class = "AutoencoderKL"
|
|
else: # dc-ae
|
|
vae_class = "AutoencoderDC"
|
|
|
|
model_index = {
|
|
"_class_name": "PRXPipeline",
|
|
"_diffusers_version": "0.31.0.dev0",
|
|
"_name_or_path": os.path.basename(output_path),
|
|
"default_sample_size": default_image_size,
|
|
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
|
|
"text_encoder": ["prx", "T5GemmaEncoder"],
|
|
"tokenizer": ["transformers", "GemmaTokenizerFast"],
|
|
"transformer": ["diffusers", "PRXTransformer2DModel"],
|
|
"vae": ["diffusers", vae_class],
|
|
}
|
|
|
|
model_index_path = os.path.join(output_path, "model_index.json")
|
|
with open(model_index_path, "w") as f:
|
|
json.dump(model_index, f, indent=2)
|
|
|
|
|
|
def main(args):
|
|
# Validate inputs
|
|
if not os.path.exists(args.checkpoint_path):
|
|
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
|
|
|
|
config = build_config(args.vae_type)
|
|
|
|
# Create output directory
|
|
os.makedirs(args.output_path, exist_ok=True)
|
|
print(f"✓ Output directory: {args.output_path}")
|
|
|
|
# Create transformer from checkpoint
|
|
transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
|
|
|
|
# Save transformer
|
|
transformer_path = os.path.join(args.output_path, "transformer")
|
|
os.makedirs(transformer_path, exist_ok=True)
|
|
|
|
# Save config
|
|
with open(os.path.join(transformer_path, "config.json"), "w") as f:
|
|
json.dump(config, f, indent=2)
|
|
|
|
# Save model weights as safetensors
|
|
state_dict = transformer.state_dict()
|
|
save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
|
|
print(f"✓ Saved transformer to {transformer_path}")
|
|
|
|
# Create scheduler config
|
|
create_scheduler_config(args.output_path, args.shift)
|
|
|
|
download_and_save_vae(args.vae_type, args.output_path)
|
|
download_and_save_text_encoder(args.output_path)
|
|
|
|
# Create model_index.json
|
|
create_model_index(args.vae_type, args.resolution, args.output_path)
|
|
|
|
# Verify the pipeline can be loaded
|
|
try:
|
|
pipeline = PRXPipeline.from_pretrained(args.output_path)
|
|
print("Pipeline loaded successfully!")
|
|
print(f"Transformer: {type(pipeline.transformer).__name__}")
|
|
print(f"VAE: {type(pipeline.vae).__name__}")
|
|
print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
|
|
print(f"Scheduler: {type(pipeline.scheduler).__name__}")
|
|
|
|
# Display model info
|
|
num_params = sum(p.numel() for p in pipeline.transformer.parameters())
|
|
print(f"✓ Transformer parameters: {num_params:,}")
|
|
|
|
except Exception as e:
|
|
print(f"Pipeline verification failed: {e}")
|
|
return False
|
|
|
|
print("Conversion completed successfully!")
|
|
print(f"Converted pipeline saved to: {args.output_path}")
|
|
print(f"VAE type: {args.vae_type}")
|
|
|
|
return True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
|
|
|
|
parser.add_argument(
|
|
"--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--vae_type",
|
|
type=str,
|
|
choices=["flux", "dc-ae"],
|
|
required=True,
|
|
help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--resolution",
|
|
type=int,
|
|
choices=[256, 512, 1024],
|
|
default=DEFAULT_RESOLUTION,
|
|
help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--shift",
|
|
type=float,
|
|
default=3.0,
|
|
help="Shift for the scheduler",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
success = main(args)
|
|
if not success:
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f"Conversion failed: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
sys.exit(1)
|