mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
1 Commits
onnx-cpu-d
...
batched-cf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e383fc901 |
@@ -289,80 +289,107 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
negative_prompt_2: Union[str, List[str]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
lora_scale: Optional[float] = None,
|
||||
do_true_cfg: bool = False,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
# Set LoRA scale if applicable
|
||||
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if do_true_cfg and negative_prompt is not None:
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_batch_size = len(negative_prompt)
|
||||
|
||||
if negative_batch_size != batch_size:
|
||||
raise ValueError(
|
||||
f"Negative prompt batch size ({negative_batch_size}) does not match prompt batch size ({batch_size})"
|
||||
)
|
||||
|
||||
# Concatenate prompts
|
||||
prompts = prompt + negative_prompt
|
||||
prompts_2 = (
|
||||
prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None
|
||||
)
|
||||
else:
|
||||
prompts = prompt
|
||||
prompts_2 = prompt_2
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
if prompts_2 is None:
|
||||
prompts_2 = prompts
|
||||
|
||||
# We only use the pooled prompt output from the CLIPTextModel
|
||||
# Get pooled prompt embeddings from CLIPTextModel
|
||||
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
||||
prompt=prompt,
|
||||
prompt=prompts,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt_2,
|
||||
prompt=prompts_2,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if do_true_cfg and negative_prompt is not None:
|
||||
# Split embeddings back into positive and negative parts
|
||||
total_batch_size = batch_size * num_images_per_prompt
|
||||
positive_indices = slice(0, total_batch_size)
|
||||
negative_indices = slice(total_batch_size, 2 * total_batch_size)
|
||||
|
||||
positive_pooled_prompt_embeds = pooled_prompt_embeds[positive_indices]
|
||||
negative_pooled_prompt_embeds = pooled_prompt_embeds[negative_indices]
|
||||
|
||||
positive_prompt_embeds = prompt_embeds[positive_indices]
|
||||
negative_prompt_embeds = prompt_embeds[negative_indices]
|
||||
|
||||
pooled_prompt_embeds = positive_pooled_prompt_embeds
|
||||
prompt_embeds = positive_prompt_embeds
|
||||
|
||||
# Unscale LoRA layers
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
if do_true_cfg and negative_prompt is not None:
|
||||
negative_text_ids = text_ids.clone()
|
||||
|
||||
return (
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
text_ids,
|
||||
negative_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
negative_text_ids,
|
||||
)
|
||||
else:
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids, None, None, None
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
@@ -687,38 +714,52 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
do_true_cfg = true_cfg > 1 and negative_prompt is not None
|
||||
(
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
text_ids,
|
||||
negative_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
negative_text_ids,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
do_true_cfg=do_true_cfg,
|
||||
)
|
||||
|
||||
# perform "real" CFG as suggested for distilled Flux models in https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md
|
||||
do_true_cfg = true_cfg > 1 and negative_prompt is not None
|
||||
# do_true_cfg = true_cfg > 1 and negative_prompt is not None
|
||||
# if do_true_cfg:
|
||||
# (
|
||||
# negative_prompt_embeds,
|
||||
# negative_pooled_prompt_embeds,
|
||||
# negative_text_ids,
|
||||
# ) = self.encode_prompt(
|
||||
# prompt=negative_prompt,
|
||||
# prompt_2=negative_prompt_2,
|
||||
# prompt_embeds=negative_prompt_embeds,
|
||||
# pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
# device=device,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
# max_sequence_length=max_sequence_length,
|
||||
# lora_scale=lora_scale,
|
||||
# )
|
||||
if do_true_cfg:
|
||||
(
|
||||
negative_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
negative_text_ids,
|
||||
) = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_2=negative_prompt_2,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
# Concatenate embeddings
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
text_ids = torch.cat([negative_text_ids, text_ids], dim=0)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
@@ -732,6 +773,9 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
if do_true_cfg:
|
||||
latent_image_ids = latent_image_ids.repeat(prompt_embeds.shape[0], 1, 1)
|
||||
# expand the latents, too?
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
@@ -767,11 +811,13 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
@@ -783,18 +829,7 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=negative_text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
neg_noise_pred, noise_pred = noise_pred.chunk(2)
|
||||
noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
Reference in New Issue
Block a user