Compare commits

...

4 Commits

Author SHA1 Message Date
sayakpaul
bae53f6004 resolve conflicts. 2024-06-23 17:14:50 +05:30
sayakpaul
41a56c80dc add: test 2024-06-07 11:36:06 +05:30
sayakpaul
f153ac10ed fix single file support 2024-06-07 11:32:15 +05:30
sayakpaul
8d631e9684 add single file loading support to PixArt Sigma 2024-06-07 10:10:13 +05:30
5 changed files with 254 additions and 2 deletions

View File

@@ -50,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`StableCascadeUNet`]
- [`AutoencoderKL`]
- [`ControlNetModel`]
- [`PixArtTransformer2DModel`]
- [`SD3Transformer2DModel`]
## FromSingleFileMixin

View File

@@ -24,9 +24,11 @@ from .single_file_utils import (
convert_controlnet_checkpoint,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_pixart_transformer_single_file_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_diffusers_config_from_pixart,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
fetch_diffusers_config,
@@ -65,6 +67,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
},
"PixArtTransformer2DModel": {
"checkpoint_mapping_fn": convert_pixart_transformer_single_file_to_diffusers,
"config_mapping_fn": create_diffusers_config_from_pixart,
},
"SD3Transformer2DModel": {
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
@@ -213,7 +219,7 @@ class FromOriginalModelMixin:
)
)
if isinstance(original_config, str):
if isinstance(original_config, str) and "PixArt" not in class_name:
# If original_config is a URL or filepath fetch the original_config dict
original_config = fetch_original_config(original_config, local_files_only=local_files_only)

View File

@@ -244,6 +244,9 @@ SCHEDULER_DEFAULT_CONFIG = {
"timestep_spacing": "leading",
}
# https://github.com/PixArt-alpha/PixArt-sigma/blob/dd087141864e30ec44f12cb7448dd654be065e88/scripts/inference.py#L158
PIXART_INTERPOLATION_SCALE = {256: 0.5, 512: 1, 1024: 2, 2048: 4}
LDM_VAE_KEY = "first_stage_model."
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
@@ -1601,6 +1604,152 @@ def _legacy_load_safety_checker(local_files_only, torch_dtype):
return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
def convert_pixart_transformer_single_file_to_diffusers(checkpoint, **kwargs):
checkpoint = checkpoint.pop("state_dict") if "state_dict" in checkpoint else checkpoint
converted_state_dict = {}
# Patch embeddings.
x_embedder_present = any("x_embedder" in key for key in checkpoint)
if x_embedder_present:
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
# Caption projection.
y_embedder_present = any("y_embedder" in key for key in checkpoint)
if y_embedder_present:
converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
# AdaLN-single LN
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
micro_condition = any(("resolution_embedder" in key or "aspect_ratio_embedder" in key) for key in checkpoint)
if micro_condition:
# Resolution.
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = checkpoint.pop(
"csize_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = checkpoint.pop(
"csize_embedder.mlp.0.bias"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = checkpoint.pop(
"csize_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = checkpoint.pop(
"csize_embedder.mlp.2.bias"
)
# Aspect ratio.
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = checkpoint.pop(
"ar_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = checkpoint.pop(
"ar_embedder.mlp.0.bias"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = checkpoint.pop(
"ar_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = checkpoint.pop(
"ar_embedder.mlp.2.bias"
)
# Shared norm.
converted_state_dict["adaln_single.linear.weight"] = checkpoint.pop("t_block.1.weight")
converted_state_dict["adaln_single.linear.bias"] = checkpoint.pop("t_block.1.bias")
depths = len({key.split(".")[1] for key in checkpoint if "blocks" in key})
for depth in range(depths):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = checkpoint.pop(
f"blocks.{depth}.scale_shift_table"
)
# Attention is all you need 🤘
# Self attention.
q, k, v = torch.chunk(checkpoint.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(checkpoint.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = checkpoint.pop(
f"blocks.{depth}.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = checkpoint.pop(
f"blocks.{depth}.attn.proj.bias"
)
qk_norm = any("q_norm" in key for key in checkpoint)
if qk_norm:
converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.weight"] = checkpoint.pop(
f"blocks.{depth}.attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.bias"] = checkpoint.pop(
f"blocks.{depth}.attn.q_norm.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.weight"] = checkpoint.pop(
f"blocks.{depth}.attn.k_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.bias"] = checkpoint.pop(
f"blocks.{depth}.attn.k_norm.bias"
)
# Feed-forward.
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = checkpoint.pop(
f"blocks.{depth}.mlp.fc1.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = checkpoint.pop(
f"blocks.{depth}.mlp.fc1.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = checkpoint.pop(
f"blocks.{depth}.mlp.fc2.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = checkpoint.pop(
f"blocks.{depth}.mlp.fc2.bias"
)
# Cross-attention.
q = checkpoint.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
q_bias = checkpoint.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
k, v = torch.chunk(checkpoint.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
k_bias, v_bias = torch.chunk(checkpoint.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = checkpoint.pop(
f"blocks.{depth}.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = checkpoint.pop(
f"blocks.{depth}.cross_attn.proj.bias"
)
# Final block.
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
try:
checkpoint.pop("y_embedder.y_embedding")
checkpoint.pop("pos_embed")
except Exception as e:
logger.debug(f"Skipping {str(e)}")
pass
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight, dim):
@@ -1753,6 +1902,35 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
return converted_state_dict
def create_diffusers_config_from_pixart(original_config, checkpoint, sample_size=None):
micro_condition = any(("resolution_embedder" in key or "aspect_ratio_embedder" in key) for key in checkpoint)
if sample_size is None:
sample_size = 1024 // 8
num_layers = len({key.split(".")[1] for key in checkpoint if "blocks" in key})
config = {
"sample_size": sample_size,
"num_layers": num_layers,
"attention_head_dim": 72,
"in_channels": 4,
"out_channels": 8,
"patch_size": 2,
"attention_bias": True,
"num_attention_heads": 16,
"cross_attention_dim": 1152,
"activation_fn": "gelu-approximate",
"num_embeds_ada_norm": 1000,
"norm_type": "ada_norm_single",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 4096,
"interpolation_scale": PIXART_INTERPOLATION_SCALE[sample_size * 8],
"use_additional_conditions": micro_condition,
}
return config
def is_t5_in_single_file(checkpoint):
if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
return True

View File

@@ -17,6 +17,7 @@ import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import is_torch_version, logging
from ..attention import BasicTransformerBlock
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
@@ -28,7 +29,7 @@ from ..normalization import AdaLayerNormSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
class PixArtTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
https://arxiv.org/abs/2403.04692).

View File

@@ -0,0 +1,66 @@
import gc
import unittest
import torch
from diffusers import PixArtSigmaPipeline, PixArtTransformer2DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
enable_full_determinism()
@slow
@require_torch_gpu
class PixArtSigmaPipelineSingleFileSlowTests(unittest.TestCase):
pipeline_class = PixArtSigmaPipeline
repo_id = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
ckpt_path = "https://huggingface.co/PixArt-alpha/PixArt-Sigma/blob/main/PixArt-Sigma-XL-2-1024-MS.pth"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
inputs = {
"prompt": "a fantasy landscape, concept art, high resolution",
"generator": generator,
"num_inference_steps": 2,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
transformer = PixArtTransformer2DModel.from_single_file(self.ckpt_path, original_config=True)
sf_pipe = self.pipeline_class.from_pretrained(self.repo_id, transformer=transformer)
sf_pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image_single_file = sf_pipe(**inputs).images[0]
del sf_pipe
pipe = self.pipeline_class.from_pretrained(self.repo_id)
pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
assert max_diff < expected_max_diff