mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-20 07:28:13 +08:00
Compare commits
26 Commits
make-tiny-
...
chroma-fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c5ac3df99 | ||
|
|
0d38346a55 | ||
|
|
fe877d760d | ||
|
|
8cafe5e316 | ||
|
|
7cd4e26ef5 | ||
|
|
3b3624b8ed | ||
|
|
57df5f9234 | ||
|
|
79b55007ef | ||
|
|
1999bffda8 | ||
|
|
8cc7de7c79 | ||
|
|
acc1a49250 | ||
|
|
414de99853 | ||
|
|
43d041adf4 | ||
|
|
03165b9269 | ||
|
|
544dad4c25 | ||
|
|
7cdd7d2df0 | ||
|
|
172b2ef73b | ||
|
|
d74985c160 | ||
|
|
ad13450cfe | ||
|
|
602af7411e | ||
|
|
188b0d2a2f | ||
|
|
9019e92899 | ||
|
|
6ac443d5f5 | ||
|
|
8bdb806816 | ||
|
|
96910d0a22 | ||
|
|
f6501cabb0 |
@@ -353,6 +353,7 @@ else:
|
|||||||
"AuraFlowPipeline",
|
"AuraFlowPipeline",
|
||||||
"BlipDiffusionControlNetPipeline",
|
"BlipDiffusionControlNetPipeline",
|
||||||
"BlipDiffusionPipeline",
|
"BlipDiffusionPipeline",
|
||||||
|
"ChromaImg2ImgPipeline",
|
||||||
"ChromaPipeline",
|
"ChromaPipeline",
|
||||||
"CLIPImageProjection",
|
"CLIPImageProjection",
|
||||||
"CogVideoXFunControlPipeline",
|
"CogVideoXFunControlPipeline",
|
||||||
@@ -945,6 +946,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
AudioLDM2UNet2DConditionModel,
|
AudioLDM2UNet2DConditionModel,
|
||||||
AudioLDMPipeline,
|
AudioLDMPipeline,
|
||||||
AuraFlowPipeline,
|
AuraFlowPipeline,
|
||||||
|
ChromaImg2ImgPipeline,
|
||||||
ChromaPipeline,
|
ChromaPipeline,
|
||||||
CLIPImageProjection,
|
CLIPImageProjection,
|
||||||
CogVideoXFunControlPipeline,
|
CogVideoXFunControlPipeline,
|
||||||
|
|||||||
@@ -2543,7 +2543,9 @@ class FusedFluxAttnProcessor2_0:
|
|||||||
query = apply_rotary_emb(query, image_rotary_emb)
|
query = apply_rotary_emb(query, image_rotary_emb)
|
||||||
key = apply_rotary_emb(key, image_rotary_emb)
|
key = apply_rotary_emb(key, image_rotary_emb)
|
||||||
|
|
||||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
@@ -2776,7 +2778,9 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
|||||||
query = apply_rotary_emb(query, image_rotary_emb)
|
query = apply_rotary_emb(query, image_rotary_emb)
|
||||||
key = apply_rotary_emb(key, image_rotary_emb)
|
key = apply_rotary_emb(key, image_rotary_emb)
|
||||||
|
|
||||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -250,15 +250,21 @@ class ChromaSingleTransformerBlock(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
|
||||||
|
|
||||||
attn_output = self.attn(
|
attn_output = self.attn(
|
||||||
hidden_states=norm_hidden_states,
|
hidden_states=norm_hidden_states,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
attention_mask=attention_mask,
|
||||||
**joint_attention_kwargs,
|
**joint_attention_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -312,6 +318,7 @@ class ChromaTransformerBlock(nn.Module):
|
|||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
|
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
|
||||||
@@ -321,11 +328,15 @@ class ChromaTransformerBlock(nn.Module):
|
|||||||
encoder_hidden_states, emb=temb_txt
|
encoder_hidden_states, emb=temb_txt
|
||||||
)
|
)
|
||||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
|
||||||
|
|
||||||
# Attention.
|
# Attention.
|
||||||
attention_outputs = self.attn(
|
attention_outputs = self.attn(
|
||||||
hidden_states=norm_hidden_states,
|
hidden_states=norm_hidden_states,
|
||||||
encoder_hidden_states=norm_encoder_hidden_states,
|
encoder_hidden_states=norm_encoder_hidden_states,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
attention_mask=attention_mask,
|
||||||
**joint_attention_kwargs,
|
**joint_attention_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -570,6 +581,7 @@ class ChromaTransformer2DModel(
|
|||||||
timestep: torch.LongTensor = None,
|
timestep: torch.LongTensor = None,
|
||||||
img_ids: torch.Tensor = None,
|
img_ids: torch.Tensor = None,
|
||||||
txt_ids: torch.Tensor = None,
|
txt_ids: torch.Tensor = None,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
controlnet_block_samples=None,
|
controlnet_block_samples=None,
|
||||||
controlnet_single_block_samples=None,
|
controlnet_single_block_samples=None,
|
||||||
@@ -659,11 +671,7 @@ class ChromaTransformer2DModel(
|
|||||||
)
|
)
|
||||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||||
block,
|
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states,
|
|
||||||
temb,
|
|
||||||
image_rotary_emb,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -672,6 +680,7 @@ class ChromaTransformer2DModel(
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
attention_mask=attention_mask,
|
||||||
joint_attention_kwargs=joint_attention_kwargs,
|
joint_attention_kwargs=joint_attention_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -704,6 +713,7 @@ class ChromaTransformer2DModel(
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
attention_mask=attention_mask,
|
||||||
joint_attention_kwargs=joint_attention_kwargs,
|
joint_attention_kwargs=joint_attention_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ else:
|
|||||||
"AudioLDM2UNet2DConditionModel",
|
"AudioLDM2UNet2DConditionModel",
|
||||||
]
|
]
|
||||||
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
|
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
|
||||||
_import_structure["chroma"] = ["ChromaPipeline"]
|
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"]
|
||||||
_import_structure["cogvideo"] = [
|
_import_structure["cogvideo"] = [
|
||||||
"CogVideoXPipeline",
|
"CogVideoXPipeline",
|
||||||
"CogVideoXImageToVideoPipeline",
|
"CogVideoXImageToVideoPipeline",
|
||||||
@@ -537,7 +537,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
)
|
)
|
||||||
from .aura_flow import AuraFlowPipeline
|
from .aura_flow import AuraFlowPipeline
|
||||||
from .blip_diffusion import BlipDiffusionPipeline
|
from .blip_diffusion import BlipDiffusionPipeline
|
||||||
from .chroma import ChromaPipeline
|
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
|
||||||
from .cogvideo import (
|
from .cogvideo import (
|
||||||
CogVideoXFunControlPipeline,
|
CogVideoXFunControlPipeline,
|
||||||
CogVideoXImageToVideoPipeline,
|
CogVideoXImageToVideoPipeline,
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
|||||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||||
else:
|
else:
|
||||||
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
|
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
|
||||||
|
_import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
|
||||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
try:
|
try:
|
||||||
if not (is_transformers_available() and is_torch_available()):
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
@@ -31,6 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||||
else:
|
else:
|
||||||
from .pipeline_chroma import ChromaPipeline
|
from .pipeline_chroma import ChromaPipeline
|
||||||
|
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -52,12 +52,21 @@ EXAMPLE_DOC_STRING = """
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
>>> from diffusers import ChromaPipeline
|
>>> from diffusers import ChromaPipeline
|
||||||
|
|
||||||
>>> pipe = ChromaPipeline.from_single_file(
|
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
||||||
... "chroma-unlocked-v35-detail-calibrated.safetensors", torch_dtype=torch.bfloat16
|
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
||||||
|
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
|
||||||
|
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
|
||||||
|
... "black-forest-labs/FLUX.1-schnell",
|
||||||
|
... transformer=transformer,
|
||||||
|
... text_encoder=text_encoder,
|
||||||
|
... tokenizer=tokenizer,
|
||||||
|
... torch_dtype=torch.bfloat16,
|
||||||
... )
|
... )
|
||||||
>>> pipe.to("cuda")
|
>>> pipe.enable_model_cpu_offload()
|
||||||
>>> prompt = "A cat holding a sign that says hello world"
|
>>> prompt = "A cat holding a sign that says hello world"
|
||||||
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
|
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||||
|
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
|
||||||
>>> image.save("chroma.png")
|
>>> image.save("chroma.png")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@@ -235,6 +244,7 @@ class ChromaPipeline(
|
|||||||
|
|
||||||
dtype = self.text_encoder.dtype
|
dtype = self.text_encoder.dtype
|
||||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||||
|
attention_mask = attention_mask.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
_, seq_len, _ = prompt_embeds.shape
|
_, seq_len, _ = prompt_embeds.shape
|
||||||
|
|
||||||
@@ -242,7 +252,10 @@ class ChromaPipeline(
|
|||||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
return prompt_embeds
|
attention_mask = attention_mask.repeat(1, num_images_per_prompt)
|
||||||
|
attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||||
|
|
||||||
|
return prompt_embeds, attention_mask
|
||||||
|
|
||||||
def encode_prompt(
|
def encode_prompt(
|
||||||
self,
|
self,
|
||||||
@@ -250,8 +263,10 @@ class ChromaPipeline(
|
|||||||
negative_prompt: Union[str, List[str]] = None,
|
negative_prompt: Union[str, List[str]] = None,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
num_images_per_prompt: int = 1,
|
num_images_per_prompt: int = 1,
|
||||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
max_sequence_length: int = 512,
|
max_sequence_length: int = 512,
|
||||||
lora_scale: Optional[float] = None,
|
lora_scale: Optional[float] = None,
|
||||||
@@ -268,7 +283,7 @@ class ChromaPipeline(
|
|||||||
torch device
|
torch device
|
||||||
num_images_per_prompt (`int`):
|
num_images_per_prompt (`int`):
|
||||||
number of images that should be generated per prompt
|
number of images that should be generated per prompt
|
||||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
provided, text embeddings will be generated from `prompt` input argument.
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
lora_scale (`float`, *optional*):
|
lora_scale (`float`, *optional*):
|
||||||
@@ -293,7 +308,7 @@ class ChromaPipeline(
|
|||||||
batch_size = prompt_embeds.shape[0]
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
if prompt_embeds is None:
|
if prompt_embeds is None:
|
||||||
prompt_embeds = self._get_t5_prompt_embeds(
|
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
num_images_per_prompt=num_images_per_prompt,
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
max_sequence_length=max_sequence_length,
|
max_sequence_length=max_sequence_length,
|
||||||
@@ -323,12 +338,13 @@ class ChromaPipeline(
|
|||||||
" the batch size of `prompt`."
|
" the batch size of `prompt`."
|
||||||
)
|
)
|
||||||
|
|
||||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||||
prompt=negative_prompt,
|
prompt=negative_prompt,
|
||||||
num_images_per_prompt=num_images_per_prompt,
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
max_sequence_length=max_sequence_length,
|
max_sequence_length=max_sequence_length,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
if self.text_encoder is not None:
|
if self.text_encoder is not None:
|
||||||
@@ -336,7 +352,14 @@ class ChromaPipeline(
|
|||||||
# Retrieve the original scale by scaling back the LoRA layers
|
# Retrieve the original scale by scaling back the LoRA layers
|
||||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||||
|
|
||||||
return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids
|
return (
|
||||||
|
prompt_embeds,
|
||||||
|
text_ids,
|
||||||
|
prompt_attention_mask,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
negative_text_ids,
|
||||||
|
negative_prompt_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
|
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
|
||||||
def encode_image(self, image, device, num_images_per_prompt):
|
def encode_image(self, image, device, num_images_per_prompt):
|
||||||
@@ -394,7 +417,9 @@ class ChromaPipeline(
|
|||||||
width,
|
width,
|
||||||
negative_prompt=None,
|
negative_prompt=None,
|
||||||
prompt_embeds=None,
|
prompt_embeds=None,
|
||||||
|
prompt_attention_mask=None,
|
||||||
negative_prompt_embeds=None,
|
negative_prompt_embeds=None,
|
||||||
|
negative_prompt_attention_mask=None,
|
||||||
callback_on_step_end_tensor_inputs=None,
|
callback_on_step_end_tensor_inputs=None,
|
||||||
max_sequence_length=None,
|
max_sequence_length=None,
|
||||||
):
|
):
|
||||||
@@ -428,6 +453,14 @@ class ChromaPipeline(
|
|||||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||||
|
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
|
||||||
|
|
||||||
|
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask"
|
||||||
|
)
|
||||||
|
|
||||||
if max_sequence_length is not None and max_sequence_length > 512:
|
if max_sequence_length is not None and max_sequence_length > 512:
|
||||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||||
|
|
||||||
@@ -534,6 +567,25 @@ class ChromaPipeline(
|
|||||||
|
|
||||||
return latents, latent_image_ids
|
return latents, latent_image_ids
|
||||||
|
|
||||||
|
def _prepare_attention_mask(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
sequence_length,
|
||||||
|
dtype,
|
||||||
|
attention_mask=None,
|
||||||
|
):
|
||||||
|
if attention_mask is None:
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
# Extend the prompt attention mask to account for image tokens in the final sequence
|
||||||
|
attention_mask = torch.cat(
|
||||||
|
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
attention_mask = attention_mask.to(dtype)
|
||||||
|
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def guidance_scale(self):
|
def guidance_scale(self):
|
||||||
return self._guidance_scale
|
return self._guidance_scale
|
||||||
@@ -566,18 +618,20 @@ class ChromaPipeline(
|
|||||||
negative_prompt: Union[str, List[str]] = None,
|
negative_prompt: Union[str, List[str]] = None,
|
||||||
height: Optional[int] = None,
|
height: Optional[int] = None,
|
||||||
width: Optional[int] = None,
|
width: Optional[int] = None,
|
||||||
num_inference_steps: int = 28,
|
num_inference_steps: int = 35,
|
||||||
sigmas: Optional[List[float]] = None,
|
sigmas: Optional[List[float]] = None,
|
||||||
guidance_scale: float = 3.5,
|
guidance_scale: float = 5.0,
|
||||||
num_images_per_prompt: Optional[int] = 1,
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
latents: Optional[torch.Tensor] = None,
|
||||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
|
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
@@ -618,11 +672,11 @@ class ChromaPipeline(
|
|||||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||||
to make generation deterministic.
|
to make generation deterministic.
|
||||||
latents (`torch.FloatTensor`, *optional*):
|
latents (`torch.Tensor`, *optional*):
|
||||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
tensor will ge generated by sampling using the supplied random `generator`.
|
tensor will ge generated by sampling using the supplied random `generator`.
|
||||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
provided, text embeddings will be generated from `prompt` input argument.
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||||
@@ -636,10 +690,18 @@ class ChromaPipeline(
|
|||||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
argument.
|
argument.
|
||||||
|
prompt_attention_mask (torch.Tensor, *optional*):
|
||||||
|
Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence.
|
||||||
|
Chroma requires a single padding token remain unmasked. Please refer to
|
||||||
|
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
|
||||||
|
negative_prompt_attention_mask (torch.Tensor, *optional*):
|
||||||
|
Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative
|
||||||
|
prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to
|
||||||
|
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
|
||||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
The output format of the generate image. Choose between
|
The output format of the generate image. Choose between
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
@@ -678,7 +740,9 @@ class ChromaPipeline(
|
|||||||
width,
|
width,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
prompt_embeds=prompt_embeds,
|
prompt_embeds=prompt_embeds,
|
||||||
|
prompt_attention_mask=prompt_attention_mask,
|
||||||
negative_prompt_embeds=negative_prompt_embeds,
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||||
max_sequence_length=max_sequence_length,
|
max_sequence_length=max_sequence_length,
|
||||||
)
|
)
|
||||||
@@ -704,13 +768,17 @@ class ChromaPipeline(
|
|||||||
(
|
(
|
||||||
prompt_embeds,
|
prompt_embeds,
|
||||||
text_ids,
|
text_ids,
|
||||||
|
prompt_attention_mask,
|
||||||
negative_prompt_embeds,
|
negative_prompt_embeds,
|
||||||
negative_text_ids,
|
negative_text_ids,
|
||||||
|
negative_prompt_attention_mask,
|
||||||
) = self.encode_prompt(
|
) = self.encode_prompt(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
prompt_embeds=prompt_embeds,
|
prompt_embeds=prompt_embeds,
|
||||||
negative_prompt_embeds=negative_prompt_embeds,
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
prompt_attention_mask=prompt_attention_mask,
|
||||||
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||||
device=device,
|
device=device,
|
||||||
num_images_per_prompt=num_images_per_prompt,
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
@@ -730,6 +798,7 @@ class ChromaPipeline(
|
|||||||
generator,
|
generator,
|
||||||
latents,
|
latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Prepare timesteps
|
# 5. Prepare timesteps
|
||||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||||
image_seq_len = latents.shape[1]
|
image_seq_len = latents.shape[1]
|
||||||
@@ -740,6 +809,20 @@ class ChromaPipeline(
|
|||||||
self.scheduler.config.get("base_shift", 0.5),
|
self.scheduler.config.get("base_shift", 0.5),
|
||||||
self.scheduler.config.get("max_shift", 1.15),
|
self.scheduler.config.get("max_shift", 1.15),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attention_mask = self._prepare_attention_mask(
|
||||||
|
batch_size=latents.shape[0],
|
||||||
|
sequence_length=image_seq_len,
|
||||||
|
dtype=latents.dtype,
|
||||||
|
attention_mask=prompt_attention_mask,
|
||||||
|
)
|
||||||
|
negative_attention_mask = self._prepare_attention_mask(
|
||||||
|
batch_size=latents.shape[0],
|
||||||
|
sequence_length=image_seq_len,
|
||||||
|
dtype=latents.dtype,
|
||||||
|
attention_mask=negative_prompt_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
timesteps, num_inference_steps = retrieve_timesteps(
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
self.scheduler,
|
self.scheduler,
|
||||||
num_inference_steps,
|
num_inference_steps,
|
||||||
@@ -801,6 +884,7 @@ class ChromaPipeline(
|
|||||||
encoder_hidden_states=prompt_embeds,
|
encoder_hidden_states=prompt_embeds,
|
||||||
txt_ids=text_ids,
|
txt_ids=text_ids,
|
||||||
img_ids=latent_image_ids,
|
img_ids=latent_image_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0]
|
)[0]
|
||||||
@@ -814,6 +898,7 @@ class ChromaPipeline(
|
|||||||
encoder_hidden_states=negative_prompt_embeds,
|
encoder_hidden_states=negative_prompt_embeds,
|
||||||
txt_ids=negative_text_ids,
|
txt_ids=negative_text_ids,
|
||||||
img_ids=latent_image_ids,
|
img_ids=latent_image_ids,
|
||||||
|
attention_mask=negative_attention_mask,
|
||||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0]
|
)[0]
|
||||||
|
|||||||
1039
src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
Normal file
1039
src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -272,6 +272,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch", "transformers"])
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaImg2ImgPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class ChromaPipeline(metaclass=DummyObject):
|
class ChromaPipeline(metaclass=DummyObject):
|
||||||
_backends = ["torch", "transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
|||||||
170
tests/pipelines/chroma/test_pipeline_chroma_img2img.py
Normal file
170
tests/pipelines/chroma/test_pipeline_chroma_img2img.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, T5EncoderModel
|
||||||
|
|
||||||
|
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||||
|
from diffusers.utils.testing_utils import floats_tensor, torch_device
|
||||||
|
|
||||||
|
from ..test_pipelines_common import (
|
||||||
|
FluxIPAdapterTesterMixin,
|
||||||
|
PipelineTesterMixin,
|
||||||
|
check_qkv_fusion_matches_attn_procs_length,
|
||||||
|
check_qkv_fusion_processors_exist,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaImg2ImgPipelineFastTests(
|
||||||
|
unittest.TestCase,
|
||||||
|
PipelineTesterMixin,
|
||||||
|
FluxIPAdapterTesterMixin,
|
||||||
|
):
|
||||||
|
pipeline_class = ChromaImg2ImgPipeline
|
||||||
|
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
|
||||||
|
batch_params = frozenset(["prompt"])
|
||||||
|
|
||||||
|
# there is no xformers processor for Flux
|
||||||
|
test_xformers_attention = False
|
||||||
|
test_layerwise_casting = True
|
||||||
|
test_group_offloading = True
|
||||||
|
|
||||||
|
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
transformer = ChromaTransformer2DModel(
|
||||||
|
patch_size=1,
|
||||||
|
in_channels=4,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_single_layers=num_single_layers,
|
||||||
|
attention_head_dim=16,
|
||||||
|
num_attention_heads=2,
|
||||||
|
joint_attention_dim=32,
|
||||||
|
axes_dims_rope=[4, 4, 8],
|
||||||
|
approximator_hidden_dim=32,
|
||||||
|
approximator_layers=1,
|
||||||
|
approximator_num_channels=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
vae = AutoencoderKL(
|
||||||
|
sample_size=32,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
block_out_channels=(4,),
|
||||||
|
layers_per_block=1,
|
||||||
|
latent_channels=1,
|
||||||
|
norm_num_groups=1,
|
||||||
|
use_quant_conv=False,
|
||||||
|
use_post_quant_conv=False,
|
||||||
|
shift_factor=0.0609,
|
||||||
|
scaling_factor=1.5035,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"text_encoder": text_encoder,
|
||||||
|
"tokenizer": tokenizer,
|
||||||
|
"transformer": transformer,
|
||||||
|
"vae": vae,
|
||||||
|
"image_encoder": None,
|
||||||
|
"feature_extractor": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_dummy_inputs(self, device, seed=0):
|
||||||
|
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||||
|
if str(device).startswith("mps"):
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
|
"image": image,
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": 2,
|
||||||
|
"guidance_scale": 5.0,
|
||||||
|
"height": 8,
|
||||||
|
"width": 8,
|
||||||
|
"max_sequence_length": 48,
|
||||||
|
"strength": 0.8,
|
||||||
|
"output_type": "np",
|
||||||
|
}
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def test_chroma_different_prompts(self):
|
||||||
|
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
output_same_prompt = pipe(**inputs).images[0]
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
inputs["prompt"] = "a different prompt"
|
||||||
|
output_different_prompts = pipe(**inputs).images[0]
|
||||||
|
|
||||||
|
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||||
|
|
||||||
|
# Outputs should be different here
|
||||||
|
# For some reasons, they don't show large differences
|
||||||
|
assert max_diff > 1e-6
|
||||||
|
|
||||||
|
def test_fused_qkv_projections(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe = pipe.to(device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs).images
|
||||||
|
original_image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||||
|
# to the pipeline level.
|
||||||
|
pipe.transformer.fuse_qkv_projections()
|
||||||
|
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||||
|
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||||
|
)
|
||||||
|
assert check_qkv_fusion_matches_attn_procs_length(
|
||||||
|
pipe.transformer, pipe.transformer.original_attn_processors
|
||||||
|
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs).images
|
||||||
|
image_slice_fused = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
pipe.transformer.unfuse_qkv_projections()
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = pipe(**inputs).images
|
||||||
|
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||||
|
"Fusion of QKV projections shouldn't affect the outputs."
|
||||||
|
)
|
||||||
|
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||||
|
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||||
|
)
|
||||||
|
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||||
|
"Original outputs should match when fused QKV projections are disabled."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_chroma_image_output_shape(self):
|
||||||
|
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
|
||||||
|
height_width_pairs = [(32, 32), (72, 57)]
|
||||||
|
for height, width in height_width_pairs:
|
||||||
|
expected_height = height - height % (pipe.vae_scale_factor * 2)
|
||||||
|
expected_width = width - width % (pipe.vae_scale_factor * 2)
|
||||||
|
|
||||||
|
inputs.update({"height": height, "width": width})
|
||||||
|
image = pipe(**inputs).images[0]
|
||||||
|
output_height, output_width, _ = image.shape
|
||||||
|
assert (output_height, output_width) == (expected_height, expected_width)
|
||||||
Reference in New Issue
Block a user