mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-22 19:45:47 +08:00
Compare commits
4 Commits
modular-te
...
support-sf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bae53f6004 | ||
|
|
41a56c80dc | ||
|
|
f153ac10ed | ||
|
|
8d631e9684 |
@@ -50,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
|
||||
- [`StableCascadeUNet`]
|
||||
- [`AutoencoderKL`]
|
||||
- [`ControlNetModel`]
|
||||
- [`PixArtTransformer2DModel`]
|
||||
- [`SD3Transformer2DModel`]
|
||||
|
||||
## FromSingleFileMixin
|
||||
|
||||
@@ -24,9 +24,11 @@ from .single_file_utils import (
|
||||
convert_controlnet_checkpoint,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_pixart_transformer_single_file_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
create_diffusers_config_from_pixart,
|
||||
create_unet_diffusers_config_from_ldm,
|
||||
create_vae_diffusers_config_from_ldm,
|
||||
fetch_diffusers_config,
|
||||
@@ -65,6 +67,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
||||
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
||||
},
|
||||
"PixArtTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_pixart_transformer_single_file_to_diffusers,
|
||||
"config_mapping_fn": create_diffusers_config_from_pixart,
|
||||
},
|
||||
"SD3Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
@@ -213,7 +219,7 @@ class FromOriginalModelMixin:
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(original_config, str):
|
||||
if isinstance(original_config, str) and "PixArt" not in class_name:
|
||||
# If original_config is a URL or filepath fetch the original_config dict
|
||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||
|
||||
|
||||
@@ -244,6 +244,9 @@ SCHEDULER_DEFAULT_CONFIG = {
|
||||
"timestep_spacing": "leading",
|
||||
}
|
||||
|
||||
# https://github.com/PixArt-alpha/PixArt-sigma/blob/dd087141864e30ec44f12cb7448dd654be065e88/scripts/inference.py#L158
|
||||
PIXART_INTERPOLATION_SCALE = {256: 0.5, 512: 1, 1024: 2, 2048: 4}
|
||||
|
||||
LDM_VAE_KEY = "first_stage_model."
|
||||
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
||||
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
||||
@@ -1601,6 +1604,152 @@ def _legacy_load_safety_checker(local_files_only, torch_dtype):
|
||||
return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
|
||||
|
||||
|
||||
def convert_pixart_transformer_single_file_to_diffusers(checkpoint, **kwargs):
|
||||
checkpoint = checkpoint.pop("state_dict") if "state_dict" in checkpoint else checkpoint
|
||||
converted_state_dict = {}
|
||||
|
||||
# Patch embeddings.
|
||||
x_embedder_present = any("x_embedder" in key for key in checkpoint)
|
||||
if x_embedder_present:
|
||||
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
|
||||
# Caption projection.
|
||||
y_embedder_present = any("y_embedder" in key for key in checkpoint)
|
||||
if y_embedder_present:
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
|
||||
|
||||
# AdaLN-single LN
|
||||
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
|
||||
micro_condition = any(("resolution_embedder" in key or "aspect_ratio_embedder" in key) for key in checkpoint)
|
||||
if micro_condition:
|
||||
# Resolution.
|
||||
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"csize_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = checkpoint.pop(
|
||||
"csize_embedder.mlp.0.bias"
|
||||
)
|
||||
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"csize_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = checkpoint.pop(
|
||||
"csize_embedder.mlp.2.bias"
|
||||
)
|
||||
# Aspect ratio.
|
||||
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"ar_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = checkpoint.pop(
|
||||
"ar_embedder.mlp.0.bias"
|
||||
)
|
||||
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"ar_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = checkpoint.pop(
|
||||
"ar_embedder.mlp.2.bias"
|
||||
)
|
||||
# Shared norm.
|
||||
converted_state_dict["adaln_single.linear.weight"] = checkpoint.pop("t_block.1.weight")
|
||||
converted_state_dict["adaln_single.linear.bias"] = checkpoint.pop("t_block.1.bias")
|
||||
|
||||
depths = len({key.split(".")[1] for key in checkpoint if "blocks" in key})
|
||||
for depth in range(depths):
|
||||
# Transformer blocks.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = checkpoint.pop(
|
||||
f"blocks.{depth}.scale_shift_table"
|
||||
)
|
||||
# Attention is all you need 🤘
|
||||
|
||||
# Self attention.
|
||||
q, k, v = torch.chunk(checkpoint.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
|
||||
q_bias, k_bias, v_bias = torch.chunk(checkpoint.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
|
||||
# Projection.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = checkpoint.pop(
|
||||
f"blocks.{depth}.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = checkpoint.pop(
|
||||
f"blocks.{depth}.attn.proj.bias"
|
||||
)
|
||||
qk_norm = any("q_norm" in key for key in checkpoint)
|
||||
if qk_norm:
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.weight"] = checkpoint.pop(
|
||||
f"blocks.{depth}.attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.bias"] = checkpoint.pop(
|
||||
f"blocks.{depth}.attn.q_norm.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.weight"] = checkpoint.pop(
|
||||
f"blocks.{depth}.attn.k_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.bias"] = checkpoint.pop(
|
||||
f"blocks.{depth}.attn.k_norm.bias"
|
||||
)
|
||||
|
||||
# Feed-forward.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = checkpoint.pop(
|
||||
f"blocks.{depth}.mlp.fc1.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = checkpoint.pop(
|
||||
f"blocks.{depth}.mlp.fc1.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = checkpoint.pop(
|
||||
f"blocks.{depth}.mlp.fc2.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = checkpoint.pop(
|
||||
f"blocks.{depth}.mlp.fc2.bias"
|
||||
)
|
||||
|
||||
# Cross-attention.
|
||||
q = checkpoint.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
|
||||
q_bias = checkpoint.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
|
||||
k, v = torch.chunk(checkpoint.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
|
||||
k_bias, v_bias = torch.chunk(checkpoint.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = checkpoint.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = checkpoint.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.bias"
|
||||
)
|
||||
|
||||
# Final block.
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
|
||||
|
||||
try:
|
||||
checkpoint.pop("y_embedder.y_embedding")
|
||||
checkpoint.pop("pos_embed")
|
||||
except Exception as e:
|
||||
logger.debug(f"Skipping {str(e)}")
|
||||
pass
|
||||
|
||||
|
||||
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
||||
def swap_scale_shift(weight, dim):
|
||||
@@ -1753,6 +1902,35 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def create_diffusers_config_from_pixart(original_config, checkpoint, sample_size=None):
|
||||
micro_condition = any(("resolution_embedder" in key or "aspect_ratio_embedder" in key) for key in checkpoint)
|
||||
if sample_size is None:
|
||||
sample_size = 1024 // 8
|
||||
|
||||
num_layers = len({key.split(".")[1] for key in checkpoint if "blocks" in key})
|
||||
|
||||
config = {
|
||||
"sample_size": sample_size,
|
||||
"num_layers": num_layers,
|
||||
"attention_head_dim": 72,
|
||||
"in_channels": 4,
|
||||
"out_channels": 8,
|
||||
"patch_size": 2,
|
||||
"attention_bias": True,
|
||||
"num_attention_heads": 16,
|
||||
"cross_attention_dim": 1152,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"num_embeds_ada_norm": 1000,
|
||||
"norm_type": "ada_norm_single",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 4096,
|
||||
"interpolation_scale": PIXART_INTERPOLATION_SCALE[sample_size * 8],
|
||||
"use_additional_conditions": micro_condition,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def is_t5_in_single_file(checkpoint):
|
||||
if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
|
||||
return True
|
||||
|
||||
@@ -17,6 +17,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
@@ -28,7 +29,7 @@ from ..normalization import AdaLayerNormSingle
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
class PixArtTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
|
||||
https://arxiv.org/abs/2403.04692).
|
||||
|
||||
66
tests/single_file/test_pixart_sigma_single_file.py
Normal file
66
tests/single_file/test_pixart_sigma_single_file.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import PixArtSigmaPipeline, PixArtTransformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class PixArtSigmaPipelineSingleFileSlowTests(unittest.TestCase):
|
||||
pipeline_class = PixArtSigmaPipeline
|
||||
repo_id = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
|
||||
ckpt_path = "https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-1024-MS.pth"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "a fantasy landscape, concept art, high resolution",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
|
||||
transformer = PixArtTransformer2DModel.from_single_file(self.ckpt_path, original_config=True)
|
||||
sf_pipe = self.pipeline_class.from_pretrained(self.repo_id, transformer=transformer)
|
||||
sf_pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image_single_file = sf_pipe(**inputs).images[0]
|
||||
|
||||
del sf_pipe
|
||||
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
|
||||
|
||||
assert max_diff < expected_max_diff
|
||||
Reference in New Issue
Block a user