Fix PixArt 256px inference (#6789)

* feat 256px diffusers inference bug

* change the max_length of T5 to pipeline config file

* fix bug in convert_pixart_alpha_to_diffusers.py

* Update scripts/convert_pixart_alpha_to_diffusers.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* remove multi_scale_train parser

* Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* styling

* change `model_token_max_length` to call argument.

* Refactoring

* add: max_sequence_length to the docstring.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Junsong Chen
2024-03-03 13:01:21 +08:00
committed by GitHub
parent ccb93dcad1
commit f55873b783
3 changed files with 59 additions and 11 deletions

View File

@@ -9,11 +9,11 @@ from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPip
ckpt_id = "PixArt-alpha/PixArt-alpha"
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125
interpolation_scale = {512: 1, 1024: 2}
interpolation_scale = {256: 0.5, 512: 1, 1024: 2}
def main(args):
all_state_dict = torch.load(args.orig_ckpt_path)
all_state_dict = torch.load(args.orig_ckpt_path, map_location="cpu")
state_dict = all_state_dict.pop("state_dict")
converted_state_dict = {}
@@ -22,7 +22,6 @@ def main(args):
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
# Caption projection.
converted_state_dict["caption_projection.y_embedding"] = state_dict.pop("y_embedder.y_embedding")
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
@@ -155,6 +154,7 @@ def main(args):
assert transformer.pos_embed.pos_embed is not None
state_dict.pop("pos_embed")
state_dict.pop("y_embedder.y_embedding")
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
num_model_params = sum(p.numel() for p in transformer.parameters())
@@ -187,7 +187,7 @@ if __name__ == "__main__":
"--image_size",
default=1024,
type=int,
choices=[512, 1024],
choices=[256, 512, 1024],
required=False,
help="Image size of pretrained model, either 512 or 1024.",
)

View File

@@ -97,6 +97,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
norm_eps: float = 1e-5,
attention_type: str = "default",
caption_channels: int = None,
interpolation_scale: float = None,
):
super().__init__()
self.use_linear_projection = use_linear_projection
@@ -168,8 +169,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.width = sample_size
self.patch_size = patch_size
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
interpolation_scale = max(interpolation_scale, 1)
interpolation_scale = (
interpolation_scale if interpolation_scale is not None else max(self.config.sample_size // 64, 1)
)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,

View File

@@ -133,6 +133,42 @@ ASPECT_RATIO_512_BIN = {
"4.0": [1024.0, 256.0],
}
ASPECT_RATIO_256_BIN = {
"0.25": [128.0, 512.0],
"0.28": [128.0, 464.0],
"0.32": [144.0, 448.0],
"0.33": [144.0, 432.0],
"0.35": [144.0, 416.0],
"0.4": [160.0, 400.0],
"0.42": [160.0, 384.0],
"0.48": [176.0, 368.0],
"0.5": [176.0, 352.0],
"0.52": [176.0, 336.0],
"0.57": [192.0, 336.0],
"0.6": [192.0, 320.0],
"0.68": [208.0, 304.0],
"0.72": [208.0, 288.0],
"0.78": [224.0, 288.0],
"0.82": [224.0, 272.0],
"0.88": [240.0, 272.0],
"0.94": [240.0, 256.0],
"1.0": [256.0, 256.0],
"1.07": [256.0, 240.0],
"1.13": [272.0, 240.0],
"1.21": [272.0, 224.0],
"1.29": [288.0, 224.0],
"1.38": [288.0, 208.0],
"1.46": [304.0, 208.0],
"1.67": [320.0, 192.0],
"1.75": [336.0, 192.0],
"2.0": [352.0, 176.0],
"2.09": [368.0, 176.0],
"2.4": [384.0, 160.0],
"2.5": [400.0, 160.0],
"3.0": [432.0, 144.0],
"4.0": [512.0, 128.0],
}
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
@@ -260,6 +296,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
max_sequence_length: int = 120,
**kwargs,
):
r"""
@@ -284,8 +321,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
string.
clean_caption (bool, defaults to `False`):
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
"""
if "mask_feature" in kwargs:
@@ -303,7 +341,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper.
max_length = 120
max_length = max_sequence_length
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
@@ -688,6 +726,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
callback_steps: int = 1,
clean_caption: bool = True,
use_resolution_binning: bool = True,
max_sequence_length: int = 120,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
"""
@@ -757,6 +796,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
the requested resolution. Useful for generating non-square images.
max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.
Examples:
@@ -772,9 +812,14 @@ class PixArtAlphaPipeline(DiffusionPipeline):
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
aspect_ratio_bin = (
ASPECT_RATIO_1024_BIN if self.transformer.config.sample_size == 128 else ASPECT_RATIO_512_BIN
)
if self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_512_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
@@ -822,6 +867,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)