mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Fix PixArt 256px inference (#6789)
* feat 256px diffusers inference bug * change the max_length of T5 to pipeline config file * fix bug in convert_pixart_alpha_to_diffusers.py * Update scripts/convert_pixart_alpha_to_diffusers.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * remove multi_scale_train parser * Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * styling * change `model_token_max_length` to call argument. * Refactoring * add: max_sequence_length to the docstring. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -9,11 +9,11 @@ from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPip
|
||||
|
||||
ckpt_id = "PixArt-alpha/PixArt-alpha"
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125
|
||||
interpolation_scale = {512: 1, 1024: 2}
|
||||
interpolation_scale = {256: 0.5, 512: 1, 1024: 2}
|
||||
|
||||
|
||||
def main(args):
|
||||
all_state_dict = torch.load(args.orig_ckpt_path)
|
||||
all_state_dict = torch.load(args.orig_ckpt_path, map_location="cpu")
|
||||
state_dict = all_state_dict.pop("state_dict")
|
||||
converted_state_dict = {}
|
||||
|
||||
@@ -22,7 +22,6 @@ def main(args):
|
||||
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# Caption projection.
|
||||
converted_state_dict["caption_projection.y_embedding"] = state_dict.pop("y_embedder.y_embedding")
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
||||
@@ -155,6 +154,7 @@ def main(args):
|
||||
|
||||
assert transformer.pos_embed.pos_embed is not None
|
||||
state_dict.pop("pos_embed")
|
||||
state_dict.pop("y_embedder.y_embedding")
|
||||
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
|
||||
|
||||
num_model_params = sum(p.numel() for p in transformer.parameters())
|
||||
@@ -187,7 +187,7 @@ if __name__ == "__main__":
|
||||
"--image_size",
|
||||
default=1024,
|
||||
type=int,
|
||||
choices=[512, 1024],
|
||||
choices=[256, 512, 1024],
|
||||
required=False,
|
||||
help="Image size of pretrained model, either 512 or 1024.",
|
||||
)
|
||||
|
||||
@@ -97,6 +97,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
norm_eps: float = 1e-5,
|
||||
attention_type: str = "default",
|
||||
caption_channels: int = None,
|
||||
interpolation_scale: float = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
@@ -168,8 +169,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
self.width = sample_size
|
||||
|
||||
self.patch_size = patch_size
|
||||
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
||||
interpolation_scale = max(interpolation_scale, 1)
|
||||
interpolation_scale = (
|
||||
interpolation_scale if interpolation_scale is not None else max(self.config.sample_size // 64, 1)
|
||||
)
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
|
||||
@@ -133,6 +133,42 @@ ASPECT_RATIO_512_BIN = {
|
||||
"4.0": [1024.0, 256.0],
|
||||
}
|
||||
|
||||
ASPECT_RATIO_256_BIN = {
|
||||
"0.25": [128.0, 512.0],
|
||||
"0.28": [128.0, 464.0],
|
||||
"0.32": [144.0, 448.0],
|
||||
"0.33": [144.0, 432.0],
|
||||
"0.35": [144.0, 416.0],
|
||||
"0.4": [160.0, 400.0],
|
||||
"0.42": [160.0, 384.0],
|
||||
"0.48": [176.0, 368.0],
|
||||
"0.5": [176.0, 352.0],
|
||||
"0.52": [176.0, 336.0],
|
||||
"0.57": [192.0, 336.0],
|
||||
"0.6": [192.0, 320.0],
|
||||
"0.68": [208.0, 304.0],
|
||||
"0.72": [208.0, 288.0],
|
||||
"0.78": [224.0, 288.0],
|
||||
"0.82": [224.0, 272.0],
|
||||
"0.88": [240.0, 272.0],
|
||||
"0.94": [240.0, 256.0],
|
||||
"1.0": [256.0, 256.0],
|
||||
"1.07": [256.0, 240.0],
|
||||
"1.13": [272.0, 240.0],
|
||||
"1.21": [272.0, 224.0],
|
||||
"1.29": [288.0, 224.0],
|
||||
"1.38": [288.0, 208.0],
|
||||
"1.46": [304.0, 208.0],
|
||||
"1.67": [320.0, 192.0],
|
||||
"1.75": [336.0, 192.0],
|
||||
"2.0": [352.0, 176.0],
|
||||
"2.09": [368.0, 176.0],
|
||||
"2.4": [384.0, 160.0],
|
||||
"2.5": [400.0, 160.0],
|
||||
"3.0": [432.0, 144.0],
|
||||
"4.0": [512.0, 128.0],
|
||||
}
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
@@ -260,6 +296,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 120,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -284,8 +321,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
|
||||
string.
|
||||
clean_caption (bool, defaults to `False`):
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
|
||||
if "mask_feature" in kwargs:
|
||||
@@ -303,7 +341,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# See Section 3.1. of the paper.
|
||||
max_length = 120
|
||||
max_length = max_sequence_length
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
@@ -688,6 +726,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
callback_steps: int = 1,
|
||||
clean_caption: bool = True,
|
||||
use_resolution_binning: bool = True,
|
||||
max_sequence_length: int = 120,
|
||||
**kwargs,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
"""
|
||||
@@ -757,6 +796,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
If set to `True`, the requested height and width are first mapped to the closest resolutions using
|
||||
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
|
||||
the requested resolution. Useful for generating non-square images.
|
||||
max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -772,9 +812,14 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
if use_resolution_binning:
|
||||
aspect_ratio_bin = (
|
||||
ASPECT_RATIO_1024_BIN if self.transformer.config.sample_size == 128 else ASPECT_RATIO_512_BIN
|
||||
)
|
||||
if self.transformer.config.sample_size == 128:
|
||||
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
|
||||
elif self.transformer.config.sample_size == 64:
|
||||
aspect_ratio_bin = ASPECT_RATIO_512_BIN
|
||||
elif self.transformer.config.sample_size == 32:
|
||||
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
||||
else:
|
||||
raise ValueError("Invalid sample size")
|
||||
orig_height, orig_width = height, width
|
||||
height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
|
||||
@@ -822,6 +867,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
Reference in New Issue
Block a user