mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 21:14:44 +08:00
Compare commits
26 Commits
custom-cod
...
chroma-fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c5ac3df99 | ||
|
|
0d38346a55 | ||
|
|
fe877d760d | ||
|
|
8cafe5e316 | ||
|
|
7cd4e26ef5 | ||
|
|
3b3624b8ed | ||
|
|
57df5f9234 | ||
|
|
79b55007ef | ||
|
|
1999bffda8 | ||
|
|
8cc7de7c79 | ||
|
|
acc1a49250 | ||
|
|
414de99853 | ||
|
|
43d041adf4 | ||
|
|
03165b9269 | ||
|
|
544dad4c25 | ||
|
|
7cdd7d2df0 | ||
|
|
172b2ef73b | ||
|
|
d74985c160 | ||
|
|
ad13450cfe | ||
|
|
602af7411e | ||
|
|
188b0d2a2f | ||
|
|
9019e92899 | ||
|
|
6ac443d5f5 | ||
|
|
8bdb806816 | ||
|
|
96910d0a22 | ||
|
|
f6501cabb0 |
@@ -353,6 +353,7 @@ else:
|
||||
"AuraFlowPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"ChromaImg2ImgPipeline",
|
||||
"ChromaPipeline",
|
||||
"CLIPImageProjection",
|
||||
"CogVideoXFunControlPipeline",
|
||||
@@ -945,6 +946,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
ChromaImg2ImgPipeline,
|
||||
ChromaPipeline,
|
||||
CLIPImageProjection,
|
||||
CogVideoXFunControlPipeline,
|
||||
|
||||
@@ -2543,7 +2543,9 @@ class FusedFluxAttnProcessor2_0:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -2776,7 +2778,9 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
|
||||
@@ -250,15 +250,21 @@ class ChromaSingleTransformerBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
|
||||
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
@@ -312,6 +318,7 @@ class ChromaTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
|
||||
@@ -321,11 +328,15 @@ class ChromaTransformerBlock(nn.Module):
|
||||
encoder_hidden_states, emb=temb_txt
|
||||
)
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
|
||||
|
||||
# Attention.
|
||||
attention_outputs = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
@@ -570,6 +581,7 @@ class ChromaTransformer2DModel(
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
@@ -659,11 +671,7 @@ class ChromaTransformer2DModel(
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -672,6 +680,7 @@ class ChromaTransformer2DModel(
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
@@ -704,6 +713,7 @@ class ChromaTransformer2DModel(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -148,7 +148,7 @@ else:
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
]
|
||||
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
|
||||
_import_structure["chroma"] = ["ChromaPipeline"]
|
||||
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"]
|
||||
_import_structure["cogvideo"] = [
|
||||
"CogVideoXPipeline",
|
||||
"CogVideoXImageToVideoPipeline",
|
||||
@@ -537,7 +537,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .chroma import ChromaPipeline
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
|
||||
from .cogvideo import (
|
||||
CogVideoXFunControlPipeline,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
|
||||
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
|
||||
_import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -31,6 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_chroma import ChromaPipeline
|
||||
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -52,12 +52,21 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import ChromaPipeline
|
||||
|
||||
>>> pipe = ChromaPipeline.from_single_file(
|
||||
... "chroma-unlocked-v35-detail-calibrated.safetensors", torch_dtype=torch.bfloat16
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
||||
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
|
||||
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.1-schnell",
|
||||
... transformer=transformer,
|
||||
... text_encoder=text_encoder,
|
||||
... tokenizer=tokenizer,
|
||||
... torch_dtype=torch.bfloat16,
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
|
||||
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
|
||||
>>> image.save("chroma.png")
|
||||
```
|
||||
"""
|
||||
@@ -235,6 +244,7 @@ class ChromaPipeline(
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
@@ -242,7 +252,10 @@ class ChromaPipeline(
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
attention_mask = attention_mask.repeat(1, num_images_per_prompt)
|
||||
attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
return prompt_embeds, attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
@@ -250,8 +263,10 @@ class ChromaPipeline(
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
max_sequence_length: int = 512,
|
||||
lora_scale: Optional[float] = None,
|
||||
@@ -268,7 +283,7 @@ class ChromaPipeline(
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
@@ -293,7 +308,7 @@ class ChromaPipeline(
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
@@ -323,12 +338,13 @@ class ChromaPipeline(
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
@@ -336,7 +352,14 @@ class ChromaPipeline(
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids
|
||||
return (
|
||||
prompt_embeds,
|
||||
text_ids,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_text_ids,
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
@@ -394,7 +417,9 @@ class ChromaPipeline(
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_embeds=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
@@ -428,6 +453,14 @@ class ChromaPipeline(
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError(
|
||||
"Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask"
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@@ -534,6 +567,25 @@ class ChromaPipeline(
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
def _prepare_attention_mask(
|
||||
self,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
dtype,
|
||||
attention_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
# Extend the prompt attention mask to account for image tokens in the final sequence
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
|
||||
dim=1,
|
||||
)
|
||||
attention_mask = attention_mask.to(dtype)
|
||||
|
||||
return attention_mask
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -566,18 +618,20 @@ class ChromaPipeline(
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
num_inference_steps: int = 35,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 3.5,
|
||||
guidance_scale: float = 5.0,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -618,11 +672,11 @@ class ChromaPipeline(
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
@@ -636,10 +690,18 @@ class ChromaPipeline(
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
prompt_attention_mask (torch.Tensor, *optional*):
|
||||
Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence.
|
||||
Chroma requires a single padding token remain unmasked. Please refer to
|
||||
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
|
||||
negative_prompt_attention_mask (torch.Tensor, *optional*):
|
||||
Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative
|
||||
prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to
|
||||
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -678,7 +740,9 @@ class ChromaPipeline(
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
@@ -704,13 +768,17 @@ class ChromaPipeline(
|
||||
(
|
||||
prompt_embeds,
|
||||
text_ids,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_text_ids,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
@@ -730,6 +798,7 @@ class ChromaPipeline(
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
@@ -740,6 +809,20 @@ class ChromaPipeline(
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
|
||||
attention_mask = self._prepare_attention_mask(
|
||||
batch_size=latents.shape[0],
|
||||
sequence_length=image_seq_len,
|
||||
dtype=latents.dtype,
|
||||
attention_mask=prompt_attention_mask,
|
||||
)
|
||||
negative_attention_mask = self._prepare_attention_mask(
|
||||
batch_size=latents.shape[0],
|
||||
sequence_length=image_seq_len,
|
||||
dtype=latents.dtype,
|
||||
attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
@@ -801,6 +884,7 @@ class ChromaPipeline(
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
attention_mask=attention_mask,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -814,6 +898,7 @@ class ChromaPipeline(
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=negative_text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
attention_mask=negative_attention_mask,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
1039
src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
Normal file
1039
src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -272,6 +272,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ChromaImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ChromaPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
170
tests/pipelines/chroma/test_pipeline_chroma_img2img.py
Normal file
170
tests/pipelines/chroma/test_pipeline_chroma_img2img.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils.testing_utils import floats_tensor, torch_device
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
|
||||
|
||||
class ChromaImg2ImgPipelineFastTests(
|
||||
unittest.TestCase,
|
||||
PipelineTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
):
|
||||
pipeline_class = ChromaImg2ImgPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
# there is no xformers processor for Flux
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = ChromaTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=32,
|
||||
axes_dims_rope=[4, 4, 8],
|
||||
approximator_hidden_dim=32,
|
||||
approximator_layers=1,
|
||||
approximator_num_channels=16,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=1,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0609,
|
||||
scaling_factor=1.5035,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"image_encoder": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 48,
|
||||
"strength": 0.8,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_chroma_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt"] = "a different prompt"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
# For some reasons, they don't show large differences
|
||||
assert max_diff > 1e-6
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
)
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
|
||||
"Fusion of QKV projections shouldn't affect the outputs."
|
||||
)
|
||||
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
|
||||
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
)
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
|
||||
"Original outputs should match when fused QKV projections are disabled."
|
||||
)
|
||||
|
||||
def test_chroma_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (72, 57)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height - height % (pipe.vae_scale_factor * 2)
|
||||
expected_width = width - width % (pipe.vae_scale_factor * 2)
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
assert (output_height, output_width) == (expected_height, expected_width)
|
||||
Reference in New Issue
Block a user